intermediateDeep Learning

Learn how multi-head attention allows transformers to attend to information from different representation subspaces at different positions.

attentiontransformersdeep-learningnlparchitecture

Multi-Head Attention

Multi-head attention is a key component of the Transformer architecture that runs multiple attention operations in parallel, allowing the model to jointly attend to information from different representation subspaces.

Single-Head vs Multi-Head

Single-Head Attention

Q, K, V → Attention → Output

Limitation: Only one "perspective" on the relationships

Multi-Head Attention

        Q  K  V
        │  │  │
   ┌────┴──┴──┴────┐
   │   Split into   │
   │    h heads     │
   └───────┬────────┘
           │
   ┌───────┼───────┐
   ↓       ↓       ↓
[Head1] [Head2] [Head3] ... [Head_h]
   │       │       │           │
   └───────┴───────┴───────────┘
           │
     Concatenate
           │
     Linear (W_O)
           │
        Output

Mathematical Formulation

Single Attention Head

Attention(Q, K, V) = softmax(QK^T / √d_k) × V

Multi-Head Attention

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) × W_O

where head_i = Attention(Q×W_i^Q, K×W_i^K, V×W_i^V)

Parameters:

  • W_i^Q: Query projection for head i, shape [d_model, d_k]
  • W_i^K: Key projection for head i, shape [d_model, d_k]
  • W_i^V: Value projection for head i, shape [d_model, d_v]
  • W_O: Output projection, shape [h×d_v, d_model]

Dimensions

Typical configuration (BERT-base):
  d_model = 768        # Model dimension
  h = 12               # Number of heads
  d_k = d_v = 64       # Dimension per head (768/12)

Total parameters per MHA layer:
  3 × d_model × d_model (Q, K, V projections)
  + d_model × d_model   (output projection)
  = 4 × d_model²
  = 4 × 768² = 2.36M parameters

Implementation

From Scratch

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1. Linear projections
        Q = self.W_q(query)  # [batch, seq_q, d_model]
        K = self.W_k(key)    # [batch, seq_k, d_model]
        V = self.W_v(value)  # [batch, seq_k, d_model]
        
        # 2. Split into heads
        # [batch, seq, d_model] -> [batch, heads, seq, d_k]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores: [batch, heads, seq_q, seq_k]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        # context: [batch, heads, seq_q, d_k]
        
        # 4. Concatenate heads
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, -1, self.d_model)
        # context: [batch, seq_q, d_model]
        
        # 5. Final linear projection
        output = self.W_o(context)
        
        return output, attn_weights

Using PyTorch Built-in

import torch.nn as nn

mha = nn.MultiheadAttention(
    embed_dim=768,
    num_heads=12,
    dropout=0.1,
    batch_first=True
)

# Self-attention
query = key = value = torch.randn(32, 100, 768)  # [batch, seq, dim]
output, attn_weights = mha(query, key, value)
# output: [32, 100, 768]
# attn_weights: [32, 100, 100]

Why Multiple Heads?

1. Different Attention Patterns

Head 1: Focuses on adjacent words (local syntax)
Head 2: Connects subject to verb (long-range)
Head 3: Tracks coreference (pronouns to nouns)
Head 4: Attends to punctuation
...

2. Subspace Learning

Each head operates in a different d_k-dimensional subspace

Full space (d_model = 768):
  [───────────────────────────────]
  
Head 1 subspace (d_k = 64):
  [────]                           
           Head 2: [────]          
                        Head 3: [────]
                                    ...

3. Stabilizes Training

Multiple heads = ensemble effect
Noise in one head doesn't dominate
More robust gradients

Attention Types in Transformers

Self-Attention (Encoder)

# Query, Key, Value all come from same sequence
output = MultiHeadAttention(x, x, x)

Self-Attention with Causal Mask (Decoder)

# Prevent attending to future tokens
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
output = MultiHeadAttention(x, x, x, mask=~causal_mask)

Cross-Attention (Decoder)

# Query from decoder, Key/Value from encoder
output = MultiHeadAttention(
    query=decoder_hidden,
    key=encoder_output,
    value=encoder_output
)

Visualizing Attention Heads

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(tokens, attention_weights, head_idx=0):
    """Visualize attention pattern for one head."""
    weights = attention_weights[0, head_idx].detach().cpu().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        weights,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        annot=True,
        fmt='.2f'
    )
    plt.title(f'Attention Head {head_idx}')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.show()

Efficient Attention Variants

Problem: O(n²) Complexity

Standard attention: O(n² × d)

For n = 1000, d = 768:
  1000² × 768 = 768M operations per head

Solutions

Sparse Attention:

Only attend to subset of positions
- Local (window) + global (sparse) patterns
- Examples: Longformer, BigBird

Linear Attention:

Approximate softmax(QK^T)V with kernel trick
O(n × d²) instead of O(n² × d)
Examples: Linear Transformer, Performer

Flash Attention:

Same computation, better memory access pattern
Fuses operations to reduce memory bandwidth
Used in most modern implementations

Number of Heads Trade-offs

HeadsPer-head dimBehavior
1768Rich per-head, no diversity
896Good balance
1264BERT default
1648More diversity, less capacity each
6412Extreme diversity

Research shows not all heads are equally useful; some can be pruned.

Key Takeaways

  1. Multi-head attention runs h parallel attention operations
  2. Each head projects to lower-dimensional subspaces (d_k = d_model/h)
  3. Heads capture different relationship types in the data
  4. Outputs are concatenated and projected back to d_model
  5. Typical configurations: 8-16 heads, 64-128 dim per head
  6. Different attention types: self, causal, cross-attention

Practice Questions

Test your understanding with these related interview questions: