intermediateDeep Learning

Learn about weight initialization strategies - how to set initial neural network weights for stable and efficient training.

initializationtrainingneural-networksxavierkaiming

Weight Initialization

Weight initialization sets the starting values of neural network parameters. Good initialization enables training; bad initialization can make training impossible.

Why Initialization Matters

The Problem with Zero Initialization

All weights = 0
→ All neurons compute same thing
→ All gradients are same
→ Neurons never differentiate
→ Network never learns

The Problem with Large Weights

Large weights → Large activations → Saturated sigmoids/tanh
→ Gradients ≈ 0 → Vanishing gradients

The Problem with Small Weights

Too small weights → Activations shrink through layers
→ Signal dies → Gradients vanish

The Goal

Maintain stable activation/gradient magnitudes across layers:

  • Variance of activations ≈ 1 in each layer
  • Variance of gradients ≈ 1 in each layer

Xavier/Glorot Initialization

For Sigmoid and Tanh

W ~ N(0, σ²) where σ² = 2 / (fan_in + fan_out)

Or uniform:
W ~ U(-√(6/(fan_in + fan_out)), +√(6/(fan_in + fan_out)))

Where:

  • fan_in: number of input units
  • fan_out: number of output units

Derivation Intuition

Balances variance preservation in both forward and backward passes.

He/Kaiming Initialization

For ReLU and Variants

W ~ N(0, σ²) where σ² = 2 / fan_in

Why Different from Xavier?

ReLU zeros out half the activations, so we need double the variance to compensate.

Variants

# For ReLU
kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu')

# For Leaky ReLU
kaiming_normal_(tensor, mode='fan_in', nonlinearity='leaky_relu', a=0.01)

LeCun Initialization

For SELU Activation

W ~ N(0, σ²) where σ² = 1 / fan_in

Designed for self-normalizing networks.

Orthogonal Initialization

For RNNs

W = orthogonal matrix
WᵀW = I

Benefits:

  • Preserves gradient norm
  • Prevents vanishing/exploding in RNNs
  • Good for very deep networks

Initialization by Layer Type

Linear/Dense Layers

# For ReLU
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.zeros_(layer.bias)

Convolutional Layers

Same as linear, but fan_in = kernel_size² × in_channels

nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')

LSTM/GRU

for name, param in lstm.named_parameters():
    if 'weight_ih' in name:
        nn.init.xavier_uniform_(param)
    elif 'weight_hh' in name:
        nn.init.orthogonal_(param)
    elif 'bias' in name:
        nn.init.zeros_(param)
        # Set forget gate bias to 1
        n = param.size(0)
        param.data[n//4:n//2].fill_(1.0)

Embeddings

nn.init.normal_(embedding.weight, mean=0, std=0.02)

Batch Normalization

nn.init.ones_(bn.weight)   # gamma = 1
nn.init.zeros_(bn.bias)    # beta = 0

Transformer Initialization

GPT-style

def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=0.02)

Residual Connections

Scale output projections by 1/√N where N is number of layers:

nn.init.normal_(output_proj.weight, std=0.02 / math.sqrt(2 * num_layers))

Common Mistakes

1. Wrong Initialization for Activation

# Wrong: Xavier for ReLU
nn.init.xavier_normal_(relu_layer.weight)

# Right: He for ReLU
nn.init.kaiming_normal_(relu_layer.weight)

2. Forgetting Bias Initialization

# Bias usually initialized to zero
nn.init.zeros_(layer.bias)

3. Not Reinitializing After Model Creation

# PyTorch has defaults, but explicit is better
model.apply(init_weights)  # Apply custom initialization

Debugging Initialization

Check Activation Statistics

def check_activations(model, x):
    activations = []
    hooks = []
    
    def hook(module, input, output):
        activations.append(output.detach())
    
    for layer in model.modules():
        if isinstance(layer, nn.Linear):
            hooks.append(layer.register_forward_hook(hook))
    
    model(x)
    
    for i, act in enumerate(activations):
        print(f"Layer {i}: mean={act.mean():.4f}, std={act.std():.4f}")

Good signs:

  • Mean ≈ 0 (for layers before activation)
  • Std ≈ 1 or stable across layers

Bad signs:

  • Mean or std growing/shrinking rapidly
  • NaN or Inf values

Summary Table

ActivationInitializationVariance
Sigmoid/TanhXavier2/(fan_in + fan_out)
ReLUHe/Kaiming2/fan_in
SELULeCun1/fan_in
RNN hiddenOrthogonalWᵀ W = I
Any (safe default)He2/fan_in

Code Template

import torch.nn as nn
import math

def initialize_model(model):
    for name, param in model.named_parameters():
        if 'weight' in name:
            if 'bn' in name or 'norm' in name:
                nn.init.ones_(param)
            elif len(param.shape) >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in')
            else:
                nn.init.normal_(param, std=0.02)
        elif 'bias' in name:
            nn.init.zeros_(param)

model = MyModel()
initialize_model(model)

Key Takeaways

  1. Initialization determines if training is possible
  2. Xavier for sigmoid/tanh, He/Kaiming for ReLU
  3. Goal: maintain variance ≈ 1 across layers
  4. Zero init breaks symmetry → bad
  5. Use orthogonal for RNNs
  6. Check activation statistics to debug

Practice Questions

Test your understanding with these related interview questions: