intermediateDeep Learning

Understand layer normalization - the normalization technique used in transformers that normalizes across features instead of batches.

normalizationtransformerslayer-normtraining

Layer Normalization

Layer normalization normalizes inputs across features for each example independently. It's the standard normalization in transformers and works better than batch normalization for sequences.

The Formula

LayerNorm(x) = γ × (x - μ) / √(σ² + ε) + β

μ = mean across features (for each example)
σ² = variance across features
γ, β = learned scale and shift
ε = small constant for stability

Layer Norm vs Batch Norm

Normalization Dimensions

Input shape: [Batch, Sequence, Features]

Batch Norm:  Normalize across [Batch, Sequence] for each feature
Layer Norm:  Normalize across [Features] for each position

Visualization

         Batch Norm              Layer Norm
         
  Features →              Features →
  ┌─┬─┬─┬─┐               ┌─┬─┬─┬─┐
  │▓│▓│▓│▓│ ← same color  │▓│░│▒│▓│ ← normalize
B │▓│▓│▓│▓│   normalized  │░│▒│▓│░│   this row
a │▓│▓│▓│▓│   together    │▒│▓│░│▒│
t │▓│▓│▓│▓│               │▓│░│▒│▓│

Key Differences

AspectBatch NormLayer Norm
Normalizes acrossBatchFeatures
Depends on batchYesNo
Works with batch=1NoYes
Running statsYesNo
Best forCNNsTransformers, RNNs

Why Layer Norm for Transformers?

1. Independent of Batch Size

# Works the same for batch_size=1 or 1000
output = layer_norm(x)  # Each example normalized independently

2. Variable Sequence Lengths

# Batch norm would mix different positions
# Layer norm normalizes each position independently

3. Inference Consistency

No need for running mean/variance—same computation at train and test.

Implementation

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

# PyTorch built-in
layer_norm = nn.LayerNorm(hidden_size)

RMSNorm

Simplified variant used in LLaMA:

RMSNorm(x) = x / √(mean(x²) + ε) × γ

# No mean subtraction, no β shift

Why RMSNorm?

  • Slightly faster (no mean computation)
  • Works just as well in practice
  • Used in modern LLMs
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.gamma

Pre-Norm vs Post-Norm

Post-Norm (Original Transformer)

x → Sublayer → Add x → LayerNorm

output = LayerNorm(x + Sublayer(x))

Pre-Norm (Modern Default)

x → LayerNorm → Sublayer → Add x

output = x + Sublayer(LayerNorm(x))

Comparison

AspectPost-NormPre-Norm
Training stabilityLess stableMore stable
Needs warmupOften yesUsually no
Final layer normAfter last blockBefore output
Used inOriginal TransformerGPT-2, LLaMA

Pre-Norm is generally preferred for stability.

In Transformer Blocks

Post-Norm (Original)

def forward(self, x):
    x = self.norm1(x + self.attention(x))
    x = self.norm2(x + self.ffn(x))
    return x

Pre-Norm (Modern)

def forward(self, x):
    x = x + self.attention(self.norm1(x))
    x = x + self.ffn(self.norm2(x))
    return x

Other Normalizations

Group Normalization

Normalize across groups of features:

Group Norm: Divide features into G groups, normalize each

Good for small batches in vision.

Instance Normalization

Normalize each channel per example:

Used in style transfer

Comparison

                    What's normalized together?
                    
Batch Norm:    Same feature across batch
Layer Norm:    All features for one example
Instance Norm: Each channel for one example
Group Norm:    Groups of channels for one example

Common Issues

Numerical Stability

# Always use epsilon
std = torch.sqrt(var + eps)  # Not torch.sqrt(var)

Affine Parameters

# With learnable parameters (default)
nn.LayerNorm(dim, elementwise_affine=True)

# Without
nn.LayerNorm(dim, elementwise_affine=False)

Key Takeaways

  1. Layer norm normalizes across features, not batch
  2. Standard for transformers (batch-independent)
  3. Pre-norm is more stable than post-norm
  4. RMSNorm is simpler variant used in modern LLMs
  5. γ (scale) and β (shift) are learned parameters
  6. Always include epsilon for numerical stability

Practice Questions

Test your understanding with these related interview questions: