Layer Normalization
Layer normalization normalizes inputs across features for each example independently. It's the standard normalization in transformers and works better than batch normalization for sequences.
The Formula
LayerNorm(x) = γ × (x - μ) / √(σ² + ε) + β
μ = mean across features (for each example)
σ² = variance across features
γ, β = learned scale and shift
ε = small constant for stability
Layer Norm vs Batch Norm
Normalization Dimensions
Input shape: [Batch, Sequence, Features]
Batch Norm: Normalize across [Batch, Sequence] for each feature
Layer Norm: Normalize across [Features] for each position
Visualization
Batch Norm Layer Norm
Features → Features →
┌─┬─┬─┬─┐ ┌─┬─┬─┬─┐
│▓│▓│▓│▓│ ← same color │▓│░│▒│▓│ ← normalize
B │▓│▓│▓│▓│ normalized │░│▒│▓│░│ this row
a │▓│▓│▓│▓│ together │▒│▓│░│▒│
t │▓│▓│▓│▓│ │▓│░│▒│▓│
Key Differences
| Aspect | Batch Norm | Layer Norm |
|---|---|---|
| Normalizes across | Batch | Features |
| Depends on batch | Yes | No |
| Works with batch=1 | No | Yes |
| Running stats | Yes | No |
| Best for | CNNs | Transformers, RNNs |
Why Layer Norm for Transformers?
1. Independent of Batch Size
# Works the same for batch_size=1 or 1000
output = layer_norm(x) # Each example normalized independently
2. Variable Sequence Lengths
# Batch norm would mix different positions
# Layer norm normalizes each position independently
3. Inference Consistency
No need for running mean/variance—same computation at train and test.
Implementation
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
# PyTorch built-in
layer_norm = nn.LayerNorm(hidden_size)
RMSNorm
Simplified variant used in LLaMA:
RMSNorm(x) = x / √(mean(x²) + ε) × γ
# No mean subtraction, no β shift
Why RMSNorm?
- Slightly faster (no mean computation)
- Works just as well in practice
- Used in modern LLMs
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.gamma
Pre-Norm vs Post-Norm
Post-Norm (Original Transformer)
x → Sublayer → Add x → LayerNorm
output = LayerNorm(x + Sublayer(x))
Pre-Norm (Modern Default)
x → LayerNorm → Sublayer → Add x
output = x + Sublayer(LayerNorm(x))
Comparison
| Aspect | Post-Norm | Pre-Norm |
|---|---|---|
| Training stability | Less stable | More stable |
| Needs warmup | Often yes | Usually no |
| Final layer norm | After last block | Before output |
| Used in | Original Transformer | GPT-2, LLaMA |
Pre-Norm is generally preferred for stability.
In Transformer Blocks
Post-Norm (Original)
def forward(self, x):
x = self.norm1(x + self.attention(x))
x = self.norm2(x + self.ffn(x))
return x
Pre-Norm (Modern)
def forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
Other Normalizations
Group Normalization
Normalize across groups of features:
Group Norm: Divide features into G groups, normalize each
Good for small batches in vision.
Instance Normalization
Normalize each channel per example:
Used in style transfer
Comparison
What's normalized together?
Batch Norm: Same feature across batch
Layer Norm: All features for one example
Instance Norm: Each channel for one example
Group Norm: Groups of channels for one example
Common Issues
Numerical Stability
# Always use epsilon
std = torch.sqrt(var + eps) # Not torch.sqrt(var)
Affine Parameters
# With learnable parameters (default)
nn.LayerNorm(dim, elementwise_affine=True)
# Without
nn.LayerNorm(dim, elementwise_affine=False)
Key Takeaways
- Layer norm normalizes across features, not batch
- Standard for transformers (batch-independent)
- Pre-norm is more stable than post-norm
- RMSNorm is simpler variant used in modern LLMs
- γ (scale) and β (shift) are learned parameters
- Always include epsilon for numerical stability