🟣 Technical 9 min read

Weight Initialization in Deep Learning: Why It Matters More Than You Think

Bad weight initialization can make a deep network untrainable. This guide explains the theory behind Xavier, He, and modern initialization schemes — and when each one matters.

View all deep learning depths →

Here’s a thought experiment: initialize all weights in a neural network to zero. What happens? Every neuron computes the same output, receives the same gradient, and updates identically. Your 1000-neuron layer behaves as one neuron. The network can’t learn.

Weight initialization solves this symmetry-breaking problem, but it does much more. The right initialization keeps gradients flowing through deep networks. The wrong one causes gradients to explode or vanish, making training impossible.

The Fundamental Problem

Consider a layer: y = Wx + b, where W is the weight matrix. If you stack 50 of these layers:

  • Weights too large: Each layer amplifies the signal. After 50 layers, activations explode to infinity (gradient explosion)
  • Weights too small: Each layer shrinks the signal. After 50 layers, activations collapse to zero (gradient vanishing)
  • Weights just right: The signal maintains a stable variance through all layers

“Just right” is what initialization schemes compute.

The Key Schemes

Xavier/Glorot Initialization (2010)

Designed for networks with sigmoid or tanh activations. The idea: keep the variance of activations and gradients constant across layers.

# For a layer with fan_in inputs and fan_out outputs:
# Uniform: W ~ U(-limit, limit) where limit = sqrt(6 / (fan_in + fan_out))
# Normal:  W ~ N(0, 2 / (fan_in + fan_out))

import torch.nn as nn

# PyTorch
nn.init.xavier_uniform_(layer.weight)
nn.init.xavier_normal_(layer.weight)

When to use: Linear layers, sigmoid/tanh activations, autoencoders, vanilla RNNs.

Why it works: Xavier analyzed the forward and backward pass variance propagation and found that balancing fan_in and fan_out keeps both stable. The math assumes linear or symmetric activations.

He/Kaiming Initialization (2015)

ReLU breaks Xavier’s symmetry assumption — it zeros out half the inputs. He initialization compensates by doubling the variance:

# W ~ N(0, 2 / fan_in)  [for ReLU]
# W ~ N(0, 2 / fan_out) [for backward pass stability]

nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')  # for backward

When to use: Any network with ReLU or its variants (LeakyReLU, PReLU, ELU). This is the default for most modern CNNs and MLPs.

fan_in vs fan_out: fan_in preserves forward pass variance (good for activations), fan_out preserves backward pass variance (good for gradients). In practice, fan_in is the default and works well.

Orthogonal Initialization

Initialize weights as orthogonal matrices, which preserve vector norms exactly:

nn.init.orthogonal_(layer.weight, gain=1.0)

When to use: RNNs and very deep networks. Orthogonal initialization prevents gradient explosion/vanishing more reliably than random initialization for recurrent architectures.

Scaled Initialization for Transformers

Transformers need special care because of residual connections. GPT-style models typically scale the output projection of each residual block:

# Scale by 1/sqrt(2 * num_layers) for residual stream stability
std = 0.02 / math.sqrt(2 * num_layers)
nn.init.normal_(output_projection.weight, mean=0.0, std=std)

This prevents the residual stream from growing with depth. Without it, the accumulated signal from residual connections can cause instability.

Why Modern Networks Are Less Sensitive

Three developments have reduced (but not eliminated) sensitivity to initialization:

1. Batch/Layer Normalization

Normalization layers re-center and re-scale activations at each layer, partially correcting for bad initialization. This is why BatchNorm “allows higher learning rates” — it makes training more robust to initialization.

But normalization doesn’t fix everything. The first few training steps still depend on initialization, and very bad initialization can cause NaN losses before normalization has a chance to stabilize.

2. Residual Connections

Skip connections provide a “gradient highway” that bypasses the vanishing gradient problem. Even if some layers are poorly initialized, gradients flow through the skip connections.

3. Adaptive Optimizers

Adam and its variants adapt per-parameter learning rates, partially compensating for initialization-induced scale differences. SGD is much more sensitive to initialization than Adam.

Practical Guidelines

ArchitectureRecommended Initialization
MLP with ReLUHe/Kaiming (fan_in)
CNN with ReLUHe/Kaiming (fan_in)
TransformerScaled normal (GPT-style)
LSTM/GRUOrthogonal for recurrent weights, Xavier for input weights
Embedding layersNormal(0, 0.02) or uniform
Output/logit layersSmall normal or zeros
Bias termsZeros (almost always)

Debugging Initialization Issues

Symptoms of bad initialization:

  • Loss starts at an unexpected value (cross-entropy should start near -ln(1/num_classes))
  • Loss doesn’t decrease in the first few steps
  • Activations are all zero or all saturated
  • Gradients are NaN or extremely large/small

Diagnostic checks:

# Check activation statistics after forward pass
for name, module in model.named_modules():
    if hasattr(module, 'weight'):
        hook = module.register_forward_hook(
            lambda m, i, o, n=name: print(f"{n}: mean={o.mean():.4f}, std={o.std():.4f}")
        )

If you see means far from 0 or stds far from 1 in early layers, your initialization is likely wrong.

The Bottom Line

For most practical work: use He initialization for ReLU networks, use your framework’s transformer-specific defaults for transformers, and use Xavier for everything else. Initialize biases to zero. If training is unstable, check initialization before tuning hyperparameters — it’s a common root cause that’s easy to fix.

Simplify

← The Transformer Architecture: How It Actually Works

Related reads

deep-learningweight-initializationtrainingneural-networksoptimization

Stay ahead of the AI curve

Weekly insights on AI — explained at the level that's right for you. No hype, no jargon, just what matters.

No spam. Unsubscribe anytime. We respect your inbox.