fae: FLUX.1 SAE training on TPUs
An implementation of MMDiT in jax, + SAE training on TPUs
This is a post describing fae1, a project I worked on in late 2024 as part of TPU Research Cloud. It is a library for training sparse autoencoders2 and is behind "Interpreting Large Text-to-Image Diffusion Models with Dictionary Learning".
Code: http://github.com/neverix/fae.
FLUX implementation
FLUX.1 came out in August 2024. It is a MMDiT, like Stable Diffusion 3, which was SOTA at the time. It was quickly adopted and optimized for GPUs. Still, in the next few months, no open source implementation for TPUs came out.
I requested TPUs from TRC to port this model to TPUs and got access to v4-8 and v3-8 TPUs for testing. FLUX was a fairly large model compared to what was out there, with 12 billion diffusion transformer parameters, so quantized versions were published. v4-8 TPUs have 32GB of HBM per chip, meaning that, if one were to use data parallelism, the parameters would just barely fit in bfloat16. For this reason, I considered ways to reduce memory use early into the project. I settled on a combination of FSDP and 4-bit quantization. In the end, I had a reasonably efficient FLUX.1 runtime capable of generating images and suitable for SAE training.
WIP after this point.
Memory use optimization: FSDP
As mentioned above, each v4 TPU chip has 32GB of RAM . there are 4 of them linked with fast interconnect on a v4-8 node, meaning...
Memory use optimization: 4-bit quantization
Quantizing: normalfloat4, the format used in bitsandbytes.
Wrote a kernel. A lot of restrictions on v4-8's, ran into open compiler bugs.
Eager dequantization is just as fast on v4-8. There may be a way to run faster by using int8 quantization and quantizing individual rows and columns, but only on v5-8
Layer slicing
The other big hurdle other than memory use is compilation time: jax can take minutes to compile large models, with compile time scaling linearly with the number of operations -- or, equivalently for neural networks, the number of layers.
SAE training
We train SAEs on residaul stream activations.
We train with 16k features. Use generated data and schnell, so just noise inputs and one timestep. Trained with more timesteps but didn't evaluate.
Initially, training failed -- variance explained didn't decrease quickly, and many features were dead or sparse
We find that the low-dimensional subspace is? isn't? the subspace used for storing the inputs.
I discovered fluxlens after training the residual stream SAEs. A difference is that
Gathering activations
We need to gather activations from the mdoel's forward pass. Existing Jax solutions are to accumulate into a dict (penzai, used in saex, and NAME OF THE LIBRARY) or to store .
Harvest's reap and sow.
Layer slicing interaction.
Sparse operations on TPUs
SAEs naturally perform a sparse matrix multiplication in the decoder; properly implementing it to take advantage can speed up training 6x by removing almost half of the FLOPs for the forward pass and almost all the FLOPs for the backward pass (see Gao et al. 2024 (Section D)). This shifts the bottleneck into https://www.neuronpedia.org/graph/info#appendix-e.
Activation dashboards
Collect activations also
Future work
- model diffing schnell and dev
- understand what CFG distillation does
- probe for 3D concepts
Short for "Flux SAE".
And a leaner SAE training codebase for TPUs like saex.