Mixture of Experts

A trillion-parameter model that runs as cheaply as a 30B one. The routing trick that lets you decouple "how much a model knows" from "how much compute each token costs" — and the load-balancing problem that keeps everyone up at night.

Prereq: MLPs, softmax Time to read: ~22 min Interactive figures: 1 Code: NumPy, PyTorch, JAX

1. The core idea

A standard Transformer feed-forward block takes a token's hidden state $\mathbf{h} \in \mathbb{R}^d$ and runs it through one big MLP:

$$\text{FFN}(\mathbf{h}) = W_2 \, \phi(W_1 \mathbf{h})$$

In a dense model like Llama-3-70B, $W_1$ and $W_2$ are roughly $d \times 4d$ and $4d \times d$. Every token touches every weight. Capacity and compute are the same number — if you want the model to know twice as much, it costs twice as much to run.

Mixture of Experts breaks that coupling. Replace the single FFN with $N$ parallel FFNs, each called an expert:

$$E_1(\mathbf{h}),\, E_2(\mathbf{h}),\, \dots,\, E_N(\mathbf{h})$$

Then a tiny router (also called a gate) decides, per token, which experts to send it to. Crucially, the token is sent to only $k \ll N$ of them — typically $k = 1$ or $2$. The output is a weighted sum of the active experts' outputs:

$$\text{MoE}(\mathbf{h}) = \sum_{i \in \text{top-}k} g_i(\mathbf{h}) \, E_i(\mathbf{h})$$

The MoE equation explained

$\mathbf{h}$
The token's hidden state arriving at this MoE block — a vector of dimension $d$ (say 4096 for a large model).
$N$
Total number of experts in the layer. DeepSeek-V3 has 256, Mixtral has 8, GShard originally had thousands.
$k$
How many experts each token gets routed to. Small: 1 (Switch Transformer), 2 (Mixtral), or 8 (DeepSeek). The rest of the experts do nothing for this token.
$E_i(\mathbf{h})$
Expert $i$'s FFN applied to the token. Architecturally identical to a dense FFN; the experts differ only in their learned weights.
$g_i(\mathbf{h})$
The router's weight for expert $i$ on this token. A small positive number. Comes from a softmax over the router logits.
top-$k$
The indices of the $k$ experts with the largest router scores. Everyone else gets weight 0 and isn't computed.

Analogy A hospital with 256 specialists but only 8 doctors on duty per patient. The triage nurse (the router) looks at the incoming patient (token) for a moment, then picks the 8 most relevant specialists. The other 248 are paid to exist but sit in their offices. The hospital has the knowledge of 256 specialists at the cost of 8 consults per patient.

Total parameters scale with $N$. Compute per token scales with $k$. Those are now two independent knobs, and you can crank parameters without cranking cost. A 671B-parameter DeepSeek-V3 activates only ~37B parameters per token. The "effective model size" is 671B; the serving cost is closer to a 37B dense model.

2. The router (gating network)

The router is the simplest part of the system — it's a single linear layer followed by a softmax:

$$\mathbf{s} = W_r \mathbf{h}, \quad \mathbf{g} = \text{softmax}(\mathbf{s})$$

The router network

$W_r \in \mathbb{R}^{N \times d}$
The router's only learnable parameters. One row per expert. A tiny matrix — for a 256-expert, $d=4096$ layer, that's ~1M params vs. billions in the experts themselves.
$\mathbf{s} \in \mathbb{R}^N$
The raw router logits — one score per expert, indicating how much the router "wants" each expert to handle this token.
$\mathbf{g} \in \Delta^N$
The softmax-normalized probabilities, summing to 1. $g_i$ is the fraction of this token's output that expert $i$ contributes.
$\text{softmax}$
$\text{softmax}(\mathbf{s})_i = \frac{e^{s_i}}{\sum_j e^{s_j}}$. Turns raw scores into a probability distribution.

Analogy The router is like asking "how much does this patient look like something for each of my 256 specialists?" and getting a percentage answer. We then pick the top $k$ percentages and send the patient to those.

Before taking the top-$k$, the router logits are usually perturbed with noise during training. That noise is what lets gradients flow to experts that would otherwise never be selected — without it, the argmax in top-$k$ kills learning for everyone not currently winning.

3. Top-k sparsification

After softmax, we keep only the $k$ largest entries of $\mathbf{g}$ and zero out the rest:

$$\tilde{g}_i = \begin{cases} g_i & \text{if } i \in \text{top-}k(\mathbf{g}) \\ 0 & \text{otherwise} \end{cases}$$

You usually renormalize the kept weights so they sum to 1 again, then compute only those $k$ experts:

$$\text{MoE}(\mathbf{h}) = \sum_{i=1}^{N} \tilde{g}_i \, E_i(\mathbf{h}) = \sum_{i \in \text{top-}k} \frac{g_i}{\sum_{j \in \text{top-}k} g_j} \, E_i(\mathbf{h})$$

The key insight: the sum looks like it's over all $N$ experts, but $N - k$ terms are zero. The GPU never touches those weights. Forward and backward cost scale with $k$, not $N$.

4. Interactive router demo

Below is a live 4-expert MoE layer with $k = 2$. Click on one of the tokens to "send" it through the router. You'll see the router logits, the softmax probabilities, and which two experts wake up to process it. The four experts are colored differently — expert 0 likes "warm" topics, 1 likes "cold", 2 likes "large numbers", 3 likes "small numbers". They've been pretrained in this simulation.

Click a token:

Pick a token above. Active experts light up; dim experts are skipped (zero compute).

5. Load balancing

The dangerous failure mode of MoE is expert collapse: the router learns to send every token to the same 2 or 3 experts. Those experts get all the gradient updates, so they improve fastest, so the router sends them even more tokens, so they improve even faster. It's a winner-takes-all death spiral. The other $N - 3$ experts are dead weight in the checkpoint.

The fix is an auxiliary loss that punishes uneven assignment. Let $f_i$ be the fraction of tokens routed to expert $i$ across a batch, and $P_i$ the mean router probability for expert $i$. Then the load-balancing loss is:

$$\mathcal{L}_{\text{balance}} = \alpha \cdot N \sum_{i=1}^{N} f_i \cdot P_i$$

Balance loss explained

$f_i$
The fraction of tokens in this batch that chose expert $i$ in their top-$k$. A histogram over experts.
$P_i$
The average of the router's softmax output for expert $i$ across the batch — how much the router wanted to use expert $i$ on average.
$\alpha$
A small coefficient, typically 0.01 or 0.001. The main loss is the language-modeling loss; this is a gentle nudge on top.
$N$
The scaling factor $N$ ensures the loss has magnitude ~1 when assignment is uniform (each $f_i = P_i = 1/N$), so it doesn't depend on the number of experts.

Why it works The loss is minimized when both $f_i$ and $P_i$ are uniform across experts. If expert 3 is getting way more tokens than average, the loss pushes the router to lower its score for expert 3 on the next batch. It's a pressure valve: any expert that starts to dominate gets penalized until routing evens out.

DeepSeek-V3 (2024) replaced this auxiliary loss with a bias term added to each expert's router score, adjusted step-by-step: overloaded experts get their bias lowered, underloaded experts get their bias raised. The token routing remains unchanged, but balance is maintained without the hack of an extra loss term. This is called auxiliary-loss-free load balancing and is a notable 2024 refinement.

6. Systems challenges

On paper, MoE is an algorithmic trick. In practice, it's a systems nightmare because experts are distributed across many GPUs, and routing means every forward pass has to shuffle tokens between them:

7. Source code

A minimal MoE layer in three frameworks. No load balancing, no all-to-all — just the core logic of "route, pick top-k, compute, combine."

mixture of experts · forward pass
import numpy as np

def softmax(x, axis=-1):
    e = np.exp(x - x.max(axis=axis, keepdims=True))
    return e / e.sum(axis=axis, keepdims=True)

def moe_forward(h, W_r, experts, k=2):
    # h: (B, d) — batch of tokens
    # W_r: (N, d) — router weights for N experts
    # experts: list of N callables h -> h
    B, d = h.shape
    N = W_r.shape[0]

    logits = h @ W_r.T                      # (B, N)
    probs  = softmax(logits)               # (B, N)

    # Top-k indices per token
    topk_idx = np.argsort(-probs, axis=1)[:, :k]   # (B, k)
    topk_prob = np.take_along_axis(probs, topk_idx, axis=1)
    topk_prob = topk_prob / topk_prob.sum(axis=1, keepdims=True)

    out = np.zeros_like(h)
    for b in range(B):
        for j in range(k):
            i = topk_idx[b, j]
            out[b] += topk_prob[b, j] * experts[i](h[b])
    return out
import torch, torch.nn as nn, torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, d, dff):
        super().__init__()
        self.w1 = nn.Linear(d, dff, bias=False)
        self.w2 = nn.Linear(dff, d, bias=False)
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)))

class MoE(nn.Module):
    def __init__(self, d=4096, dff=11008, N=8, k=2):
        super().__init__()
        self.N, self.k = N, k
        self.router  = nn.Linear(d, N, bias=False)
        self.experts = nn.ModuleList([Expert(d, dff) for _ in range(N)])

    def forward(self, x):                 # x: (B, T, d)
        B, T, d = x.shape
        x_flat  = x.reshape(-1, d)          # (BT, d)
        logits  = self.router(x_flat)        # (BT, N)
        probs   = F.softmax(logits, dim=-1)
        topk_p, topk_i = probs.topk(self.k, dim=-1)
        topk_p  = topk_p / topk_p.sum(dim=-1, keepdim=True)

        out = torch.zeros_like(x_flat)
        for e in range(self.N):
            mask = (topk_i == e).any(dim=-1)     # (BT,)
            if not mask.any(): continue
            w = (topk_p * (topk_i == e).float()).sum(-1)
            out[mask] += w[mask, None] * self.experts[e](x_flat[mask])
        return out.reshape(B, T, d)
import jax, jax.numpy as jnp
from flax import linen as nn

class MoE(nn.Module):
    d: int; dff: int; N: int; k: int

    @nn.compact
    def __call__(self, x):            # (B, T, d)
        logits = nn.Dense(self.N, use_bias=False, name="router")(x)
        probs  = jax.nn.softmax(logits)

        topk_p, topk_i = jax.lax.top_k(probs, self.k)
        topk_p = topk_p / topk_p.sum(-1, keepdims=True)

        # Compute all experts then gather — naive but vectorizes cleanly.
        def expert(i, x):
            h = nn.Dense(self.dff, name=f"e{i}_w1")(x)
            return nn.Dense(self.d, name=f"e{i}_w2")(jax.nn.silu(h))
        stacked = jnp.stack([expert(i, x) for i in range(self.N)], -2)
        # stacked: (B, T, N, d) — pick top-k and weight them

        onehot = jax.nn.one_hot(topk_i, self.N)          # (B, T, k, N)
        weights = (topk_p[..., None] * onehot).sum(-2) # (B, T, N)
        return (stacked * weights[..., None]).sum(-2)

A real production MoE layer adds: noise injection during training, load-balancing loss or DeepSeek-style bias adjustment, expert capacity limits with token dropping, all-to-all communication primitives for expert parallelism, and grouped matmul kernels that process tokens bound for the same expert together. The above is the conceptual kernel, not the shippable one.

8. A brief history

9. Summary

Further reading

  • Jacobs et al. (1991) — Adaptive Mixtures of Local Experts.
  • Shazeer et al. (2017) — Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer.
  • Fedus, Zoph & Shazeer (2021) — Switch Transformers: Scaling to Trillion Parameter Models.
  • Zhou et al. (2022) — Mixture-of-Experts with Expert Choice Routing.
  • DeepSeek-AI (2024) — DeepSeek-V3 Technical Report.
NEXT UP
→ State Space Models & Mamba

MoE makes width cheap. State Space Models make length cheap. Together they're the two main axes on which 2024–26 models escape Transformer limits.