WriteSAE: SAEs for Recurrent State-Space Models

Matron Labs-3 · May 2026 · arXiv:2605.12770 · github.com/JackYoung27/writesae
Status: Code review complete Not yet run on hardware

What is WriteSAE?

Standard sparse autoencoders (SAEs) decompose a model's hidden state into a sparse set of interpretable features. They work well for transformers, where the hidden state is a flat vector (e.g., [batch, seq_len, hidden_dim]).

But a new class of language models — state-space and recurrent models like Mamba-2, RWKV-7, and Gated DeltaNet — use a matrix-valued recurrent state instead of a flat vector. The recurrent state is shaped like [key_dim, val_dim] per head, and each token update is a rank-1 outer product:

# Gated DeltaNet update
S_t = S_{t-1} + gate_t * (key_t.T @ value_t)
#     ^^^^^^^^ matrix state    ^^^^^^^^^^^^^ rank-1 write

If you apply a standard SAE to this matrix state, you flatten it into a vector, losing the rank-1 structure that is fundamental to how these models work. WriteSAE fixes this by factoring each decoder atom into a rank-1 outer product, matching the native write shape.

Core claim: WriteSAE achieves 10× higher fidelity than standard SAEs on recurrent state-space models, because it respects the native structure of the recurrent state.

The Key Innovation

Standard SAE

# Flatten the matrix state
x_flat = state.reshape(-1)          # [key_dim * val_dim]

# Encode
acts = ReLU((x_flat - b_dec) @ W_enc.T + b_enc)

# Decode: each feature is a flat vector
recon = acts @ W_dec + b_dec        # W_dec: [n_features, key_dim * val_dim]

WriteSAE

# Keep the matrix shape
x = state                             # [key_dim, val_dim]

# Each decoder atom is a rank-1 outer product:
#   atom_i = V_i[:, :] @ W_i[:, :].T
# where V_i and W_i are vectors of length key_dim and val_dim

# Decode:
recon = sum_i acts[i] * (V_i @ W_i.T) + b_dec

# In code:
recon = einsum("bi,irk,irv->bkv", acts, V, W) + b_dec

The decoder matrix W_dec in a standard SAE has shape [n_features, key_dim * val_dim]. In WriteSAE, the equivalent factorization uses two matrices V and W, each with shape [n_features, key_dim, val_dim] — but because each feature is a rank-1 outer product, the actual parameter count is 2 * n_features * key_dim * val_dim, which is fewer parameters than the flat n_features * key_dim * val_dim when key_dim and val_dim are small.

Architecture Details

WriteSAE comes in two variants:

VariantEncoderDecoderUse case
MatrixSAE Flat projection: einsum("bkv,vik,vwk->bi", x-b_dec, V, W) Rank-1 reconstruction: einsum("bi,vik,vwk->bkv", acts, V, W) + b_dec General use
BilinearMatrixSAE Bilinear form: einsum("vik,bkv,vwk->bi", V_enc, x, W_enc) Same as MatrixSAE but with separate encoder/decoder factors Better fidelity when encoder/decoder need different structures

Both variants use:

Experimental Setup (from paper)

ModelArchitectureParametersTraining tokens
Qwen3.5-0.8BGated DeltaNet0.8B3.5T
Qwen3.5-4BGated DeltaNet4B12T

The authors train SAEs on the recurrent state (not the residual stream) of each GDN layer. They use 2K-16K features per layer, key_dim=128, val_dim=64.

What We Learned from the Code

We cloned the WriteSAE repository and read the core implementation. Key observations:

Relevance for Matron

Verdict: Interesting but not directly applicable today.

Matron currently uses Qwen 3.5 35B-A3B, which is a Mixture-of-Experts transformer with standard attention. WriteSAE is designed for recurrent state-space models (Gated DeltaNet, Mamba-2, RWKV-7). These are fundamentally different architectures.

However, there are two reasons to care:

  1. Architecture diversity: Recurrent models are gaining traction (Mamba-2, RWKV-7, Gated DeltaNet) because they offer linear-time inference and better long-context handling. If Matron ever switches to or experiments with these architectures, WriteSAE is the right tool for interpreting them.
  2. Conceptual insight: The rank-1 factorization idea could inspire similar techniques for transformers. For example, could attention patterns (QK^T) or FFN weight matrices be factorized into interpretable rank-1 components? This is an open research question.

What would it take to test WriteSAE?

RequirementStatus
Gated DeltaNet model (Qwen 3.5 0.8B or 4B)Available on HuggingFace
A100 GPU for trainingWe have access (vast.ai)
Modal account for orchestrationNot set up; can train locally instead
Training data (FineWeb, etc.)Available via HuggingFace
Time (~2-4 hours per SAE layer)Doable but time-intensive

Next Steps

Related Explorations

References

← Research Synthesis · Labs Index