Learning Rate Warmup
Learning rate warmup is a technique where training starts with a very low learning rate that gradually increases to the target rate, helping stabilize early training and often improving final model performance.
Why Warmup?
The Problem
At the start of training:
- Weights are randomly initialized
- Gradients can be very large and noisy
- High learning rate → unstable updates → divergence
Without Warmup:
Loss
│ ╱╲ ╱╲
│╱ ╲╱ ╲ Unstable start
│ ╲
│ ╲___
└──────────────────► Steps
With Warmup:
Loss
│
│╲
│ ╲
│ ╲___________ Smooth descent
└──────────────────► Steps
Especially Important For
- Transformers: Large models with attention mechanisms
- Large batch training: Gradients averaged over many samples
- Adam/adaptive optimizers: Running averages need time to stabilize
- Transfer learning: Pre-trained weights are already good
Warmup Strategies
Linear Warmup
LR
│
│ ___________
│ /
│ /
│ /
│ /
│____/
└─────────────────────► Steps
Warmup Main Training
def linear_warmup(step, warmup_steps, target_lr):
if step < warmup_steps:
return target_lr * (step / warmup_steps)
return target_lr
Exponential Warmup
import math
def exponential_warmup(step, warmup_steps, target_lr):
if step < warmup_steps:
return target_lr * (1 - math.exp(-5 * step / warmup_steps))
return target_lr
Warmup + Decay
LR
│
│ ╱╲
│ / \
│ / \___
│ / \___
│_/ \_
└──────────────────────► Steps
Warmup Constant Decay
Implementation
PyTorch Manual
import torch
def get_lr(step, warmup_steps, max_lr, total_steps):
"""Linear warmup then cosine decay."""
if step < warmup_steps:
# Linear warmup
return max_lr * step / warmup_steps
else:
# Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max_lr * 0.5 * (1 + math.cos(math.pi * progress))
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for step in range(total_steps):
lr = get_lr(step, warmup_steps=1000, max_lr=1e-4, total_steps=total_steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Training step
loss = train_step(model, batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PyTorch Built-in Schedulers
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Option 1: LambdaLR
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
return 1.0
scheduler = LambdaLR(optimizer, lr_lambda)
# Option 2: LinearLR for warmup only
warmup_scheduler = LinearLR(
optimizer,
start_factor=0.1, # Start at 10% of target LR
total_iters=warmup_steps
)
# Option 3: Combined warmup + decay
warmup = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps)
main = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)
scheduler = SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps])
Hugging Face Transformers
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
# Linear warmup, linear decay
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=1000,
num_training_steps=10000
)
# Linear warmup, cosine decay
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=1000,
num_training_steps=10000
)
# Training loop
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
optimizer.step()
scheduler.step() # Update LR
optimizer.zero_grad()
How Many Warmup Steps?
Rules of Thumb
| Scenario | Warmup Steps |
|---|---|
| Small dataset | 100-500 steps |
| BERT fine-tuning | 6-10% of total steps |
| Large model training | 1000-5000 steps |
| Very large batch | More warmup needed |
Based on Total Training
# Common: 5-10% of total steps
warmup_steps = int(0.1 * total_steps)
# Or based on epochs
steps_per_epoch = len(dataloader)
warmup_steps = 1 * steps_per_epoch # 1 epoch warmup
Warmup with Different Optimizers
Adam/AdamW
# Adam benefits from warmup because:
# - Running averages (m, v) start at 0
# - Early updates are based on limited history
# - Warmup gives time for statistics to stabilize
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01
)
SGD with Momentum
# Less critical for SGD, but still helpful
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9
)
Warmup in Large Batch Training
Large batches = more stable gradients
= can use higher learning rate
= BUT need longer warmup
Linear scaling rule:
target_lr = base_lr × (batch_size / base_batch_size)
warmup_steps = base_warmup × (batch_size / base_batch_size)
# Example: Scaling from batch 256 to 8192
base_lr = 1e-4
base_batch = 256
large_batch = 8192
scaled_lr = base_lr * (large_batch / base_batch) # 32x higher
warmup_steps = 1000 * (large_batch / base_batch) # 32x longer warmup
Visualizing Learning Rate
import matplotlib.pyplot as plt
def visualize_lr_schedule(scheduler, total_steps):
lrs = []
for step in range(total_steps):
lrs.append(scheduler.get_last_lr()[0])
scheduler.step()
plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.axvline(warmup_steps, color='r', linestyle='--', label='End warmup')
plt.legend()
plt.show()
Common Patterns
BERT/Transformer Fine-tuning
total_steps = len(dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps) # 10% warmup
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
Vision Transformer (ViT)
# ViT often uses longer warmup
warmup_epochs = 5
warmup_steps = warmup_epochs * steps_per_epoch
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
Key Takeaways
- Warmup prevents instability at training start
- Especially important for Transformers and large batches
- Typical warmup: 5-10% of total training steps
- Linear warmup is most common, easy to implement
- Combine with decay schedule for best results
- Large batch training requires proportionally longer warmup