intermediateDeep Learning

Learn about batch normalization - the technique that stabilizes training by normalizing layer inputs, enabling faster and more reliable deep learning.

normalizationtrainingregularizationneural-networks

Batch Normalization

Batch normalization (BatchNorm) is a technique that normalizes the inputs to each layer during training. It's one of the most important innovations in deep learning, making networks faster and easier to train.

The Problem: Internal Covariate Shift

As training progresses, the distribution of inputs to each layer keeps changing (because earlier layers' weights change). This "internal covariate shift" makes training unstable.

Symptoms:

  • Need very small learning rates
  • Careful weight initialization critical
  • Deep networks hard to train
  • Saturation in sigmoid/tanh activations

The Solution

Normalize each layer's inputs to have zero mean and unit variance:

1. Compute batch statistics:
   μ = mean(x)  over batch
   σ² = var(x)  over batch

2. Normalize:
   x̂ = (x - μ) / √(σ² + ε)

3. Scale and shift (learnable):
   y = γx̂ + β

Where:

  • ε prevents division by zero
  • γ, β are learned parameters that restore representational power

Why Scale and Shift?

Normalizing to mean=0, var=1 might not be optimal. The network should be able to learn the best distribution.

If γ = σ and β = μ, we get back the original values. So BatchNorm can learn to undo itself if needed.

During Training vs Inference

Training

  • Use batch statistics (μ, σ² from current batch)
  • Update running averages for inference

Inference

  • Use running averages (fixed statistics)
  • Deterministic predictions
running_mean = momentum × running_mean + (1 - momentum) × batch_mean

Typically momentum = 0.1.

Where to Place BatchNorm?

Original paper: After linear, before activation

Linear → BatchNorm → ReLU

Common practice: After activation (works equally well)

Linear → ReLU → BatchNorm

With convolutions:

Conv → BatchNorm → ReLU

Benefits

1. Higher Learning Rates

Normalized inputs mean gradients are better behaved. You can use 10x higher learning rates.

2. Less Sensitive to Initialization

Bad initialization effects are normalized away.

3. Regularization Effect

Batch statistics add noise (different batches → different normalizations). Acts like a mild regularizer.

4. Enables Very Deep Networks

Gradients flow better through normalized layers.

5. Reduces Need for Dropout

The regularization effect often replaces dropout.

Limitations

1. Batch Size Dependence

  • Small batches → noisy statistics → unstable training
  • Need batch size ≥ 16-32 typically

2. Different Train/Test Behavior

Must track running statistics and switch modes.

3. Sequence Models

Variable-length sequences are problematic.

4. Not Great for RNNs

Statistics vary across time steps.

Variants

Layer Normalization

Normalize across features (not batch):

μ, σ = computed over features for each sample
  • Works with batch size = 1
  • Standard in Transformers
  • Better for RNNs

Instance Normalization

Normalize each sample, each channel separately:

μ, σ = computed per sample, per channel (spatial dimensions)
  • Used in style transfer
  • Removes instance-specific contrast

Group Normalization

Compromise between Layer and Instance:

Divide channels into groups, normalize within groups
  • Works with small batches
  • Used in detection/segmentation

Comparison

MethodNormalizes OverGood For
BatchNormBatchCNNs, large batches
LayerNormFeaturesTransformers, RNNs
InstanceNormH, W per channelStyle transfer
GroupNormGroups of channelsSmall batch CNNs

Implementation Details

Affine Parameters

# PyTorch
nn.BatchNorm2d(num_features, affine=True)  # γ, β learned
nn.BatchNorm2d(num_features, affine=False)  # No γ, β

Tracking Running Stats

nn.BatchNorm2d(num_features, track_running_stats=True)  # Default

Training Mode

model.train()   # Use batch statistics
model.eval()    # Use running statistics

Key Takeaways

  1. BatchNorm normalizes layer inputs to zero mean, unit variance
  2. Learnable γ, β restore representational power
  3. Enables higher learning rates and easier training
  4. Use running statistics at inference time
  5. Layer Norm is preferred for Transformers and RNNs
  6. Group Norm works better with small batches