Gradient Clipping
Gradient clipping is a technique to prevent exploding gradients by limiting the magnitude of gradients during backpropagation.
The Exploding Gradient Problem
During backpropagation, gradients can become extremely large:
Layer 1 ← Layer 2 ← Layer 3 ← ... ← Layer n
↑ ↑ ↑ ↑
g₁ g₂ g₃ gₙ
If |gᵢ| > 1 for many layers:
Final gradient ≈ g₁ × g₂ × ... × gₙ → ∞
This causes:
- NaN losses: Weights overflow to infinity
- Unstable training: Large oscillations
- Failed convergence: Model never learns
Two Clipping Methods
1. Gradient Value Clipping
Clip individual gradient values to a range:
# Clip each gradient element to [-max_value, max_value]
clipped_grad = torch.clamp(grad, -max_value, max_value)
Pros:
- Simple to implement
- Preserves gradient direction within bounds
Cons:
- Can distort gradient direction
- Different scaling per parameter
2. Gradient Norm Clipping (Recommended)
Scale the entire gradient vector if its norm exceeds a threshold:
# If ||g|| > max_norm, scale g to have norm = max_norm
if ||g|| > max_norm:
g = g × (max_norm / ||g||)
Pros:
- Preserves gradient direction
- Consistent scaling across all parameters
Cons:
- Requires computing global norm first
Implementation
PyTorch
import torch
optimizer.zero_grad()
loss.backward()
# Gradient norm clipping (most common)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Or gradient value clipping
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
optimizer.step()
TensorFlow
optimizer = tf.keras.optimizers.Adam(clipnorm=1.0) # Norm clipping
# or
optimizer = tf.keras.optimizers.Adam(clipvalue=0.5) # Value clipping
Choosing the Clipping Threshold
Common Values
| Model Type | Typical max_norm |
|---|---|
| RNN/LSTM | 1.0 - 5.0 |
| Transformer | 1.0 |
| General | 1.0 - 10.0 |
Finding the Right Value
# Monitor gradient norms during training
def get_grad_norm(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
return total_norm ** 0.5
# Log gradient norms
grad_norm = get_grad_norm(model)
print(f"Gradient norm: {grad_norm:.4f}")
Set threshold above typical norm but below explosion point.
When to Use Gradient Clipping
Always Recommended
- RNNs and LSTMs: Prone to exploding gradients
- Transformers: Standard practice
- Deep networks: More multiplication = more risk
- Reinforcement learning: High variance gradients
Usually Not Needed
- Shallow networks (< 5 layers)
- Well-initialized networks with batch normalization
- When using adaptive optimizers with small learning rates
Gradient Clipping vs Other Techniques
| Technique | Prevents Exploding | Prevents Vanishing |
|---|---|---|
| Gradient Clipping | ✓ | ✗ |
| Batch Normalization | ✓ | ✓ |
| Residual Connections | ✓ | ✓ |
| Proper Initialization | ✓ | ✓ |
| LSTM/GRU | ✓ | ✓ |
Debugging with Gradient Clipping
# Check how often clipping is triggered
def train_step_with_monitoring(model, data, target, max_norm=1.0):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# Get gradient norm before clipping
pre_clip_norm = get_grad_norm(model)
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# Get gradient norm after clipping
post_clip_norm = get_grad_norm(model)
clipped = pre_clip_norm > max_norm
if clipped:
print(f"Clipped: {pre_clip_norm:.2f} → {post_clip_norm:.2f}")
optimizer.step()
return loss.item(), clipped
Common Mistakes
1. Clipping Before backward()
# Wrong!
loss.backward()
optimizer.zero_grad() # This clears gradients!
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
2. Clipping After optimizer.step()
# Wrong - clipping has no effect!
optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
3. Too Aggressive Clipping
# If max_norm is too small, training slows drastically
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001) # Too small!
Key Takeaways
- Gradient clipping prevents exploding gradients
- Norm clipping (preferred) preserves gradient direction
- Common threshold values: 1.0 - 5.0
- Essential for RNNs, Transformers, and deep networks
- Apply after backward() but before optimizer.step()
- Monitor gradient norms to tune the threshold