Knowledge distillation is one of the most practical techniques in applied ML. The goal: take a large, accurate model that is too slow or expensive to run in production, and use it to train a small model that approaches the large model's performance at a fraction of the inference cost. DistilBERT achieves 97% of BERT's performance with 40% fewer parameters. This is not magic -- it is the result of a training procedure that extracts more information from each training example.
The Core Idea: Hard Labels vs Soft Targets
Standard supervised training uses hard labels: a binary vector where the correct class is 1 and all others are 0. If you train an image classifier and the image is a cat, the target is [0, 0, 1, 0, 0] (assuming cat is index 2). This binary signal is informative, but it discards a lot of structure.
Now consider what a well-trained teacher model outputs for the same cat image. It might produce something like:
cat: 0.72, tiger: 0.14, lion: 0.08, dog: 0.04, car: 0.02
This distribution is rich with information. It tells the student model that a cat looks somewhat like a tiger and somewhat like a lion, but not at all like a car. The relative probabilities encode the similarity structure of the problem in a way that the binary hard label does not.
Knowledge distillation (Hinton, Vinyals, and Dean, 2015) formalizes this. Instead of training the student on hard labels, you train it to match the teacher's soft probability outputs. The student loss becomes:
Loss = alpha * CrossEntropy(student_logits, hard_labels)
+ (1 - alpha) * KLDivergence(softmax(student_logits/T), softmax(teacher_logits/T))
where T is the temperature parameter and alpha balances the two loss terms.
The Role of Temperature
Temperature controls how "soft" the teacher's probability distribution is. Without temperature (T = 1), the teacher's output probabilities might be quite peaked: cat: 0.998, tiger: 0.001, dog: 0.001. At this point, the soft target carries little more information than the hard label.
At higher temperatures (T = 4 or T = 10), the distribution is flattened: cat: 0.55, tiger: 0.25, lion: 0.12, dog: 0.08. Now the relative probabilities are much more visible and the student can learn from the full inter-class similarity structure.
After training, the student uses T = 1 for inference (standard temperature).
The intuition: high temperature during distillation reveals "dark knowledge" -- the information the teacher has learned about how classes relate to each other, which is mostly hidden in standard low-temperature outputs.
DistilBERT: The Canonical Example
DistilBERT (Sanh, Debut, Chaumond, and Wolf, 2019) demonstrated that distillation could be applied to large pretrained language models, not just task-specific classifiers.
The distillation process for DistilBERT:
- Student architecture: same as BERT-base but with 6 transformer layers instead of 12 (BERT-base has 12 layers)
- Teacher: BERT-base (110M parameters)
- Training: the student is initialized from every other layer of the teacher (a technique called "layer initialization"), then trained to match the teacher's MLM output distributions on the same pretraining corpus
Results:
- 40% fewer parameters (66M vs 110M)
- 60% faster inference
- 97% of BERT-base performance on GLUE benchmark
- 97% of BERT-base performance on SQuAD
The layer initialization trick (starting the student weights from the teacher rather than random initialization) is important. It significantly reduces training time and improves final performance compared to random initialization.
Intermediate Layer Distillation
The original Hinton et al. distillation paper focused on matching the final output distributions. But you can also match intermediate representations:
Feature-based distillation (FitNets, Romero et al., 2015): Train the student's intermediate hidden states to match the teacher's hidden states. This gives the student more signal about the teacher's internal representations, not just its final outputs.
Attention-based distillation (TinyBERT, Jiao et al., 2020): Match both the attention maps and the hidden states at each layer. TinyBERT achieves competitive performance with BERT-base at 1/7th the size.
PKD (Patient Knowledge Distillation, Sun et al., 2019): Match outputs from multiple layers of the teacher, not just the final layer. "Patient" refers to using all layers, not rushing to only look at the final output.
Each of these approaches adds training complexity but can improve the student's performance, particularly when the student is much smaller than the teacher.
Task-Specific vs Task-Agnostic Distillation
Task-agnostic distillation (like DistilBERT) happens at the pretraining stage. The student learns from the teacher on the same general pretraining task (masked language modeling). The resulting student is a general-purpose model that can be fine-tuned on downstream tasks, just like the teacher.
Task-specific distillation happens after fine-tuning. You fine-tune the teacher on task A, then distill the fine-tuned teacher into a student specifically for task A. This typically produces a smaller model with higher task-specific performance than task-agnostic distillation followed by fine-tuning.
For production deployments where you have a specific, stable task (sentiment classification, intent detection, named entity recognition), task-specific distillation is usually the right approach. You get maximum compression for your specific use case.
When to Use Knowledge Distillation in Production
Knowledge distillation is the right tool when:
Inference latency is a hard constraint. If you need sub-10ms inference for a text classifier and a full BERT model takes 50ms, distillation into a 6-layer model might get you to 20ms with minimal accuracy loss.
Deployment target is resource-constrained. Edge devices, mobile, embedded systems. Distillation combined with quantization (INT8) can reduce a BERT model from 440MB to under 50MB.
You have a fixed task and labeled training data. Task-specific distillation works best when you can distill the teacher's knowledge specifically for your use case.
You cannot afford API costs at scale. A distilled model running on your own hardware has predictable, low marginal cost per inference. At millions of calls per day, this matters.
When distillation is NOT worth the effort:
- Your task changes frequently (you would need to re-distill)
- Accuracy requirements are extremely high and the accuracy gap matters
- You are in early prototyping and do not know yet if the model will be used at scale
The Practical Workflow
# Step 1: Fine-tune teacher on your task
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
# ... fine-tune teacher on your labeled data ...
# Step 2: Use teacher to generate soft labels for your training data
teacher.eval()
with torch.no_grad():
teacher_logits = teacher(**inputs).logits # shape: (batch_size, num_labels)
# Step 3: Train student to match teacher logits
student = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)
T = 4.0 # temperature
alpha = 0.5
student_logits = student(**inputs).logits
hard_loss = CrossEntropyLoss()(student_logits, hard_labels)
soft_loss = KLDivLoss()(
F.log_softmax(student_logits / T, dim=-1),
F.softmax(teacher_logits / T, dim=-1)
) * (T ** 2)
loss = alpha * hard_loss + (1 - alpha) * soft_loss
Keep Reading
- BERT Explained for Developers -- DistilBERT is BERT's distilled sibling; understand the parent model first
- ML Model Evaluation Metrics Guide -- how to measure whether your distilled model actually retained performance
- ML Serving Latency Guide -- distillation is one of several strategies for hitting latency targets in production
Pristren builds AI-powered software for teams. Zlyqor is our all-in-one workspace -- chat, projects, time tracking, AI meeting summaries, and invoicing -- in one tool. Try it free.