intermediateDeep Learning

Learn about knowledge distillation - training smaller models to mimic larger ones, achieving efficiency without sacrificing performance.

distillationcompressionefficiencymodel-optimization

Knowledge Distillation

Knowledge distillation transfers knowledge from a large "teacher" model to a smaller "student" model. The student learns to mimic the teacher's behavior, achieving better performance than training from scratch.

The Core Idea

[Large Teacher Model]
        ↓ soft labels (logits)
[Small Student Model] ← learns from teacher

Student learns from teacher's outputs, not just ground truth.

Why Soft Labels?

Hard Labels vs Soft Labels

Hard label: [0, 0, 1, 0]  (one-hot)
Soft label: [0.1, 0.15, 0.6, 0.15]  (teacher probabilities)

The Information in Soft Labels

Soft labels contain "dark knowledge":

  • Class similarities ("3" looks like "8")
  • Uncertainty (ambiguous cases)
  • Feature relationships

The Distillation Loss

Temperature Scaling

Soften the probability distribution:

p_i = exp(z_i / T) / Σ exp(z_j / T)

T = 1: Normal softmax
T > 1: Softer, more uniform distribution

Combined Loss

L = α × L_hard + (1-α) × L_soft

L_hard = CrossEntropy(student, ground_truth)
L_soft = KL(student_soft, teacher_soft)

Full Formulation

L = α × CE(σ(z_s), y) + (1-α) × T² × KL(σ(z_s/T), σ(z_t/T))

z_s: student logits
z_t: teacher logits
T: temperature
T²: scaling factor for gradient magnitude

Implementation

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, 
                      temperature=4.0, alpha=0.7):
    # Soft targets from teacher
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_prob = F.log_softmax(student_logits / temperature, dim=-1)
    soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean')
    soft_loss = soft_loss * (temperature ** 2)  # Scale gradients
    
    # Hard targets (ground truth)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # Combined loss
    return alpha * hard_loss + (1 - alpha) * soft_loss

# Training loop
for batch in dataloader:
    inputs, labels = batch
    
    with torch.no_grad():
        teacher_logits = teacher(inputs)
    
    student_logits = student(inputs)
    loss = distillation_loss(student_logits, teacher_logits, labels)
    loss.backward()
    optimizer.step()

Types of Distillation

Response-Based

Match teacher's output predictions:

Student softmax → Teacher softmax

Feature-Based

Match intermediate representations:

Student features → Teacher features

L = MSE(student_hidden, teacher_hidden)

May need projection if dimensions differ.

Relation-Based

Match relationships between samples:

Similarity(student[i], student[j]) ≈ Similarity(teacher[i], teacher[j])

Distillation Architectures

Offline Distillation

1. Train teacher fully
2. Generate teacher outputs
3. Train student with frozen teacher

Most common approach.

Online Distillation

Train teacher and student simultaneously
Mutual learning possible

Self-Distillation

Deeper layers teach shallower layers
Or same architecture, different runs

Applications

Model Compression

BERT-Base (110M) → DistilBERT (66M)
40% smaller, 60% faster, 97% accuracy

LLM Distillation

GPT-4 → Smaller open models
Use teacher outputs as training data

Ensemble Distillation

Ensemble of models → Single model
Capture ensemble's knowledge

Cross-Modal

Teacher (images) → Student (text only)
Transfer visual knowledge to language model

Famous Examples

DistilBERT

Teacher: BERT-Base
Student: 6 layers (vs 12)
Method: Layer selection, triple loss
Result: 40% smaller, 60% faster

TinyBERT

Additional feature distillation
Attention transfer
Smaller and more efficient

MobileBERT

Inverted bottleneck structure
Progressive distillation
Optimized for mobile

Hyperparameters

Temperature (T)

T = 1: No softening
T = 2-5: Typical range
T = 20: Very soft

Higher T → More info from wrong classes

Alpha (α)

α = 0: Only soft labels
α = 1: Only hard labels
α = 0.5-0.9: Typical range

Guidelines

  • Larger teacher-student gap: higher T
  • More training data: can rely more on hard labels

Best Practices

Teacher Quality

  • Better teacher → better student
  • Ensemble teachers often help

Student Architecture

  • Similar architecture often works best
  • Can be completely different though

Data

  • Can use unlabeled data (teacher provides labels)
  • Augmentation helps

Training

  • Often train longer than from scratch
  • Careful learning rate tuning

Limitations

  1. Teacher required: Need good teacher first
  2. Compute at training: Need teacher forward passes
  3. Performance ceiling: Student bounded by teacher
  4. Architecture mismatch: Can be challenging

Key Takeaways

  1. Distillation transfers knowledge from large to small model
  2. Soft labels contain "dark knowledge" about class relationships
  3. Temperature controls softness of probability distribution
  4. Combine soft loss (teacher) and hard loss (ground truth)
  5. Enables deploying efficient models without sacrificing performance
  6. Works across architectures, modalities, and model sizes

Practice Questions

Test your understanding with these related interview questions: