State Space Models & Mamba

Attention is O(n²) in the sequence length. For long contexts — genomes, audio, hour-long documents — that's a wall. State Space Models are a family of linear-time sequence mixers that, as of 2024, finally match or beat Transformers on language while scaling gracefully to million-token contexts.

Prereq: RNNs, linear algebra Time to read: ~24 min Interactive figures: 1 Code: NumPy, PyTorch

1. Why not just Transformers?

Self-attention is the most important architecture of the last decade, but it has a brutal scaling problem. Attention cost is $O(L^2)$ in sequence length $L$: every token attends to every other token. At $L = 2{,}000$ this is fine. At $L = 100{,}000$ it eats all your memory. At $L = 1{,}000{,}000$ (a DNA strand, a long video, a codebase) it's completely impossible.

A lot of approaches tried to patch this — sparse attention, linear attention, Longformer, Performer, RWKV. Most worked on small tasks but lost to full attention on language. Then in 2023, Mamba (Gu & Dao) landed. It's a state space model with a trick called selectivity, and it was the first sub-quadratic architecture to seriously compete with Transformers on language modeling loss, all while running in $O(L)$ time and $O(1)$ memory per step at inference.

Mamba is the end of a longer story that runs through S4, HIPPO, and classical control theory. Let's walk through it.

2. The classical state space

A continuous-time linear state space model represents a system with a hidden state $\mathbf{h}(t)$ that evolves according to a simple linear ODE:

$$\frac{d\mathbf{h}(t)}{dt} = \mathbf{A} \, \mathbf{h}(t) + \mathbf{B} \, u(t)$$
$$y(t) = \mathbf{C} \, \mathbf{h}(t) + \mathbf{D} \, u(t)$$

The continuous state space

$u(t)$
The input signal at time $t$. Scalar for 1-D, vector for multi-channel (e.g., a token embedding).
$\mathbf{h}(t) \in \mathbb{R}^N$
The hidden state — an $N$-dimensional memory that summarizes everything the system has seen so far. $N$ is typically 16 or 64.
$\mathbf{A} \in \mathbb{R}^{N \times N}$
The state transition matrix: how the state evolves on its own, without new input. If $\mathbf{A}$ has negative eigenvalues, the system forgets; if positive, it explodes; if imaginary, it oscillates.
$\mathbf{B} \in \mathbb{R}^{N \times 1}$
The input-to-state matrix: how much new input pushes into each state dimension.
$\mathbf{C} \in \mathbb{R}^{1 \times N}$
The state-to-output matrix: a linear projection from state to observable output.
$\mathbf{D} \in \mathbb{R}$
A direct feedthrough (skip connection) from input to output. Often set to 0 in deep learning contexts.

Analogy Picture a spring-mass-damper system. The "state" is [position, velocity]. The ODE says: velocity changes based on position and the external force; position changes based on velocity. $\mathbf{A}$ encodes the spring constant and damping; $\mathbf{B}$ encodes how much an external push accelerates the mass. Any linear dynamical system — a circuit, a thermostat, a cruise-control loop — looks like this. SSMs bring this century-old formalism into deep learning.

3. Discretization

Neural networks work in discrete steps, not continuous time. To use an SSM in a network, you discretize the ODE with a step size $\Delta$. Using the zero-order hold rule:

$$\bar{\mathbf{A}} = \exp(\Delta \mathbf{A}), \quad \bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}$$

The discrete SSM is then a linear recurrence:

$$\mathbf{h}_t = \bar{\mathbf{A}} \, \mathbf{h}_{t-1} + \bar{\mathbf{B}} \, u_t, \qquad y_t = \mathbf{C} \, \mathbf{h}_t$$

Discrete SSM

$\Delta$
The time step — how much "real time" each discrete step represents. In a language SSM this is a learned, data-dependent parameter.
$\bar{\mathbf{A}}, \bar{\mathbf{B}}$
The discretized versions of $\mathbf{A}$ and $\mathbf{B}$. The bar distinguishes them from their continuous-time originals.
$\mathbf{h}_t$
The discrete hidden state at step $t$. Now the formula reads exactly like an RNN.
$y_t$
The output at step $t$ — a linear readout of the state.

The dual view This recurrence is algebraically identical to a very specific linear RNN — no nonlinearity between steps, and with structured transition matrices. That sounds like a weakness, but it unlocks two superpowers: (1) you can unroll the recurrence into a global convolution and compute all outputs in parallel in $O(L \log L)$ via FFT; (2) at inference time you can switch back to the recurrence and decode one step at a time in $O(1)$ memory.

This dual view — "same model, two compute modes" — is the whole point of SSMs. During training you train as a convolution on the whole sequence (fast, parallel). During autoregressive inference you run as a recurrence (tiny state, no KV cache). Transformers can't do this: their attention is fundamentally non-recurrent and forces you to keep an ever-growing KV cache at inference.

4. S4 — HIPPO and structured state spaces

Before Mamba there was S4 (Gu, Goel, Ré 2021). S4's contribution was two-fold:

  1. It chose $\mathbf{A}$ to be a specific structured matrix — the HIPPO matrix, derived from a theory of optimal polynomial basis projection. HIPPO-LegS guarantees that the state summarizes past inputs as Legendre polynomial coefficients, giving provably good long-range memory.
  2. It factored $\mathbf{A}$ into a low-rank plus normal form so that the convolution kernel can be computed efficiently without ever materializing $\bar{\mathbf{A}}^L$ directly.

With a HIPPO-initialized S4 layer, a model could recall information from tens of thousands of steps ago — something vanilla RNNs and even Transformers with full attention struggled with on the Long Range Arena benchmark.

S4's limitation: $\mathbf{A}$, $\mathbf{B}$, $\mathbf{C}$, $\Delta$ were all input-independent. The model applied the same dynamics to every token. For language, where different tokens need the model to remember or forget different things, that's too rigid.

5. Mamba's selectivity

Mamba's move (Gu & Dao, 2023) is deceptively small: make $\mathbf{B}$, $\mathbf{C}$, and $\Delta$ functions of the input. The $\mathbf{A}$ matrix stays structured and shared, but at each time step:

$$\mathbf{B}_t = \text{Linear}_B(u_t), \quad \mathbf{C}_t = \text{Linear}_C(u_t), \quad \Delta_t = \text{softplus}(\text{Linear}_\Delta(u_t))$$

Selective SSM

$\mathbf{B}_t, \mathbf{C}_t$
Now depend on the current input $u_t$ through learned linear projections. Different tokens push different information into the hidden state and read different information out.
$\Delta_t$
A per-token, learned step size. Large $\Delta_t$ means "this token matters — update the state a lot"; small $\Delta_t$ means "this token is filler — barely update." This is how selection happens.
$\text{softplus}$
A smooth positive function that keeps $\Delta_t > 0$ (step size must be positive for the ODE interpretation).
$\mathbf{A}$
Still structured and input-independent. Selectivity runs through $\mathbf{B}, \mathbf{C}, \Delta$, not $\mathbf{A}$.

Why it's a big deal With input-dependent $\Delta_t$, the model can effectively "pay attention" — it learns to take big steps on important tokens and small steps on irrelevant ones, compressing or expanding time. Structurally, the resulting layer is still a recurrence, so inference is $O(L)$ and the KV cache is gone. But because $\mathbf{B}, \mathbf{C}$ now vary per step, the layer can copy/delete specific pieces of the state — recovering much of the flexibility that pure SSMs lacked.

The cost of selectivity: the recurrence is no longer a pure convolution, so the nice FFT trick doesn't apply directly. Mamba solves this with a hardware-aware parallel scan implemented as a custom CUDA kernel. The scan computes the recurrence in $O(L)$ work and $O(\log L)$ depth, streaming state through GPU SRAM so it never touches slow HBM. This implementation detail is what made Mamba actually fast in practice, not just fast on paper.

By 2024 Mamba-2 tightened the connection between SSMs and linear attention (showing they're two views of the same object) and landed hybrids like Jamba (AI21) and Zamba, which interleave Mamba blocks with a few attention layers to get Transformer-quality recall plus SSM-quality speed.

6. Interactive SSM demo

A 1-D linear SSM running on an input pulse train. The hidden state is 2-dimensional. Adjust $\mathbf{A}$'s eigenvalue and the step $\Delta$ to see the state response change: stable decay vs. oscillation vs. blow-up. Small $\Delta$ = slow integration; large $\Delta$ = fast response.

A (decay rate): -1.00 Δ: 0.50

Drag the sliders to see how A and Δ shape the state response to an input pulse train.

7. Source code

A minimal selective SSM, scaled down. Not Mamba — just the core recurrence with per-step B, C, Δ.

state space model · two views
import numpy as np

def ssm_recurrent(u, A, B, C, delta):
    # Discrete SSM: h_t = A_bar h_{t-1} + B_bar u_t
    # u: (T,) input sequence; A, B: state dyn; C: readout
    A_bar = np.exp(delta * A)                        # (N, N) diagonal case: elementwise
    B_bar = ((A_bar - 1) / A) * B                   # (N,) for 1-D input

    N = A.shape[0]
    h = np.zeros(N)
    y = np.zeros_like(u)
    for t in range(len(u)):
        h = A_bar @ h + B_bar * u[t]
        y[t] = C @ h
    return y                                        # O(T * N) — inference-friendly
import numpy as np

def ssm_convolutional(u, A, B, C, delta, L):
    # Unroll the recurrence into a single global convolution kernel K.
    # K[t] = C · A_bar^t · B_bar, for t = 0..L-1
    A_bar = np.exp(delta * A)
    B_bar = ((A_bar - 1) / A) * B

    K = np.zeros(L)
    Ak = np.eye(A.shape[0])
    for t in range(L):
        K[t] = C @ (Ak @ B_bar)
        Ak = A_bar @ Ak
    # y = causal conv(u, K). In practice use FFT for O(L log L).
    return np.convolve(u, K)[:L]                    # training-friendly, parallel
import torch, torch.nn as nn

class SelectiveSSM(nn.Module):
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d, self.N = d_model, d_state
        # Structured diagonal A — log-parameterized so eigenvalues stay negative.
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).float()))
        # Input-dependent B, C, Δ come from linear projections of u:
        self.to_B    = nn.Linear(d_model, d_state, bias=False)
        self.to_C    = nn.Linear(d_model, d_state, bias=False)
        self.to_dt   = nn.Linear(d_model, 1, bias=True)

    def forward(self, u):                             # u: (B, T, d)
        B, T, d = u.shape
        A = -torch.exp(self.A_log)                   # (N,) negative eigenvalues
        Bp = self.to_B(u)                            # (B, T, N)
        Cp = self.to_C(u)                            # (B, T, N)
        dt = torch.softplus(self.to_dt(u))           # (B, T, 1) — the selective Δ

        # Discretize per timestep (diagonal A makes this elementwise)
        A_bar = torch.exp(dt * A)                    # (B, T, N)
        B_bar = dt * Bp                              # (B, T, N) — approximate

        # Sequential scan. Real Mamba uses a parallel associative scan in CUDA.
        h = torch.zeros(B, self.N, device=u.device)
        outs = []
        for t in range(T):
            h = A_bar[:, t] * h + B_bar[:, t] * u[:, t].mean(-1, keepdim=True)
            outs.append((Cp[:, t] * h).sum(-1))
        return torch.stack(outs, dim=1)                # (B, T)

8. Summary

Further reading

  • Gu, Goel & Ré (2021) — Efficiently Modeling Long Sequences with Structured State Spaces (S4).
  • Gu, Johnson et al. (2020) — HiPPO: Recurrent Memory with Optimal Polynomial Projections.
  • Gu & Dao (2023) — Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
  • Dao & Gu (2024) — Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (Mamba-2).
  • Lieber et al. (2024) — Jamba: A Hybrid Transformer-Mamba Language Model.
NEXT UP
→ Foundation Models

SSMs are an architecture. Foundation models are the paradigm that turns any architecture into an ecosystem. Read on for scaling laws and the economics of pretraining.