Standard sparse autoencoders (SAEs) decompose a model's hidden state into a sparse set of interpretable features. They work well for transformers, where the hidden state is a flat vector (e.g., [batch, seq_len, hidden_dim]).
But a new class of language models — state-space and recurrent models like Mamba-2, RWKV-7, and Gated DeltaNet —
use a matrix-valued recurrent state instead of a flat vector. The recurrent state is shaped like
[key_dim, val_dim] per head, and each token update is a rank-1 outer product:
# Gated DeltaNet update
S_t = S_{t-1} + gate_t * (key_t.T @ value_t)
# ^^^^^^^^ matrix state ^^^^^^^^^^^^^ rank-1 write
If you apply a standard SAE to this matrix state, you flatten it into a vector, losing the rank-1 structure that is fundamental to how these models work. WriteSAE fixes this by factoring each decoder atom into a rank-1 outer product, matching the native write shape.
# Flatten the matrix state x_flat = state.reshape(-1) # [key_dim * val_dim] # Encode acts = ReLU((x_flat - b_dec) @ W_enc.T + b_enc) # Decode: each feature is a flat vector recon = acts @ W_dec + b_dec # W_dec: [n_features, key_dim * val_dim]
# Keep the matrix shape
x = state # [key_dim, val_dim]
# Each decoder atom is a rank-1 outer product:
# atom_i = V_i[:, :] @ W_i[:, :].T
# where V_i and W_i are vectors of length key_dim and val_dim
# Decode:
recon = sum_i acts[i] * (V_i @ W_i.T) + b_dec
# In code:
recon = einsum("bi,irk,irv->bkv", acts, V, W) + b_dec
The decoder matrix W_dec in a standard SAE has shape [n_features, key_dim * val_dim].
In WriteSAE, the equivalent factorization uses two matrices V and W, each with shape
[n_features, key_dim, val_dim] — but because each feature is a rank-1 outer product,
the actual parameter count is 2 * n_features * key_dim * val_dim, which is fewer parameters
than the flat n_features * key_dim * val_dim when key_dim and val_dim are small.
WriteSAE comes in two variants:
| Variant | Encoder | Decoder | Use case |
|---|---|---|---|
| MatrixSAE | Flat projection: einsum("bkv,vik,vwk->bi", x-b_dec, V, W) |
Rank-1 reconstruction: einsum("bi,vik,vwk->bkv", acts, V, W) + b_dec |
General use |
| BilinearMatrixSAE | Bilinear form: einsum("vik,bkv,vwk->bi", V_enc, x, W_enc) |
Same as MatrixSAE but with separate encoder/decoder factors | Better fidelity when encoder/decoder need different structures |
Both variants use:
| Model | Architecture | Parameters | Training tokens |
|---|---|---|---|
| Qwen3.5-0.8B | Gated DeltaNet | 0.8B | 3.5T |
| Qwen3.5-4B | Gated DeltaNet | 4B | 12T |
The authors train SAEs on the recurrent state (not the residual stream) of each GDN layer. They use 2K-16K features per layer, key_dim=128, val_dim=64.
We cloned the WriteSAE repository and read the core implementation. Key observations:
core/ (SAE classes, training loop, loss functions),
experiments/ (extraction, analysis, interpretability), and scripts/ (training configs).model.model.layers[i].mlp.forward.delta to get the recurrent state.
This won't work for transformers.Matron currently uses Qwen 3.5 35B-A3B, which is a Mixture-of-Experts transformer with standard attention. WriteSAE is designed for recurrent state-space models (Gated DeltaNet, Mamba-2, RWKV-7). These are fundamentally different architectures.
However, there are two reasons to care:
| Requirement | Status |
|---|---|
| Gated DeltaNet model (Qwen 3.5 0.8B or 4B) | Available on HuggingFace |
| A100 GPU for training | We have access (vast.ai) |
| Modal account for orchestration | Not set up; can train locally instead |
| Training data (FineWeb, etc.) | Available via HuggingFace |
| Time (~2-4 hours per SAE layer) | Doable but time-intensive |