Edge & On-Device AI
How you fit a language model onto a phone, a laptop, a microcontroller. The four main tricks — quantization, distillation, pruning, speculative decoding — and the math that makes each of them lose less accuracy than you'd expect.
1. Why bring the model to you?
For most of deep learning's history, inference ran on servers. You sent your input to a GPU cluster, a big model crunched on it, and you got an answer back. That worked for vision APIs and chatbots but breaks down for anything where you can't afford the round trip: voice assistants, live camera effects, keyboards, cars, medical devices.
By 2024, the trend has reversed. Phones, laptops, and even microcontrollers are running multi-billion-parameter models locally. Apple Intelligence runs a ~3B on-device model. Gemini Nano ships with Pixel phones. Llama-3.2-1B and Phi-3-mini are designed to be fast on consumer hardware. Copilot+ PCs have NPUs rated at 40 TOPS for exactly this workload.
Three reasons this matters:
- Latency. A local model responds in ~10 ms; a server round trip is 100–1000 ms. For voice, typing, camera, this difference is the difference between "useful" and "broken."
- Privacy. Your photos, messages, health data never leave the device. Some applications (EU healthcare, financial compliance, enterprise) can't legally use cloud inference.
- Cost. The marginal cost of one more inference on a user's device is zero. At a few billion queries a day, this turns into serious savings.
The challenge is that consumer hardware has ~10× less memory and ~100× less compute than server GPUs. You need aggressive compression that doesn't destroy accuracy. Four techniques carry the load.
2. Quantization
Models are trained in 16- or 32-bit floats. Inference doesn't need that precision. Quantization reduces the number of bits per parameter — from 16 down to 8, 4, or even 2 — and in exchange you get smaller models and faster matmuls (int8 and int4 kernels run 2–4× faster than fp16 on GPUs, and many-fold faster on NPUs).
The basic int8 mapping for a weight tensor $\mathbf{W}$ with range $[w_\text{min}, w_\text{max}]$ is:
Affine int8 quantization
- $w$
- The original full-precision weight (a 16- or 32-bit float).
- $s$
- The scale — one float per tensor (or per channel). It maps the integer range $[0, 255]$ back to the original weight range.
- $z$
- The zero point — an integer offset so that the int value $z$ represents the float 0. Matters if your weight distribution isn't symmetric around zero.
- $q \in [0, 255]$
- The quantized 8-bit unsigned integer actually stored in memory. Each weight is one byte.
- $\hat{w}$
- The dequantized weight — the approximation we get back from the integer. Typically within ~0.5% of $w$ for int8.
- $255$
- The number of representable levels for unsigned int8. For signed int8 it's 256; for int4 it's only 16.
Analogy A ruler with 256 tick marks instead of 4 billion. You lose precision, but if your measurement is "about 3.7 cm," 256 ticks is still more than enough. For most weights in a neural net, this is true — the extra bits of fp16 weren't actually carrying information. For some outliers, it isn't, and that's where modern techniques (AWQ, GPTQ, SmoothQuant) earn their keep.
Going further — int4 quantization, as in QLoRA, GPTQ, and AWQ — gives you 4× compression over fp16 at a cost of 1-3% accuracy for well-designed schemes. Group quantization (one scale per 64 or 128 weights instead of per-tensor) is what makes int4 work at all; without it, the dynamic range is too tight to represent.
The 2024 frontier: int2 and 1.58-bit (BitNet b1.58, Ma et al.). BitNet replaces weights with $\{-1, 0, 1\}$ and gets surprisingly close to full-precision accuracy — if you train from scratch with quantization in the loop (Quantization-Aware Training) rather than post-hoc quantizing a pretrained model.
3. Interactive quantization
A histogram of a single layer's weights. Drag the slider to change the bit width; the bins coarsen, and the reconstruction error rises. Notice how int8 is visually indistinguishable from fp32, while int4 starts to show rounding.
Blue: the original fp32 weight distribution. Orange: the quantized approximation. The RMS error grows as bit width shrinks.
4. Knowledge distillation
Instead of compressing a big model, you train a small one to imitate the big one's output distribution. The canonical recipe (Hinton, Vinyals, Dean 2015):
Distillation loss
- $p_T, p_S$
- The teacher's and student's output distributions (the softmax logits). Teacher is big and frozen; student is small and being trained.
- $\text{CE}(y, p_S)$
- Standard cross-entropy on the ground-truth label $y$. Keeps the student grounded in real labels.
- $\text{KL}(p_T \| p_S)$
- KL divergence between teacher and student distributions. This is the "imitation" term: the student learns to match the teacher's full probability vector, not just its top-1 answer.
- $\tau$
- Temperature. Applied to both distributions before the softmax. Larger $\tau$ softens the distributions, making the "wrong class" probabilities more visible. Typically 2–4.
- $\lambda$
- The mixing coefficient. Often 0.5 — half real labels, half teacher.
- $\tau^2$
- A scaling factor on the KL term. Compensates for the fact that softening the distributions with $\tau$ shrinks the gradient magnitude by $1/\tau^2$.
Why it helps The teacher's distribution contains "dark knowledge" — information about relative similarities between classes that a one-hot label throws away. "This is a cat" is less informative than "this is 70% cat, 20% lynx, 5% dog, 5% other." The student learns these similarities and generalizes better. Distilled models often match their teachers to within 1-2% while being 4-10× smaller.
Modern on-device LLMs — Gemini Nano, Phi-3-mini, Llama-3.2-1B — are all distilled from larger siblings. The distillation happens on a mix of the teacher's outputs on real data and synthetic data generated by the teacher. For a 1B student trained on a 70B teacher, the results are within 3–5 points on most benchmarks at a fraction of the compute.
5. Pruning
A surprising fact about trained neural networks: most of their weights are near zero and contribute little to the output. Pruning zeroes out the small ones, turning a dense matrix into a sparse one. Two regimes:
- Unstructured pruning. Individually set the smallest 50–90% of weights to zero. You get big model-size savings, but most hardware can't actually run sparse matmul faster than dense, so you need specialized kernels (cuSPARSE, NVIDIA's 2:4 sparsity) to see speedups.
- Structured pruning. Remove whole rows, channels, attention heads, or layers. Keeps the computation dense, so every mainstream GPU benefits. More aggressive accuracy loss, but easier to deploy.
The Lottery Ticket Hypothesis (Frankle & Carbin, 2019) gave pruning theoretical weight: a randomly initialized network contains, with high probability, a subnetwork that — if you pruned the rest and reset the remaining weights to their original random values — can be trained to match the full network's accuracy. Suggested that big networks are "over-parameterized" primarily for optimization purposes, not representation.
In 2024, structured pruning is used heavily to make Llama- and Mistral-sized models fit edge budgets (NVIDIA's Minitron line starts from Llama-3.1-8B and prunes it down to 4B or smaller with light re-training). Unstructured pruning is more niche.
6. Speculative decoding
A different kind of inference speedup, orthogonal to compression: speculative decoding (Leviathan et al. 2023; independently Chen et al. 2023). Use a small, fast "draft" model to propose the next several tokens, then verify them all in one parallel pass of the big model. If the big model agrees with the draft, you got several tokens for the price of one forward pass.
The expected speedup is governed by the acceptance rate $\alpha$ — the probability that the draft's token matches what the big model would have sampled. For a budget of $\gamma$ draft tokens per step:
Speculative decoding math
- $\alpha$
- Acceptance probability — chance the draft model's proposed token matches the big model's sampled token. Typically 0.6–0.9 for a well-paired draft/target.
- $\gamma$
- The draft budget — how many tokens the small model speculates ahead per step. Usually 4–8.
- $\mathbb{E}[\text{accepted}]$
- The expected number of tokens ultimately accepted from a $\gamma$-token speculation. The formula is a geometric series.
In practice With $\alpha = 0.8$ and $\gamma = 5$, you accept $\sim 3.4$ tokens per big-model forward pass on average. That's a ~3× throughput speedup for free — correct token-by-token probability distributions, just amortized cost. Llama.cpp, vLLM, and transformers all ship this now. It's a pure-systems win.
7. Source code
Three compression techniques as one-pagers.
import torch
def quantize_int8(w):
# Symmetric per-tensor int8: map [-|w|_max, |w|_max] → [-127, 127]
scale = w.abs().max() / 127
q = torch.round(w / scale).clamp(-127, 127).to(torch.int8)
return q, scale
def dequantize(q, scale):
return q.float() * scale
def int4_group(w, group=128):
# Group quantization: one scale per group of 128 weights. The trick that
# makes int4 work without destroying outlier-sensitive layers.
w = w.reshape(-1, group) # (G, 128)
scale = w.abs().max(dim=1, keepdim=True).values / 7
q = torch.round(w / scale).clamp(-7, 7) # 4-bit signed
return q, scale
# To run a quantized linear layer:
# out = (x @ dequantize(q_w, s).T) + b ← still fp, but memory is 4x smaller
# With fused kernels (bitsandbytes, exllama) the matmul runs directly on ints.
import torch, torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels,
T=3.0, lam=0.5):
# Hard-label cross-entropy on the student's raw logits
ce = F.cross_entropy(student_logits, labels)
# Soft-label KL divergence at temperature T
s_soft = F.log_softmax(student_logits / T, dim=-1)
t_soft = F.softmax(teacher_logits / T, dim=-1)
kl = F.kl_div(s_soft, t_soft, reduction="batchmean") * (T ** 2)
return (1 - lam) * ce + lam * kl
# Training loop
for batch in loader:
with torch.no_grad():
t_logits = teacher(batch["x"]) # frozen
s_logits = student(batch["x"])
loss = distillation_loss(s_logits, t_logits, batch["y"])
optimizer.zero_grad(); loss.backward(); optimizer.step()
import torch
def speculative_decode(target, draft, prompt, gamma=5, max_new=128):
# target: the big, accurate model; draft: a small fast model
x = prompt
while x.shape[-1] - prompt.shape[-1] < max_new:
# 1. Draft model speculates γ tokens ahead
draft_tokens = []
draft_probs = []
dx = x.clone()
for _ in range(gamma):
p_d = torch.softmax(draft(dx)[0, -1], -1)
tok = torch.multinomial(p_d, 1)
draft_tokens.append(tok)
draft_probs.append(p_d[tok])
dx = torch.cat([dx, tok[None]], -1)
# 2. Verify all γ in one big-model pass
logits = target(dx)
p_t = torch.softmax(logits[0, -(gamma + 1):-1], -1)
# 3. Accept as many prefix tokens as agree
accepted = 0
for i, tok in enumerate(draft_tokens):
r = torch.rand(1)
if r < (p_t[i, tok] / draft_probs[i]).clamp(max=1):
accepted += 1
else:
break
x = dx[..., :prompt.shape[-1] + (x.shape[-1] - prompt.shape[-1]) + accepted + 1]
return x
8. Summary
- Edge AI shrinks models to fit phones, laptops, and microcontrollers — trading a little accuracy for huge wins in latency, privacy, and cost.
- Quantization drops weights from fp16 to int8, int4, or even int2, with group-wise scales and outlier handling keeping accuracy high.
- Distillation trains a small student to imitate a big teacher's probability distribution. "Dark knowledge" transfers generalization.
- Pruning zeroes out or removes unimportant weights. Structured pruning works with standard kernels; unstructured pruning needs sparse matmul.
- Speculative decoding uses a small draft model to pre-generate tokens that the big model verifies in parallel — a 2-4× throughput win with zero accuracy loss.
- Apple Intelligence, Gemini Nano, Phi-3, Llama-3.2-1B all live on this stack. Expect 2026 consumer devices to routinely run 8B+ models.
Further reading
- Hinton, Vinyals & Dean (2015) — Distilling the Knowledge in a Neural Network.
- Frantar et al. (2022) — GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
- Lin et al. (2023) — AWQ: Activation-aware Weight Quantization for LLM Compression.
- Ma et al. (2024) — The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits.
- Leviathan, Kalman & Matias (2023) — Fast Inference from Transformers via Speculative Decoding.