intermediateOptimization

Learn why starting training with a low learning rate that gradually increases improves stability and final performance.

learning-ratetrainingoptimizationwarmuptransformers

Learning Rate Warmup

Learning rate warmup is a technique where training starts with a very low learning rate that gradually increases to the target rate, helping stabilize early training and often improving final model performance.

Why Warmup?

The Problem

At the start of training:

  • Weights are randomly initialized
  • Gradients can be very large and noisy
  • High learning rate → unstable updates → divergence
Without Warmup:
Loss
  │ ╱╲  ╱╲
  │╱  ╲╱  ╲      Unstable start
  │        ╲
  │         ╲___
  └──────────────────► Steps

With Warmup:
Loss
  │
  │╲
  │ ╲
  │  ╲___________   Smooth descent
  └──────────────────► Steps

Especially Important For

  • Transformers: Large models with attention mechanisms
  • Large batch training: Gradients averaged over many samples
  • Adam/adaptive optimizers: Running averages need time to stabilize
  • Transfer learning: Pre-trained weights are already good

Warmup Strategies

Linear Warmup

LR
 │
 │         ___________
 │        /
 │       /
 │      /
 │     /
 │____/
 └─────────────────────► Steps
    Warmup   Main Training
def linear_warmup(step, warmup_steps, target_lr):
    if step < warmup_steps:
        return target_lr * (step / warmup_steps)
    return target_lr

Exponential Warmup

import math

def exponential_warmup(step, warmup_steps, target_lr):
    if step < warmup_steps:
        return target_lr * (1 - math.exp(-5 * step / warmup_steps))
    return target_lr

Warmup + Decay

LR
 │
 │     ╱╲
 │    /  \
 │   /    \___
 │  /          \___
 │_/                \_
 └──────────────────────► Steps
  Warmup    Constant   Decay

Implementation

PyTorch Manual

import torch

def get_lr(step, warmup_steps, max_lr, total_steps):
    """Linear warmup then cosine decay."""
    if step < warmup_steps:
        # Linear warmup
        return max_lr * step / warmup_steps
    else:
        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return max_lr * 0.5 * (1 + math.cos(math.pi * progress))

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for step in range(total_steps):
    lr = get_lr(step, warmup_steps=1000, max_lr=1e-4, total_steps=total_steps)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Training step
    loss = train_step(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

PyTorch Built-in Schedulers

from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Option 1: LambdaLR
def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda)

# Option 2: LinearLR for warmup only
warmup_scheduler = LinearLR(
    optimizer, 
    start_factor=0.1,  # Start at 10% of target LR
    total_iters=warmup_steps
)

# Option 3: Combined warmup + decay
warmup = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps)
main = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)
scheduler = SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps])

Hugging Face Transformers

from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

# Linear warmup, linear decay
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=1000,
    num_training_steps=10000
)

# Linear warmup, cosine decay
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=1000,
    num_training_steps=10000
)

# Training loop
for batch in dataloader:
    loss = model(**batch).loss
    loss.backward()
    optimizer.step()
    scheduler.step()  # Update LR
    optimizer.zero_grad()

How Many Warmup Steps?

Rules of Thumb

ScenarioWarmup Steps
Small dataset100-500 steps
BERT fine-tuning6-10% of total steps
Large model training1000-5000 steps
Very large batchMore warmup needed

Based on Total Training

# Common: 5-10% of total steps
warmup_steps = int(0.1 * total_steps)

# Or based on epochs
steps_per_epoch = len(dataloader)
warmup_steps = 1 * steps_per_epoch  # 1 epoch warmup

Warmup with Different Optimizers

Adam/AdamW

# Adam benefits from warmup because:
# - Running averages (m, v) start at 0
# - Early updates are based on limited history
# - Warmup gives time for statistics to stabilize

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01
)

SGD with Momentum

# Less critical for SGD, but still helpful
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9
)

Warmup in Large Batch Training

Large batches = more stable gradients
              = can use higher learning rate
              = BUT need longer warmup

Linear scaling rule:
  target_lr = base_lr × (batch_size / base_batch_size)
  warmup_steps = base_warmup × (batch_size / base_batch_size)
# Example: Scaling from batch 256 to 8192
base_lr = 1e-4
base_batch = 256
large_batch = 8192

scaled_lr = base_lr * (large_batch / base_batch)  # 32x higher
warmup_steps = 1000 * (large_batch / base_batch)  # 32x longer warmup

Visualizing Learning Rate

import matplotlib.pyplot as plt

def visualize_lr_schedule(scheduler, total_steps):
    lrs = []
    for step in range(total_steps):
        lrs.append(scheduler.get_last_lr()[0])
        scheduler.step()
    
    plt.figure(figsize=(10, 4))
    plt.plot(lrs)
    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.axvline(warmup_steps, color='r', linestyle='--', label='End warmup')
    plt.legend()
    plt.show()

Common Patterns

BERT/Transformer Fine-tuning

total_steps = len(dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)  # 10% warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

Vision Transformer (ViT)

# ViT often uses longer warmup
warmup_epochs = 5
warmup_steps = warmup_epochs * steps_per_epoch

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

Key Takeaways

  1. Warmup prevents instability at training start
  2. Especially important for Transformers and large batches
  3. Typical warmup: 5-10% of total training steps
  4. Linear warmup is most common, easy to implement
  5. Combine with decay schedule for best results
  6. Large batch training requires proportionally longer warmup