Multi-Head Attention
Multi-head attention is a key component of the Transformer architecture that runs multiple attention operations in parallel, allowing the model to jointly attend to information from different representation subspaces.
Single-Head vs Multi-Head
Single-Head Attention
Q, K, V → Attention → Output
Limitation: Only one "perspective" on the relationships
Multi-Head Attention
Q K V
│ │ │
┌────┴──┴──┴────┐
│ Split into │
│ h heads │
└───────┬────────┘
│
┌───────┼───────┐
↓ ↓ ↓
[Head1] [Head2] [Head3] ... [Head_h]
│ │ │ │
└───────┴───────┴───────────┘
│
Concatenate
│
Linear (W_O)
│
Output
Mathematical Formulation
Single Attention Head
Attention(Q, K, V) = softmax(QK^T / √d_k) × V
Multi-Head Attention
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) × W_O
where head_i = Attention(Q×W_i^Q, K×W_i^K, V×W_i^V)
Parameters:
W_i^Q: Query projection for head i, shape [d_model, d_k]W_i^K: Key projection for head i, shape [d_model, d_k]W_i^V: Value projection for head i, shape [d_model, d_v]W_O: Output projection, shape [h×d_v, d_model]
Dimensions
Typical configuration (BERT-base):
d_model = 768 # Model dimension
h = 12 # Number of heads
d_k = d_v = 64 # Dimension per head (768/12)
Total parameters per MHA layer:
3 × d_model × d_model (Q, K, V projections)
+ d_model × d_model (output projection)
= 4 × d_model²
= 4 × 768² = 2.36M parameters
Implementation
From Scratch
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. Linear projections
Q = self.W_q(query) # [batch, seq_q, d_model]
K = self.W_k(key) # [batch, seq_k, d_model]
V = self.W_v(value) # [batch, seq_k, d_model]
# 2. Split into heads
# [batch, seq, d_model] -> [batch, heads, seq, d_k]
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: [batch, heads, seq_q, seq_k]
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# context: [batch, heads, seq_q, d_k]
# 4. Concatenate heads
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, -1, self.d_model)
# context: [batch, seq_q, d_model]
# 5. Final linear projection
output = self.W_o(context)
return output, attn_weights
Using PyTorch Built-in
import torch.nn as nn
mha = nn.MultiheadAttention(
embed_dim=768,
num_heads=12,
dropout=0.1,
batch_first=True
)
# Self-attention
query = key = value = torch.randn(32, 100, 768) # [batch, seq, dim]
output, attn_weights = mha(query, key, value)
# output: [32, 100, 768]
# attn_weights: [32, 100, 100]
Why Multiple Heads?
1. Different Attention Patterns
Head 1: Focuses on adjacent words (local syntax)
Head 2: Connects subject to verb (long-range)
Head 3: Tracks coreference (pronouns to nouns)
Head 4: Attends to punctuation
...
2. Subspace Learning
Each head operates in a different d_k-dimensional subspace
Full space (d_model = 768):
[───────────────────────────────]
Head 1 subspace (d_k = 64):
[────]
Head 2: [────]
Head 3: [────]
...
3. Stabilizes Training
Multiple heads = ensemble effect
Noise in one head doesn't dominate
More robust gradients
Attention Types in Transformers
Self-Attention (Encoder)
# Query, Key, Value all come from same sequence
output = MultiHeadAttention(x, x, x)
Self-Attention with Causal Mask (Decoder)
# Prevent attending to future tokens
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
output = MultiHeadAttention(x, x, x, mask=~causal_mask)
Cross-Attention (Decoder)
# Query from decoder, Key/Value from encoder
output = MultiHeadAttention(
query=decoder_hidden,
key=encoder_output,
value=encoder_output
)
Visualizing Attention Heads
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(tokens, attention_weights, head_idx=0):
"""Visualize attention pattern for one head."""
weights = attention_weights[0, head_idx].detach().cpu().numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(
weights,
xticklabels=tokens,
yticklabels=tokens,
cmap='Blues',
annot=True,
fmt='.2f'
)
plt.title(f'Attention Head {head_idx}')
plt.xlabel('Key')
plt.ylabel('Query')
plt.show()
Efficient Attention Variants
Problem: O(n²) Complexity
Standard attention: O(n² × d)
For n = 1000, d = 768:
1000² × 768 = 768M operations per head
Solutions
Sparse Attention:
Only attend to subset of positions
- Local (window) + global (sparse) patterns
- Examples: Longformer, BigBird
Linear Attention:
Approximate softmax(QK^T)V with kernel trick
O(n × d²) instead of O(n² × d)
Examples: Linear Transformer, Performer
Flash Attention:
Same computation, better memory access pattern
Fuses operations to reduce memory bandwidth
Used in most modern implementations
Number of Heads Trade-offs
| Heads | Per-head dim | Behavior |
|---|---|---|
| 1 | 768 | Rich per-head, no diversity |
| 8 | 96 | Good balance |
| 12 | 64 | BERT default |
| 16 | 48 | More diversity, less capacity each |
| 64 | 12 | Extreme diversity |
Research shows not all heads are equally useful; some can be pruned.
Key Takeaways
- Multi-head attention runs h parallel attention operations
- Each head projects to lower-dimensional subspaces (d_k = d_model/h)
- Heads capture different relationship types in the data
- Outputs are concatenated and projected back to d_model
- Typical configurations: 8-16 heads, 64-128 dim per head
- Different attention types: self, causal, cross-attention