Beam Search
Beam search is a search algorithm used in sequence generation that maintains multiple candidate sequences (beams) at each step, balancing between greedy search and exhaustive search.
The Decoding Problem
When generating sequences (text, translations), we need to find:
Y* = argmax P(Y|X)
Y
But exhaustive search over all possible sequences is intractable:
- Vocabulary size V = 50,000
- Sequence length T = 20
- Possible sequences = V^T = 50,000^20 ≈ 10^94
Greedy vs Beam Search
Greedy Search
Always pick the highest probability token:
Step 1: "The" (0.4) ← pick
Step 2: "cat" (0.3) ← pick
Step 3: "sat" (0.5) ← pick
Result: "The cat sat" P = 0.4 × 0.3 × 0.5 = 0.06
Problem: Locally optimal choices may not be globally optimal.
Beam Search
Keep top-k candidates at each step:
Beam width k = 2
Step 1:
"The" (0.4) ← keep
"A" (0.35) ← keep
Step 2:
"The cat" (0.4 × 0.3 = 0.12) ← keep
"A small" (0.35 × 0.4 = 0.14) ← keep
"The dog" (0.4 × 0.2 = 0.08) ✗
"A cat" (0.35 × 0.3 = 0.105) ✗
Step 3:
Continue with top 2 candidates...
Algorithm
def beam_search(model, input_seq, beam_width=5, max_length=50):
# Initialize with start token
beams = [([], 0.0)] # (sequence, log_probability)
for _ in range(max_length):
all_candidates = []
for seq, score in beams:
if seq and seq[-1] == EOS_TOKEN:
all_candidates.append((seq, score))
continue
# Get next token probabilities
probs = model.predict_next(input_seq, seq)
# Add all possible extensions
for token_id, log_prob in enumerate(probs):
new_seq = seq + [token_id]
new_score = score + log_prob
all_candidates.append((new_seq, new_score))
# Keep top k beams
beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
# Check if all beams ended
if all(seq[-1] == EOS_TOKEN for seq, _ in beams if seq):
break
return beams[0][0] # Return best sequence
Length Normalization
The Problem
Longer sequences have lower probabilities (more multiplications):
"Hello" → P = 0.1
"Hello world" → P = 0.1 × 0.05 = 0.005
Beam search unfairly favors shorter sequences.
Solution: Normalize by Length
def length_normalized_score(log_prob, length, alpha=0.7):
# alpha controls normalization strength
# alpha = 0: no normalization
# alpha = 1: full normalization
return log_prob / (length ** alpha)
Beam Search Parameters
Beam Width (k)
| Beam Width | Behavior |
|---|---|
| k = 1 | Greedy search |
| k = 3-5 | Common for production |
| k = 10-20 | Higher quality, slower |
| k → ∞ | Exhaustive search |
Length Penalty (alpha)
alpha = 0.0: No penalty (favors short)
alpha = 0.6-0.7: Typical value
alpha = 1.0: Full normalization
alpha > 1.0: Favors longer sequences
Diverse Beam Search
Standard beam search often produces similar outputs:
Beam 1: "The cat sat on the mat"
Beam 2: "The cat sat on the rug"
Beam 3: "The cat sat on the floor"
Solutions
1. Group-based diversity:
# Penalize similarity between beams in different groups
diverse_score = score - diversity_penalty * similarity_to_other_groups
2. Sampling from top-k:
# Instead of always taking top-k, sample from top-k
probs = softmax(logits / temperature)
sampled = np.random.choice(vocab, p=probs)
Comparison of Decoding Methods
| Method | Speed | Quality | Diversity |
|---|---|---|---|
| Greedy | Fast | Low | None |
| Beam Search | Medium | High | Low |
| Sampling | Fast | Variable | High |
| Top-k Sampling | Fast | Good | Medium |
| Nucleus (top-p) | Fast | Good | Medium |
When to Use Beam Search
Good For
- Machine translation
- Summarization
- Image captioning
- Any task with a "correct" output
Not Ideal For
- Creative text generation (use sampling)
- Chatbots (can be repetitive)
- Open-ended generation
Implementation Tips
1. Use Log Probabilities
# Avoid numerical underflow
log_prob = sum(log(p_i)) # Good
prob = prod(p_i) # Underflows!
2. Batch Beam Search
# Process all beams in parallel
beam_batch = torch.stack([seq for seq, _ in beams])
all_probs = model(input_seq.expand(beam_width, -1), beam_batch)
3. Early Stopping
# Stop when top beam is complete and no other can beat it
if best_complete_score > max(incomplete_scores):
break
Practical Example: Hugging Face
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = tokenizer.encode("The future of AI is", return_tensors='pt')
# Beam search
outputs = model.generate(
input_ids,
max_length=50,
num_beams=5, # Beam width
length_penalty=0.7, # Length normalization
early_stopping=True,
no_repeat_ngram_size=2 # Prevent repetition
)
print(tokenizer.decode(outputs[0]))
Key Takeaways
- Beam search maintains k candidates at each step
- Beam width balances quality vs speed (k=3-5 typical)
- Length normalization prevents bias toward short sequences
- Use log probabilities to avoid numerical underflow
- For creative tasks, consider sampling instead
- Diverse beam search helps when you need variety