saex: training SAEs in Jax
A library for training Sparse Autoencoders in Jax
I worked on this project from April to August 2024 as part of MATS.
Sparse autoencoders are a recent method for decomposing neural networks into more human understandable components - see this blog post for a more in-depth explanation. The main idea is to train a neural network to reconstruct its input, but with a sparsity constraint on the hidden layer activations. This forces the network to learn a sparse representation of the input data, and we often see that individual hidden latents correspond to specific features in the input data.
saex is an implementation of sparse autoencoders in Jax intended to run on TPUs, like Deepmind's SAE trainer. Its reference implementation is SAE Lens (which was called mats-sae in early 2024). It can train competitive ReLU and Gated SAEs on models up to Phi-3 scale. SAEs trained with it featured in Lesswrong posts from our MATS team.
Basics
SAEs are trained offline: there is a model from which we extract activations, and the sparse autoencoder is trained to reconstruct them with a pure reconstruction loss like MSE. When we do a forward pass through the base model with a batch of tokens, we can choose either to store the activations in a buffer or train on them immediately.
saex does the former: it uses a HBM ring buffer to save and periodically recycle activations. saex's caching strategy is to first store activations for a set number of tokens, then train on randomly sampled activations from the buffer. Depending on the configuration, SAE training can be faster or slower than caching.
Training and caching
The trainer calls several JIT-compiled functions: caching a batch of activations, storing it in the buffer, retrieving them, and training on a random sample.
I started with the Adam optimizer as in SAE Lens, but switched to Adadelta (Adam without momentum). The reason for this was the finding that SAEs converge better with . I found that the amount of dead features was lower with this modification.
We initialize the SAE weights with an orthogonal matrix and the decoder with its pseudoinverse or transpose. For transcoders, we set the decoder to a random matrix.
I experimented with keeping parameters in bfloat16 during optimization. I arrived at a scheme where the weights and Gated SAE scaling factors are stored in bfloat16, but biases are in fp32. I don't understand why this works, but it retains the accuracy of fp32 training unlike every other scheme I tried. I did not try stochastic rounding, but it would have been a good idea given the stabilizing effect it has on cross-layer transcoders.
The model is sharded across the feature dimension. The codebase supports data and model parallelism. Because we do not explicitly select the sparse activations, we don't need the engineering of TopK SAEs. We use the same mesh for sharding SAE activations as for the base model (except for transformers models, see below). The buffer is also sharded across the data parallel dimension.
We also implemented caching of SAE activations for maximum activating example visualization. This usually consists of running the SAE on a smaller distribution, noting the token indices that activate a feature the most, and storing small (16-32 token) windows around these indices along with feature activations.
Unlike in Pytorch, we cannot use a lot of dynamic control flow or accelerator-host communication. This means we can't have a growing array of activation windows or store more than a fixed number of contexts in a single step. The solution we ended up with was messy but not too slow. We first select a random activating (above some threshold) feature for each token by permuting activations of all features and scanning until the last feature that matches the activity condition. We select a random subset of active token/feature pairs and slice activations of the same feature on tokens around them with jnp.take_along_axis, setting invalid indices to 0. We quantize the activations to 8 bits and store them in a buffer together with the token indices. Finally, we save all tokens, token indices and feature activations in a pyarrow table. We only compress as much as we do because the caches can get large for the 100GB of the v4-8 TPU's local SSD.
Models
saex initially supported loading LLMs from the Flax version of transformers. There were, however, some issues with the library: it was slow to compile, required AST patching to extract activations from sites other than the residual stream (as would be necessary for transcoder training, which we needed to support later on), and there was no simple way to shard the model for sequence parallelism.
micrlhf
For these reasons, I started writing a Llama architecture implementation from scratch at micrlhf-progress. On its own, this would not be a worthwhile project, so I decided to make the library work with llama.cpp's GGUF format and support quantized models natively.
I wrote the library with the Penzai framework because it was recent at the time and looked promising for interpretability. It was indeed easy to log intermediate activations with it, even if all code became more verbose due to having to convert back and forth between positional and named axes.
To speed up compilation, I implemented layer stacking, a jax-specific optimization technique that combines parameter arrays across different layers and slices them at runtime to dynamically access the correct weights. This speeds up compilation time because the compilation is not repeated for each layer, but may increase backward pass time due to inefficiency on XLA's side.
I implemented 8-bit matmul kernels so as to be able to matrix-multiply directly from the quantized tensor in memory without having to unquantize it. However, I did not know what I was doing and the kernels were in all likelihood slower than naively dequantizing and checkpointing the results. Still, I was able to inference Llama 3 70B on a v4-8 node with a loading time of 5 minutes and about 15% MFU.
Since then, version 0.2 of the framework came out with support for Llama and layer stacking out of the box, making micrlhf's implementation of the architecture mostly obsolete. But we stayed with the repository for a while: during the project, my MATS partner and I had written a few dozen scripts for:
- prompt optimization
- activation steering (and refusal orthogonalization)
- many-shot jailbreaking
- mechanistically eliciting latent behavior
- sparse circuit finding, as used in our paper all the way until 2025
...and others. Fortunately, we don't have any reason to use these scripts or the library anymore.
Supported SAE types
Initially, the codebase only supported the basic ReLU SAE. As mentioned above, the implementation closely followed SAE Lens and achieved similar performance for the same number of tokens. One difference is the way saex keeps the decoder weights $W_d : \mathbb{R}^{N \cross d}$, unit-norm across the d_model dimension. The approach described in Scaling Monosemanticity is to normalize the decoder after each step and project gradients before each update so that the norm will not be affected to a first-order approximation. SAE Lens also uses this stategy. dictionary_learning, another implementation that was available at the time, instead normalizes the Adam updates. We follow their implementation in that choice and did not ablate it.
I added Ghost Gradients to avoid dead features, with some modifications for numerical stability: specifically, replacing the activation with instead of so the activations don't blow up when they're already large and don't need the regularizer. When the Gated SAE paper came out, I implemented the architecture. I also added support for unnormalized decoder rows, as described in the April 2024 Circuits update. This change required a different implementation of gated SAE decoder freezing, as described in an unpublished update by Deepmind. See this section for an implementation.