Mixture of Experts
A trillion-parameter model that runs as cheaply as a 30B one. The routing trick that lets you decouple "how much a model knows" from "how much compute each token costs" — and the load-balancing problem that keeps everyone up at night.
1. The core idea
A standard Transformer feed-forward block takes a token's hidden state $\mathbf{h} \in \mathbb{R}^d$ and runs it through one big MLP:
In a dense model like Llama-3-70B, $W_1$ and $W_2$ are roughly $d \times 4d$ and $4d \times d$. Every token touches every weight. Capacity and compute are the same number — if you want the model to know twice as much, it costs twice as much to run.
Mixture of Experts breaks that coupling. Replace the single FFN with $N$ parallel FFNs, each called an expert:
Then a tiny router (also called a gate) decides, per token, which experts to send it to. Crucially, the token is sent to only $k \ll N$ of them — typically $k = 1$ or $2$. The output is a weighted sum of the active experts' outputs:
The MoE equation explained
- $\mathbf{h}$
- The token's hidden state arriving at this MoE block — a vector of dimension $d$ (say 4096 for a large model).
- $N$
- Total number of experts in the layer. DeepSeek-V3 has 256, Mixtral has 8, GShard originally had thousands.
- $k$
- How many experts each token gets routed to. Small: 1 (Switch Transformer), 2 (Mixtral), or 8 (DeepSeek). The rest of the experts do nothing for this token.
- $E_i(\mathbf{h})$
- Expert $i$'s FFN applied to the token. Architecturally identical to a dense FFN; the experts differ only in their learned weights.
- $g_i(\mathbf{h})$
- The router's weight for expert $i$ on this token. A small positive number. Comes from a softmax over the router logits.
- top-$k$
- The indices of the $k$ experts with the largest router scores. Everyone else gets weight 0 and isn't computed.
Analogy A hospital with 256 specialists but only 8 doctors on duty per patient. The triage nurse (the router) looks at the incoming patient (token) for a moment, then picks the 8 most relevant specialists. The other 248 are paid to exist but sit in their offices. The hospital has the knowledge of 256 specialists at the cost of 8 consults per patient.
Total parameters scale with $N$. Compute per token scales with $k$. Those are now two independent knobs, and you can crank parameters without cranking cost. A 671B-parameter DeepSeek-V3 activates only ~37B parameters per token. The "effective model size" is 671B; the serving cost is closer to a 37B dense model.
2. The router (gating network)
The router is the simplest part of the system — it's a single linear layer followed by a softmax:
The router network
- $W_r \in \mathbb{R}^{N \times d}$
- The router's only learnable parameters. One row per expert. A tiny matrix — for a 256-expert, $d=4096$ layer, that's ~1M params vs. billions in the experts themselves.
- $\mathbf{s} \in \mathbb{R}^N$
- The raw router logits — one score per expert, indicating how much the router "wants" each expert to handle this token.
- $\mathbf{g} \in \Delta^N$
- The softmax-normalized probabilities, summing to 1. $g_i$ is the fraction of this token's output that expert $i$ contributes.
- $\text{softmax}$
- $\text{softmax}(\mathbf{s})_i = \frac{e^{s_i}}{\sum_j e^{s_j}}$. Turns raw scores into a probability distribution.
Analogy The router is like asking "how much does this patient look like something for each of my 256 specialists?" and getting a percentage answer. We then pick the top $k$ percentages and send the patient to those.
Before taking the top-$k$, the router logits are usually perturbed with noise during training. That noise is what lets gradients flow to experts that would otherwise never be selected — without it, the argmax in top-$k$ kills learning for everyone not currently winning.
3. Top-k sparsification
After softmax, we keep only the $k$ largest entries of $\mathbf{g}$ and zero out the rest:
You usually renormalize the kept weights so they sum to 1 again, then compute only those $k$ experts:
The key insight: the sum looks like it's over all $N$ experts, but $N - k$ terms are zero. The GPU never touches those weights. Forward and backward cost scale with $k$, not $N$.
4. Interactive router demo
Below is a live 4-expert MoE layer with $k = 2$. Click on one of the tokens to "send" it through the router. You'll see the router logits, the softmax probabilities, and which two experts wake up to process it. The four experts are colored differently — expert 0 likes "warm" topics, 1 likes "cold", 2 likes "large numbers", 3 likes "small numbers". They've been pretrained in this simulation.
Pick a token above. Active experts light up; dim experts are skipped (zero compute).
5. Load balancing
The dangerous failure mode of MoE is expert collapse: the router learns to send every token to the same 2 or 3 experts. Those experts get all the gradient updates, so they improve fastest, so the router sends them even more tokens, so they improve even faster. It's a winner-takes-all death spiral. The other $N - 3$ experts are dead weight in the checkpoint.
The fix is an auxiliary loss that punishes uneven assignment. Let $f_i$ be the fraction of tokens routed to expert $i$ across a batch, and $P_i$ the mean router probability for expert $i$. Then the load-balancing loss is:
Balance loss explained
- $f_i$
- The fraction of tokens in this batch that chose expert $i$ in their top-$k$. A histogram over experts.
- $P_i$
- The average of the router's softmax output for expert $i$ across the batch — how much the router wanted to use expert $i$ on average.
- $\alpha$
- A small coefficient, typically 0.01 or 0.001. The main loss is the language-modeling loss; this is a gentle nudge on top.
- $N$
- The scaling factor $N$ ensures the loss has magnitude ~1 when assignment is uniform (each $f_i = P_i = 1/N$), so it doesn't depend on the number of experts.
Why it works The loss is minimized when both $f_i$ and $P_i$ are uniform across experts. If expert 3 is getting way more tokens than average, the loss pushes the router to lower its score for expert 3 on the next batch. It's a pressure valve: any expert that starts to dominate gets penalized until routing evens out.
DeepSeek-V3 (2024) replaced this auxiliary loss with a bias term added to each expert's router score, adjusted step-by-step: overloaded experts get their bias lowered, underloaded experts get their bias raised. The token routing remains unchanged, but balance is maintained without the hack of an extra loss term. This is called auxiliary-loss-free load balancing and is a notable 2024 refinement.
6. Systems challenges
On paper, MoE is an algorithmic trick. In practice, it's a systems nightmare because experts are distributed across many GPUs, and routing means every forward pass has to shuffle tokens between them:
- All-to-all communication. Each GPU holds a handful of experts but sees tokens destined for all of them. Before the expert compute starts, every GPU has to send its non-local tokens to the GPUs that host the target experts. Then after compute, the reverse — send outputs back. This is two
all_to_allcollectives per MoE layer. On a fast interconnect (NVLink, InfiniBand) it's manageable; on slower networks it dominates. - Expert capacity and dropping. If you don't know ahead of time how many tokens will land on each expert, you can't pre-allocate GPU memory for them. The standard fix is a capacity factor: each expert accepts at most $C = \text{capacity\_factor} \cdot k \cdot T / N$ tokens per batch, where $T$ is the batch-times-sequence token count. Overflow tokens are dropped (the expert output is zero, and they rely on the residual stream alone). Capacity factor is usually 1.25 or 2.0.
- Expert parallelism. The dimension along which experts are partitioned. A 64-expert layer split across 8 GPUs puts 8 experts on each. Combined with data, tensor, and pipeline parallelism, this is a 4-way partitioning scheme that DeepSpeed, Megatron, and JAX/Pallas all spend a lot of code handling.
- Memory vs. compute. Weights are large ($N \times$ a full FFN's worth) but compute is small. Inference-time MoE serving is bandwidth-bound, not compute-bound — exactly the opposite of dense models. This is why vLLM and SGLang added special MoE kernels (Marlin, ExLlamaV2-MoE).
7. Source code
A minimal MoE layer in three frameworks. No load balancing, no all-to-all — just the core logic of "route, pick top-k, compute, combine."
import numpy as np
def softmax(x, axis=-1):
e = np.exp(x - x.max(axis=axis, keepdims=True))
return e / e.sum(axis=axis, keepdims=True)
def moe_forward(h, W_r, experts, k=2):
# h: (B, d) — batch of tokens
# W_r: (N, d) — router weights for N experts
# experts: list of N callables h -> h
B, d = h.shape
N = W_r.shape[0]
logits = h @ W_r.T # (B, N)
probs = softmax(logits) # (B, N)
# Top-k indices per token
topk_idx = np.argsort(-probs, axis=1)[:, :k] # (B, k)
topk_prob = np.take_along_axis(probs, topk_idx, axis=1)
topk_prob = topk_prob / topk_prob.sum(axis=1, keepdims=True)
out = np.zeros_like(h)
for b in range(B):
for j in range(k):
i = topk_idx[b, j]
out[b] += topk_prob[b, j] * experts[i](h[b])
return out
import torch, torch.nn as nn, torch.nn.functional as F
class Expert(nn.Module):
def __init__(self, d, dff):
super().__init__()
self.w1 = nn.Linear(d, dff, bias=False)
self.w2 = nn.Linear(dff, d, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)))
class MoE(nn.Module):
def __init__(self, d=4096, dff=11008, N=8, k=2):
super().__init__()
self.N, self.k = N, k
self.router = nn.Linear(d, N, bias=False)
self.experts = nn.ModuleList([Expert(d, dff) for _ in range(N)])
def forward(self, x): # x: (B, T, d)
B, T, d = x.shape
x_flat = x.reshape(-1, d) # (BT, d)
logits = self.router(x_flat) # (BT, N)
probs = F.softmax(logits, dim=-1)
topk_p, topk_i = probs.topk(self.k, dim=-1)
topk_p = topk_p / topk_p.sum(dim=-1, keepdim=True)
out = torch.zeros_like(x_flat)
for e in range(self.N):
mask = (topk_i == e).any(dim=-1) # (BT,)
if not mask.any(): continue
w = (topk_p * (topk_i == e).float()).sum(-1)
out[mask] += w[mask, None] * self.experts[e](x_flat[mask])
return out.reshape(B, T, d)
import jax, jax.numpy as jnp
from flax import linen as nn
class MoE(nn.Module):
d: int; dff: int; N: int; k: int
@nn.compact
def __call__(self, x): # (B, T, d)
logits = nn.Dense(self.N, use_bias=False, name="router")(x)
probs = jax.nn.softmax(logits)
topk_p, topk_i = jax.lax.top_k(probs, self.k)
topk_p = topk_p / topk_p.sum(-1, keepdims=True)
# Compute all experts then gather — naive but vectorizes cleanly.
def expert(i, x):
h = nn.Dense(self.dff, name=f"e{i}_w1")(x)
return nn.Dense(self.d, name=f"e{i}_w2")(jax.nn.silu(h))
stacked = jnp.stack([expert(i, x) for i in range(self.N)], -2)
# stacked: (B, T, N, d) — pick top-k and weight them
onehot = jax.nn.one_hot(topk_i, self.N) # (B, T, k, N)
weights = (topk_p[..., None] * onehot).sum(-2) # (B, T, N)
return (stacked * weights[..., None]).sum(-2)
A real production MoE layer adds: noise injection during training, load-balancing loss or DeepSeek-style bias adjustment, expert capacity limits with token dropping, all-to-all communication primitives for expert parallelism, and grouped matmul kernels that process tokens bound for the same expert together. The above is the conceptual kernel, not the shippable one.
8. A brief history
- 1991 — Jacobs, Jordan, Nowlan & Hinton publish Adaptive Mixtures of Local Experts. The original gating-network-over-subnets idea. Not sparse. Not for language.
- 2017 — Shazeer et al., Outrageously Large Neural Networks. The first sparse MoE that scales: 137B parameters trained on a 1B-dense compute budget. Showed the decoupling trick works.
- 2020 — GShard (Google): applied sparse MoE to a Transformer for machine translation. Introduced the capacity factor and top-2 routing.
- 2021 — Switch Transformer (Fedus, Zoph, Shazeer): went to top-1 routing. Simpler, faster, and actually stable. 1.6T params. Showed training MoE at scale was a real engineering challenge, not just a math trick.
- 2022 — GLaM, ST-MoE, Expert Choice: cleanup. Expert Choice flipped the problem: instead of each token picking $k$ experts, each expert picks $k$ tokens. Guarantees perfect load balance, but breaks causality for autoregressive decoding.
- 2023 — Mixtral 8×7B (Mistral): first mainstream open-weights MoE. 8 experts, top-2, each expert a full Mistral-7B FFN. Proved MoE works in the wild, at home, on consumer GPUs.
- 2024 — DeepSeek-V2 then V3: fine-grained MoE with 160 and 256 experts, shared "always-on" experts for common knowledge, and auxiliary-loss-free load balancing. The 671B-parameter V3 is the current open reference.
- 2025 — Research on dynamic-$k$ routing, hierarchical routing, and "upcycling" dense models into MoE. The design space is still rapidly expanding.
9. Summary
- MoE replaces one big FFN with $N$ expert FFNs and a lightweight router that sends each token to only $k$ of them.
- The result: total parameters scale with $N$, compute per token scales with $k$. You can have a trillion params and pay for a billion.
- The router is one small linear layer + softmax. The magic is the top-$k$ gate that turns a dense weighted sum into a sparse one.
- The core risk is expert collapse — all traffic going to a few experts. Solved by auxiliary load-balance loss, or 2024-era bias-adjusted routing.
- Systems-wise, MoE layers are dominated by all-to-all communication and capacity planning, not by the matmuls themselves.
- Every frontier open model of 2024–25 — Mixtral, DeepSeek-V2/V3, Qwen-MoE, Jamba — uses sparse MoE. It's the de-facto way to scale.
Further reading
- Jacobs et al. (1991) — Adaptive Mixtures of Local Experts.
- Shazeer et al. (2017) — Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer.
- Fedus, Zoph & Shazeer (2021) — Switch Transformers: Scaling to Trillion Parameter Models.
- Zhou et al. (2022) — Mixture-of-Experts with Expert Choice Routing.
- DeepSeek-AI (2024) — DeepSeek-V3 Technical Report.