State Space Models & Mamba
Attention is O(n²) in the sequence length. For long contexts — genomes, audio, hour-long documents — that's a wall. State Space Models are a family of linear-time sequence mixers that, as of 2024, finally match or beat Transformers on language while scaling gracefully to million-token contexts.
1. Why not just Transformers?
Self-attention is the most important architecture of the last decade, but it has a brutal scaling problem. Attention cost is $O(L^2)$ in sequence length $L$: every token attends to every other token. At $L = 2{,}000$ this is fine. At $L = 100{,}000$ it eats all your memory. At $L = 1{,}000{,}000$ (a DNA strand, a long video, a codebase) it's completely impossible.
A lot of approaches tried to patch this — sparse attention, linear attention, Longformer, Performer, RWKV. Most worked on small tasks but lost to full attention on language. Then in 2023, Mamba (Gu & Dao) landed. It's a state space model with a trick called selectivity, and it was the first sub-quadratic architecture to seriously compete with Transformers on language modeling loss, all while running in $O(L)$ time and $O(1)$ memory per step at inference.
Mamba is the end of a longer story that runs through S4, HIPPO, and classical control theory. Let's walk through it.
2. The classical state space
A continuous-time linear state space model represents a system with a hidden state $\mathbf{h}(t)$ that evolves according to a simple linear ODE:
The continuous state space
- $u(t)$
- The input signal at time $t$. Scalar for 1-D, vector for multi-channel (e.g., a token embedding).
- $\mathbf{h}(t) \in \mathbb{R}^N$
- The hidden state — an $N$-dimensional memory that summarizes everything the system has seen so far. $N$ is typically 16 or 64.
- $\mathbf{A} \in \mathbb{R}^{N \times N}$
- The state transition matrix: how the state evolves on its own, without new input. If $\mathbf{A}$ has negative eigenvalues, the system forgets; if positive, it explodes; if imaginary, it oscillates.
- $\mathbf{B} \in \mathbb{R}^{N \times 1}$
- The input-to-state matrix: how much new input pushes into each state dimension.
- $\mathbf{C} \in \mathbb{R}^{1 \times N}$
- The state-to-output matrix: a linear projection from state to observable output.
- $\mathbf{D} \in \mathbb{R}$
- A direct feedthrough (skip connection) from input to output. Often set to 0 in deep learning contexts.
Analogy Picture a spring-mass-damper system. The "state" is [position, velocity]. The ODE says: velocity changes based on position and the external force; position changes based on velocity. $\mathbf{A}$ encodes the spring constant and damping; $\mathbf{B}$ encodes how much an external push accelerates the mass. Any linear dynamical system — a circuit, a thermostat, a cruise-control loop — looks like this. SSMs bring this century-old formalism into deep learning.
3. Discretization
Neural networks work in discrete steps, not continuous time. To use an SSM in a network, you discretize the ODE with a step size $\Delta$. Using the zero-order hold rule:
The discrete SSM is then a linear recurrence:
Discrete SSM
- $\Delta$
- The time step — how much "real time" each discrete step represents. In a language SSM this is a learned, data-dependent parameter.
- $\bar{\mathbf{A}}, \bar{\mathbf{B}}$
- The discretized versions of $\mathbf{A}$ and $\mathbf{B}$. The bar distinguishes them from their continuous-time originals.
- $\mathbf{h}_t$
- The discrete hidden state at step $t$. Now the formula reads exactly like an RNN.
- $y_t$
- The output at step $t$ — a linear readout of the state.
The dual view This recurrence is algebraically identical to a very specific linear RNN — no nonlinearity between steps, and with structured transition matrices. That sounds like a weakness, but it unlocks two superpowers: (1) you can unroll the recurrence into a global convolution and compute all outputs in parallel in $O(L \log L)$ via FFT; (2) at inference time you can switch back to the recurrence and decode one step at a time in $O(1)$ memory.
This dual view — "same model, two compute modes" — is the whole point of SSMs. During training you train as a convolution on the whole sequence (fast, parallel). During autoregressive inference you run as a recurrence (tiny state, no KV cache). Transformers can't do this: their attention is fundamentally non-recurrent and forces you to keep an ever-growing KV cache at inference.
4. S4 — HIPPO and structured state spaces
Before Mamba there was S4 (Gu, Goel, Ré 2021). S4's contribution was two-fold:
- It chose $\mathbf{A}$ to be a specific structured matrix — the HIPPO matrix, derived from a theory of optimal polynomial basis projection. HIPPO-LegS guarantees that the state summarizes past inputs as Legendre polynomial coefficients, giving provably good long-range memory.
- It factored $\mathbf{A}$ into a low-rank plus normal form so that the convolution kernel can be computed efficiently without ever materializing $\bar{\mathbf{A}}^L$ directly.
With a HIPPO-initialized S4 layer, a model could recall information from tens of thousands of steps ago — something vanilla RNNs and even Transformers with full attention struggled with on the Long Range Arena benchmark.
S4's limitation: $\mathbf{A}$, $\mathbf{B}$, $\mathbf{C}$, $\Delta$ were all input-independent. The model applied the same dynamics to every token. For language, where different tokens need the model to remember or forget different things, that's too rigid.
5. Mamba's selectivity
Mamba's move (Gu & Dao, 2023) is deceptively small: make $\mathbf{B}$, $\mathbf{C}$, and $\Delta$ functions of the input. The $\mathbf{A}$ matrix stays structured and shared, but at each time step:
Selective SSM
- $\mathbf{B}_t, \mathbf{C}_t$
- Now depend on the current input $u_t$ through learned linear projections. Different tokens push different information into the hidden state and read different information out.
- $\Delta_t$
- A per-token, learned step size. Large $\Delta_t$ means "this token matters — update the state a lot"; small $\Delta_t$ means "this token is filler — barely update." This is how selection happens.
- $\text{softplus}$
- A smooth positive function that keeps $\Delta_t > 0$ (step size must be positive for the ODE interpretation).
- $\mathbf{A}$
- Still structured and input-independent. Selectivity runs through $\mathbf{B}, \mathbf{C}, \Delta$, not $\mathbf{A}$.
Why it's a big deal With input-dependent $\Delta_t$, the model can effectively "pay attention" — it learns to take big steps on important tokens and small steps on irrelevant ones, compressing or expanding time. Structurally, the resulting layer is still a recurrence, so inference is $O(L)$ and the KV cache is gone. But because $\mathbf{B}, \mathbf{C}$ now vary per step, the layer can copy/delete specific pieces of the state — recovering much of the flexibility that pure SSMs lacked.
The cost of selectivity: the recurrence is no longer a pure convolution, so the nice FFT trick doesn't apply directly. Mamba solves this with a hardware-aware parallel scan implemented as a custom CUDA kernel. The scan computes the recurrence in $O(L)$ work and $O(\log L)$ depth, streaming state through GPU SRAM so it never touches slow HBM. This implementation detail is what made Mamba actually fast in practice, not just fast on paper.
By 2024 Mamba-2 tightened the connection between SSMs and linear attention (showing they're two views of the same object) and landed hybrids like Jamba (AI21) and Zamba, which interleave Mamba blocks with a few attention layers to get Transformer-quality recall plus SSM-quality speed.
6. Interactive SSM demo
A 1-D linear SSM running on an input pulse train. The hidden state is 2-dimensional. Adjust $\mathbf{A}$'s eigenvalue and the step $\Delta$ to see the state response change: stable decay vs. oscillation vs. blow-up. Small $\Delta$ = slow integration; large $\Delta$ = fast response.
Drag the sliders to see how A and Δ shape the state response to an input pulse train.
7. Source code
A minimal selective SSM, scaled down. Not Mamba — just the core recurrence with per-step B, C, Δ.
import numpy as np
def ssm_recurrent(u, A, B, C, delta):
# Discrete SSM: h_t = A_bar h_{t-1} + B_bar u_t
# u: (T,) input sequence; A, B: state dyn; C: readout
A_bar = np.exp(delta * A) # (N, N) diagonal case: elementwise
B_bar = ((A_bar - 1) / A) * B # (N,) for 1-D input
N = A.shape[0]
h = np.zeros(N)
y = np.zeros_like(u)
for t in range(len(u)):
h = A_bar @ h + B_bar * u[t]
y[t] = C @ h
return y # O(T * N) — inference-friendly
import numpy as np
def ssm_convolutional(u, A, B, C, delta, L):
# Unroll the recurrence into a single global convolution kernel K.
# K[t] = C · A_bar^t · B_bar, for t = 0..L-1
A_bar = np.exp(delta * A)
B_bar = ((A_bar - 1) / A) * B
K = np.zeros(L)
Ak = np.eye(A.shape[0])
for t in range(L):
K[t] = C @ (Ak @ B_bar)
Ak = A_bar @ Ak
# y = causal conv(u, K). In practice use FFT for O(L log L).
return np.convolve(u, K)[:L] # training-friendly, parallel
import torch, torch.nn as nn
class SelectiveSSM(nn.Module):
def __init__(self, d_model, d_state=16):
super().__init__()
self.d, self.N = d_model, d_state
# Structured diagonal A — log-parameterized so eigenvalues stay negative.
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).float()))
# Input-dependent B, C, Δ come from linear projections of u:
self.to_B = nn.Linear(d_model, d_state, bias=False)
self.to_C = nn.Linear(d_model, d_state, bias=False)
self.to_dt = nn.Linear(d_model, 1, bias=True)
def forward(self, u): # u: (B, T, d)
B, T, d = u.shape
A = -torch.exp(self.A_log) # (N,) negative eigenvalues
Bp = self.to_B(u) # (B, T, N)
Cp = self.to_C(u) # (B, T, N)
dt = torch.softplus(self.to_dt(u)) # (B, T, 1) — the selective Δ
# Discretize per timestep (diagonal A makes this elementwise)
A_bar = torch.exp(dt * A) # (B, T, N)
B_bar = dt * Bp # (B, T, N) — approximate
# Sequential scan. Real Mamba uses a parallel associative scan in CUDA.
h = torch.zeros(B, self.N, device=u.device)
outs = []
for t in range(T):
h = A_bar[:, t] * h + B_bar[:, t] * u[:, t].mean(-1, keepdim=True)
outs.append((Cp[:, t] * h).sum(-1))
return torch.stack(outs, dim=1) # (B, T)
8. Summary
- State Space Models represent sequences through a continuous-time linear ODE discretized into a linear recurrence. Old control theory, new wrapping paper.
- The dual view — recurrence for inference, convolution for training — is the defining advantage: $O(L)$ generation without a growing KV cache.
- S4 (2021) made SSMs practical with the HIPPO initialization and a structured $\mathbf{A}$ that captures long-range dependencies.
- Mamba (2023) added selectivity: $\mathbf{B}$, $\mathbf{C}$, and the step $\Delta$ become input-dependent, letting the model "pay attention" without attention.
- Mamba-2, Jamba, Samba, and similar hybrids mix a small number of attention layers with many SSM layers to combine recall quality with long-context efficiency.
- The main reason to care: sub-quadratic scaling. If you want million-token context (DNA, video, codebase), SSM-based models are the current practical answer.
Further reading
- Gu, Goel & Ré (2021) — Efficiently Modeling Long Sequences with Structured State Spaces (S4).
- Gu, Johnson et al. (2020) — HiPPO: Recurrent Memory with Optimal Polynomial Projections.
- Gu & Dao (2023) — Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
- Dao & Gu (2024) — Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (Mamba-2).
- Lieber et al. (2024) — Jamba: A Hybrid Transformer-Mamba Language Model.