Knowledge Distillation
Knowledge distillation transfers knowledge from a large "teacher" model to a smaller "student" model. The student learns to mimic the teacher's behavior, achieving better performance than training from scratch.
The Core Idea
[Large Teacher Model]
↓ soft labels (logits)
[Small Student Model] ← learns from teacher
Student learns from teacher's outputs, not just ground truth.
Why Soft Labels?
Hard Labels vs Soft Labels
Hard label: [0, 0, 1, 0] (one-hot)
Soft label: [0.1, 0.15, 0.6, 0.15] (teacher probabilities)
The Information in Soft Labels
Soft labels contain "dark knowledge":
- Class similarities ("3" looks like "8")
- Uncertainty (ambiguous cases)
- Feature relationships
The Distillation Loss
Temperature Scaling
Soften the probability distribution:
p_i = exp(z_i / T) / Σ exp(z_j / T)
T = 1: Normal softmax
T > 1: Softer, more uniform distribution
Combined Loss
L = α × L_hard + (1-α) × L_soft
L_hard = CrossEntropy(student, ground_truth)
L_soft = KL(student_soft, teacher_soft)
Full Formulation
L = α × CE(σ(z_s), y) + (1-α) × T² × KL(σ(z_s/T), σ(z_t/T))
z_s: student logits
z_t: teacher logits
T: temperature
T²: scaling factor for gradient magnitude
Implementation
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.7):
# Soft targets from teacher
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_prob = F.log_softmax(student_logits / temperature, dim=-1)
soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean')
soft_loss = soft_loss * (temperature ** 2) # Scale gradients
# Hard targets (ground truth)
hard_loss = F.cross_entropy(student_logits, labels)
# Combined loss
return alpha * hard_loss + (1 - alpha) * soft_loss
# Training loop
for batch in dataloader:
inputs, labels = batch
with torch.no_grad():
teacher_logits = teacher(inputs)
student_logits = student(inputs)
loss = distillation_loss(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()
Types of Distillation
Response-Based
Match teacher's output predictions:
Student softmax → Teacher softmax
Feature-Based
Match intermediate representations:
Student features → Teacher features
L = MSE(student_hidden, teacher_hidden)
May need projection if dimensions differ.
Relation-Based
Match relationships between samples:
Similarity(student[i], student[j]) ≈ Similarity(teacher[i], teacher[j])
Distillation Architectures
Offline Distillation
1. Train teacher fully
2. Generate teacher outputs
3. Train student with frozen teacher
Most common approach.
Online Distillation
Train teacher and student simultaneously
Mutual learning possible
Self-Distillation
Deeper layers teach shallower layers
Or same architecture, different runs
Applications
Model Compression
BERT-Base (110M) → DistilBERT (66M)
40% smaller, 60% faster, 97% accuracy
LLM Distillation
GPT-4 → Smaller open models
Use teacher outputs as training data
Ensemble Distillation
Ensemble of models → Single model
Capture ensemble's knowledge
Cross-Modal
Teacher (images) → Student (text only)
Transfer visual knowledge to language model
Famous Examples
DistilBERT
Teacher: BERT-Base
Student: 6 layers (vs 12)
Method: Layer selection, triple loss
Result: 40% smaller, 60% faster
TinyBERT
Additional feature distillation
Attention transfer
Smaller and more efficient
MobileBERT
Inverted bottleneck structure
Progressive distillation
Optimized for mobile
Hyperparameters
Temperature (T)
T = 1: No softening
T = 2-5: Typical range
T = 20: Very soft
Higher T → More info from wrong classes
Alpha (α)
α = 0: Only soft labels
α = 1: Only hard labels
α = 0.5-0.9: Typical range
Guidelines
- Larger teacher-student gap: higher T
- More training data: can rely more on hard labels
Best Practices
Teacher Quality
- Better teacher → better student
- Ensemble teachers often help
Student Architecture
- Similar architecture often works best
- Can be completely different though
Data
- Can use unlabeled data (teacher provides labels)
- Augmentation helps
Training
- Often train longer than from scratch
- Careful learning rate tuning
Limitations
- Teacher required: Need good teacher first
- Compute at training: Need teacher forward passes
- Performance ceiling: Student bounded by teacher
- Architecture mismatch: Can be challenging
Key Takeaways
- Distillation transfers knowledge from large to small model
- Soft labels contain "dark knowledge" about class relationships
- Temperature controls softness of probability distribution
- Combine soft loss (teacher) and hard loss (ground truth)
- Enables deploying efficient models without sacrificing performance
- Works across architectures, modalities, and model sizes