Batch Normalization
Batch normalization (BatchNorm) is a technique that normalizes the inputs to each layer during training. It's one of the most important innovations in deep learning, making networks faster and easier to train.
The Problem: Internal Covariate Shift
As training progresses, the distribution of inputs to each layer keeps changing (because earlier layers' weights change). This "internal covariate shift" makes training unstable.
Symptoms:
- Need very small learning rates
- Careful weight initialization critical
- Deep networks hard to train
- Saturation in sigmoid/tanh activations
The Solution
Normalize each layer's inputs to have zero mean and unit variance:
1. Compute batch statistics:
μ = mean(x) over batch
σ² = var(x) over batch
2. Normalize:
x̂ = (x - μ) / √(σ² + ε)
3. Scale and shift (learnable):
y = γx̂ + β
Where:
- ε prevents division by zero
- γ, β are learned parameters that restore representational power
Why Scale and Shift?
Normalizing to mean=0, var=1 might not be optimal. The network should be able to learn the best distribution.
If γ = σ and β = μ, we get back the original values. So BatchNorm can learn to undo itself if needed.
During Training vs Inference
Training
- Use batch statistics (μ, σ² from current batch)
- Update running averages for inference
Inference
- Use running averages (fixed statistics)
- Deterministic predictions
running_mean = momentum × running_mean + (1 - momentum) × batch_mean
Typically momentum = 0.1.
Where to Place BatchNorm?
Original paper: After linear, before activation
Linear → BatchNorm → ReLU
Common practice: After activation (works equally well)
Linear → ReLU → BatchNorm
With convolutions:
Conv → BatchNorm → ReLU
Benefits
1. Higher Learning Rates
Normalized inputs mean gradients are better behaved. You can use 10x higher learning rates.
2. Less Sensitive to Initialization
Bad initialization effects are normalized away.
3. Regularization Effect
Batch statistics add noise (different batches → different normalizations). Acts like a mild regularizer.
4. Enables Very Deep Networks
Gradients flow better through normalized layers.
5. Reduces Need for Dropout
The regularization effect often replaces dropout.
Limitations
1. Batch Size Dependence
- Small batches → noisy statistics → unstable training
- Need batch size ≥ 16-32 typically
2. Different Train/Test Behavior
Must track running statistics and switch modes.
3. Sequence Models
Variable-length sequences are problematic.
4. Not Great for RNNs
Statistics vary across time steps.
Variants
Layer Normalization
Normalize across features (not batch):
μ, σ = computed over features for each sample
- Works with batch size = 1
- Standard in Transformers
- Better for RNNs
Instance Normalization
Normalize each sample, each channel separately:
μ, σ = computed per sample, per channel (spatial dimensions)
- Used in style transfer
- Removes instance-specific contrast
Group Normalization
Compromise between Layer and Instance:
Divide channels into groups, normalize within groups
- Works with small batches
- Used in detection/segmentation
Comparison
| Method | Normalizes Over | Good For |
|---|---|---|
| BatchNorm | Batch | CNNs, large batches |
| LayerNorm | Features | Transformers, RNNs |
| InstanceNorm | H, W per channel | Style transfer |
| GroupNorm | Groups of channels | Small batch CNNs |
Implementation Details
Affine Parameters
# PyTorch
nn.BatchNorm2d(num_features, affine=True) # γ, β learned
nn.BatchNorm2d(num_features, affine=False) # No γ, β
Tracking Running Stats
nn.BatchNorm2d(num_features, track_running_stats=True) # Default
Training Mode
model.train() # Use batch statistics
model.eval() # Use running statistics
Key Takeaways
- BatchNorm normalizes layer inputs to zero mean, unit variance
- Learnable γ, β restore representational power
- Enables higher learning rates and easier training
- Use running statistics at inference time
- Layer Norm is preferred for Transformers and RNNs
- Group Norm works better with small batches