Vanishing and Exploding Gradients
Vanishing and exploding gradients are phenomena that make training deep neural networks difficult. Understanding and addressing them was key to the deep learning revolution.
The Problem
During backpropagation, gradients are multiplied through layers:
∂L/∂w₁ = ∂L/∂h₄ × ∂h₄/∂h₃ × ∂h₃/∂h₂ × ∂h₂/∂h₁ × ∂h₁/∂w₁
If each term is:
- < 1: Gradient shrinks exponentially → vanishing
- > 1: Gradient grows exponentially → exploding
Vanishing Gradients
The Effect
Layer 1 gradient: 0.0001 ← Barely updates
Layer 2 gradient: 0.001
Layer 3 gradient: 0.01
Layer 4 gradient: 0.1
Layer 5 gradient: 1.0 ← Updates normally
Early layers learn extremely slowly or not at all.
Causes
1. Sigmoid/Tanh Activation
Sigmoid derivative: max = 0.25 at x=0
Multiply 0.25 × 0.25 × 0.25 × ... = vanishes quickly
2. Poor Weight Initialization
Weights too small → activations shrink → gradients shrink
Symptoms
- Training loss stops decreasing
- Early layer weights don't change
- Model performs like a shallow network
Exploding Gradients
The Effect
Gradient: 1 → 10 → 100 → 1000 → 10000 → NaN
Weights become unstable, loss goes to infinity.
Causes
1. Large Weights
Weights > 1 → gradients multiply > 1 → explosion
2. Recurrent Networks
Same weight matrix multiplied many times:
W × W × W × ... = explodes if ||W|| > 1
Symptoms
- Loss becomes NaN or Inf
- Weights become very large
- Training crashes
Solutions
1. ReLU Activation
ReLU(x) = max(0, x)
ReLU'(x) = 1 if x > 0, else 0
Gradient is exactly 1 for positive inputs - no shrinking!
Variants for "dying ReLU" problem:
- Leaky ReLU: small slope for x < 0
- ELU, SELU, GELU: smooth alternatives
2. Proper Weight Initialization
Xavier/Glorot (for tanh, sigmoid):
W ~ N(0, 2 / (fan_in + fan_out))
He/Kaiming (for ReLU):
W ~ N(0, 2 / fan_in)
# PyTorch
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.xavier_normal_(layer.weight)
Maintains variance of activations across layers.
3. Batch Normalization
x_norm = (x - mean) / sqrt(var + ε)
output = γ × x_norm + β
Normalizes activations, keeps gradients well-behaved.
4. Residual Connections (Skip Connections)
output = F(x) + x
Gradient can flow directly through the skip connection:
∂L/∂x = ∂L/∂output × (∂F(x)/∂x + 1)
↑
Always at least 1!
This is why ResNets can be 100+ layers deep.
5. Gradient Clipping
For exploding gradients (especially RNNs):
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
If ||gradient|| > threshold:
gradient = gradient × (threshold / ||gradient||)
6. LSTM/GRU for RNNs
Gating mechanisms provide gradient highways:
Cell state: cₜ = fₜ × cₜ₋₁ + iₜ × c̃ₜ
Gradient flows through fₜ × cₜ₋₁ with minimal transformation
7. Layer Normalization
Alternative to batch norm, normalizes across features:
nn.LayerNorm(hidden_size)
Preferred for transformers and RNNs.
Modern Architectures Address This
ResNet
Residual blocks: output = F(x) + x
→ Can train 100+ layers
Transformer
Residual + LayerNorm + Attention (no recurrence)
→ Can train very deep models
Dense Connections (DenseNet)
Connect every layer to every other layer
→ Short paths for gradients
Diagnosing Gradient Issues
Monitor During Training
# Check gradient norms
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: {param.grad.norm():.4f}")
Warning Signs
| Issue | Symptom |
|---|---|
| Vanishing | Early layer gradients ≈ 0 |
| Vanishing | Loss plateaus early |
| Exploding | Gradients become NaN |
| Exploding | Loss becomes Inf |
| Exploding | Weights grow unbounded |
Summary of Solutions
| Problem | Solution |
|---|---|
| Vanishing (activation) | ReLU, GELU, Swish |
| Vanishing (initialization) | He/Xavier initialization |
| Vanishing (depth) | Residual connections |
| Vanishing (RNN) | LSTM, GRU |
| Exploding | Gradient clipping |
| Both | Batch/Layer normalization |
Key Takeaways
- Gradients multiply through layers - can vanish or explode
- Sigmoid/tanh cause vanishing; ReLU solves this
- Proper initialization maintains gradient scale
- Skip connections provide gradient highways
- Batch/layer normalization stabilizes training
- Gradient clipping prevents explosion
- Modern architectures (ResNet, Transformer) have built-in solutions