Speculative Decoding: How to Get 3x LLM Speed With a Smaller Draft Model
Speculative decoding uses a small fast model to draft multiple tokens and a large model to verify them in parallel, achieving 1.5-3x speedups without changing output distribution.
LLMs generate text one token at a time. Each step requires a full forward pass through all model layers - for a 70B model, that is hundreds of matrix multiplications. On a GPU with 80GB of memory bandwidth, generating 1 token takes roughly the same time regardless of how many tokens you could theoretically process in parallel. This is memory-bandwidth-bound inference: the bottleneck is loading model weights, not compute.
The Speculative Decoding Insight
The key observation in the speculative decoding paper (arXiv:2211.17192) is that a large model and a small model often agree on easy tokens (common words, punctuation, predictable continuations). If you could use the small model to propose several tokens at once, then verify them with one parallel forward pass of the large model, you could generate multiple tokens per large-model call.
The algorithm:
Run a small "draft" model autoregressively to generate k tokens (e.g., k=5)
Run the large "target" model on all k tokens in one parallel forward pass
Accept each draft token with probability min(1, p_target / p_draft)
If any token is rejected, resample from the corrected distribution and discard subsequent drafts
Repeat
Team workspace
Ship faster with chat, meetings, and projects in one place — Zlyqor.
The acceptance criterion is derived from rejection sampling theory. When draft and target agree (p_draft ≈ p_target), tokens are almost always accepted. When they disagree, tokens are rejected and resampled from p_target. The mathematical guarantee is that the final output distribution is identical to what the large model would have produced alone - no approximation.
import torch
import torch.nn.functional as F
def speculative_sample(draft_model, target_model, input_ids, k=5):
# Generate k draft tokens with small model
draft_tokens = []
draft_logprobs = []
current_ids = input_ids.clone()
for _ in range(k):
with torch.no_grad():
logits = draft_model(current_ids).logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token)
draft_logprobs.append(probs[0, token.item()])
current_ids = torch.cat([current_ids, token], dim=1)
# Verify with large model in one pass
with torch.no_grad():
all_logits = target_model(current_ids).logits
target_probs = F.softmax(all_logits[:, input_ids.shape[1]-1:-1, :], dim=-1)
# Accept/reject each draft token
accepted = []
for i, (tok, draft_p) in enumerate(zip(draft_tokens, draft_logprobs)):
target_p = target_probs[0, i, tok.item()]
accept_prob = min(1.0, (target_p / draft_p).item())
if torch.rand(1).item() < accept_prob:
accepted.append(tok)
else:
break
return torch.cat([input_ids] + accepted, dim=1)
Medusa: Multi-Head Parallel Prediction
Medusa (arXiv:2401.10774) adds multiple decoding heads to the target model itself, each predicting tokens at different future positions. This eliminates the need for a separate draft model. A tree-based verification scheme accepts the best valid continuation from a structured set of candidates. Medusa achieves 2-3x speedup with no additional model needed.
EAGLE: Context-Aware Drafting
EAGLE improves over Medusa by feeding the target model's hidden states to the draft head, making the draft context-aware. This increases acceptance rates and pushes speedups toward 3x on Llama 2 70B.
When Speculative Decoding Wins
Speculative decoding helps most when inference is memory-bandwidth-bound (large batch size = 1 or small batches), the draft model is fast and cheap (7B draft for 70B target), and draft acceptance rate is high (>70%). It is less helpful for large batches where the target model is already compute-bound.
Practical deep-dives on LLMs, developer tools, and AI engineering. No filler. Unsubscribe any time.
// written byFIG. AUTH-01
530
Mahmudul Haque Qudrati
CEO & ML Engineer
CEO and ML Engineer at Pristren. Builds AI-powered software for teams and writes about machine learning, LLMs, developer tools, and practical AI applications.
Users abandon features above 300ms. Here is the complete playbook for hitting production latency targets: quantization, batching, caching, hardware selection, and pre-computation.