intermediateOptimization

Learn how gradient accumulation enables training with larger effective batch sizes when GPU memory is limited.

trainingmemoryoptimizationdeep-learningbatch-size

Gradient Accumulation

Gradient accumulation is a technique that simulates larger batch sizes by accumulating gradients over multiple mini-batches before performing a weight update, enabling training of large models on limited hardware.

The Problem

Desired batch size: 64
GPU memory: Can only fit batch size 8

Solution: Accumulate gradients over 8 mini-batches of size 8
          Effective batch size: 8 × 8 = 64

How It Works

Without Accumulation (batch_size=8):
┌─────────┐   ┌─────────┐   ┌─────────┐
│ Batch 1 │ → │ Update  │ → │ Batch 2 │ → Update → ...
└─────────┘   └─────────┘   └─────────┘

With Accumulation (effective batch_size=32, accumulate=4):
┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐
│ Batch 1 │ → │ Batch 2 │ → │ Batch 3 │ → │ Batch 4 │ → │ Update  │
└─────────┘   └─────────┘   └─────────┘   └─────────┘   └─────────┘
  Accumulate   Accumulate    Accumulate   Accumulate      Step

Mathematical Equivalence

Large batch gradient:
  g = (1/N) × Σᵢ ∇L(xᵢ)

Accumulated gradient (k mini-batches of size n):
  g = (1/k) × Σⱼ (1/n) × Σᵢ∈batch_j ∇L(xᵢ)
    = (1/N) × Σᵢ ∇L(xᵢ)  (same!)
  
where N = k × n

Implementation

Basic PyTorch

import torch

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

accumulation_steps = 4

for i, (inputs, targets) in enumerate(dataloader):
    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # Normalize loss to account for accumulation
    loss = loss / accumulation_steps
    
    # Backward pass (accumulates gradients)
    loss.backward()
    
    # Update weights every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

With Gradient Clipping

for i, (inputs, targets) in enumerate(dataloader):
    outputs = model(inputs)
    loss = criterion(outputs, targets) / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        # Clip accumulated gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

With Mixed Precision (AMP)

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
accumulation_steps = 4

for i, (inputs, targets) in enumerate(dataloader):
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets) / accumulation_steps
    
    # Scale loss and backward
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.unscale_(optimizer)  # Unscale before clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Hugging Face Transformers

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,  # Effective batch = 8 × 4 = 32
    learning_rate=2e-5,
    num_train_epochs=3,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()

Important Considerations

1. Learning Rate Scaling

# Linear scaling rule (common for SGD)
base_lr = 0.01
base_batch_size = 32
effective_batch_size = per_device_batch * accumulation_steps * num_gpus

scaled_lr = base_lr * (effective_batch_size / base_batch_size)

2. Batch Normalization

Batch statistics computed per mini-batch, not accumulated:

# Solution 1: Use SyncBatchNorm for multi-GPU
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

# Solution 2: Use LayerNorm or GroupNorm instead
# (statistics independent of batch size)

3. Steps vs Epochs

# With accumulation, fewer optimizer steps per epoch
steps_per_epoch = len(dataloader) // accumulation_steps

# Adjust schedulers accordingly
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=epochs * steps_per_epoch  # Not len(dataloader)!
)

4. Dropout and Augmentation

Each mini-batch gets different dropout masks and augmentations:

Batch 1: Dropout pattern A, Augmentation A
Batch 2: Dropout pattern B, Augmentation B
...
Accumulate → Update

This adds slight noise (usually beneficial)

Memory Comparison

Batch size 64 directly:
  Memory = 64 × (activations + gradients)
  
Batch size 8 with accumulation=8:
  Memory = 8 × activations + all parameters' gradients
  (Gradients are accumulated in place, not stored per sample)

Trade-offs

AspectDirect Large BatchGradient Accumulation
MemoryHighLow
SpeedFastSlower (more passes)
BN statisticsFull batchMini-batch only
VarianceLowerSlightly higher
ImplementationSimpleRequires care

When to Use

Use Gradient Accumulation When:

  • GPU memory is insufficient for desired batch size
  • Training large models (LLMs, ViT)
  • Need large batches for stability (contrastive learning)

Consider Alternatives When:

  • Memory allows direct large batches (faster)
  • Using batch-size-sensitive techniques (BatchNorm)
  • Training time is critical

Common Patterns

Pattern 1: Fixed Effective Batch Size

# Want effective batch of 256 across different GPUs
effective_batch = 256
per_device = 8
num_gpus = 4
accumulation = effective_batch // (per_device * num_gpus)  # = 8

Pattern 2: Dynamic Accumulation

# Increase batch size during training
for epoch in range(num_epochs):
    accum_steps = min(16, 2 ** (epoch // 2))  # 1, 1, 2, 2, 4, 4, 8, 8, 16, ...

Pattern 3: Handle Remainder Batches

for i, batch in enumerate(dataloader):
    loss = compute_loss(batch) / accumulation_steps
    loss.backward()
    
    # Update at end of accumulation OR end of epoch
    if (i + 1) % accumulation_steps == 0 or (i + 1) == len(dataloader):
        optimizer.step()
        optimizer.zero_grad()

Key Takeaways

  1. Gradient accumulation simulates large batches with limited memory
  2. Divide loss by accumulation steps to maintain correct gradient scale
  3. Update weights only after accumulating all mini-batches
  4. Adjust learning rate schedulers for fewer steps per epoch
  5. BatchNorm statistics are computed per mini-batch (use alternatives if critical)
  6. Speed is slower than direct large batches but enables otherwise impossible training