Weight Initialization
Weight initialization sets the starting values of neural network parameters. Good initialization enables training; bad initialization can make training impossible.
Why Initialization Matters
The Problem with Zero Initialization
All weights = 0
→ All neurons compute same thing
→ All gradients are same
→ Neurons never differentiate
→ Network never learns
The Problem with Large Weights
Large weights → Large activations → Saturated sigmoids/tanh
→ Gradients ≈ 0 → Vanishing gradients
The Problem with Small Weights
Too small weights → Activations shrink through layers
→ Signal dies → Gradients vanish
The Goal
Maintain stable activation/gradient magnitudes across layers:
- Variance of activations ≈ 1 in each layer
- Variance of gradients ≈ 1 in each layer
Xavier/Glorot Initialization
For Sigmoid and Tanh
W ~ N(0, σ²) where σ² = 2 / (fan_in + fan_out)
Or uniform:
W ~ U(-√(6/(fan_in + fan_out)), +√(6/(fan_in + fan_out)))
Where:
- fan_in: number of input units
- fan_out: number of output units
Derivation Intuition
Balances variance preservation in both forward and backward passes.
He/Kaiming Initialization
For ReLU and Variants
W ~ N(0, σ²) where σ² = 2 / fan_in
Why Different from Xavier?
ReLU zeros out half the activations, so we need double the variance to compensate.
Variants
# For ReLU
kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu')
# For Leaky ReLU
kaiming_normal_(tensor, mode='fan_in', nonlinearity='leaky_relu', a=0.01)
LeCun Initialization
For SELU Activation
W ~ N(0, σ²) where σ² = 1 / fan_in
Designed for self-normalizing networks.
Orthogonal Initialization
For RNNs
W = orthogonal matrix
WᵀW = I
Benefits:
- Preserves gradient norm
- Prevents vanishing/exploding in RNNs
- Good for very deep networks
Initialization by Layer Type
Linear/Dense Layers
# For ReLU
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.zeros_(layer.bias)
Convolutional Layers
Same as linear, but fan_in = kernel_size² × in_channels
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
LSTM/GRU
for name, param in lstm.named_parameters():
if 'weight_ih' in name:
nn.init.xavier_uniform_(param)
elif 'weight_hh' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
# Set forget gate bias to 1
n = param.size(0)
param.data[n//4:n//2].fill_(1.0)
Embeddings
nn.init.normal_(embedding.weight, mean=0, std=0.02)
Batch Normalization
nn.init.ones_(bn.weight) # gamma = 1
nn.init.zeros_(bn.bias) # beta = 0
Transformer Initialization
GPT-style
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
Residual Connections
Scale output projections by 1/√N where N is number of layers:
nn.init.normal_(output_proj.weight, std=0.02 / math.sqrt(2 * num_layers))
Common Mistakes
1. Wrong Initialization for Activation
# Wrong: Xavier for ReLU
nn.init.xavier_normal_(relu_layer.weight)
# Right: He for ReLU
nn.init.kaiming_normal_(relu_layer.weight)
2. Forgetting Bias Initialization
# Bias usually initialized to zero
nn.init.zeros_(layer.bias)
3. Not Reinitializing After Model Creation
# PyTorch has defaults, but explicit is better
model.apply(init_weights) # Apply custom initialization
Debugging Initialization
Check Activation Statistics
def check_activations(model, x):
activations = []
hooks = []
def hook(module, input, output):
activations.append(output.detach())
for layer in model.modules():
if isinstance(layer, nn.Linear):
hooks.append(layer.register_forward_hook(hook))
model(x)
for i, act in enumerate(activations):
print(f"Layer {i}: mean={act.mean():.4f}, std={act.std():.4f}")
Good signs:
- Mean ≈ 0 (for layers before activation)
- Std ≈ 1 or stable across layers
Bad signs:
- Mean or std growing/shrinking rapidly
- NaN or Inf values
Summary Table
| Activation | Initialization | Variance |
|---|---|---|
| Sigmoid/Tanh | Xavier | 2/(fan_in + fan_out) |
| ReLU | He/Kaiming | 2/fan_in |
| SELU | LeCun | 1/fan_in |
| RNN hidden | Orthogonal | Wᵀ W = I |
| Any (safe default) | He | 2/fan_in |
Code Template
import torch.nn as nn
import math
def initialize_model(model):
for name, param in model.named_parameters():
if 'weight' in name:
if 'bn' in name or 'norm' in name:
nn.init.ones_(param)
elif len(param.shape) >= 2:
nn.init.kaiming_normal_(param, mode='fan_in')
else:
nn.init.normal_(param, std=0.02)
elif 'bias' in name:
nn.init.zeros_(param)
model = MyModel()
initialize_model(model)
Key Takeaways
- Initialization determines if training is possible
- Xavier for sigmoid/tanh, He/Kaiming for ReLU
- Goal: maintain variance ≈ 1 across layers
- Zero init breaks symmetry → bad
- Use orthogonal for RNNs
- Check activation statistics to debug