intermediateNatural Language Processing

Learn how beam search balances exploration and efficiency in sequence generation by maintaining multiple candidate sequences.

decodingsequence-generationnlpsearch-algorithms

Beam Search

Beam search is a search algorithm used in sequence generation that maintains multiple candidate sequences (beams) at each step, balancing between greedy search and exhaustive search.

The Decoding Problem

When generating sequences (text, translations), we need to find:

Y* = argmax P(Y|X)
    Y

But exhaustive search over all possible sequences is intractable:

  • Vocabulary size V = 50,000
  • Sequence length T = 20
  • Possible sequences = V^T = 50,000^20 ≈ 10^94

Greedy vs Beam Search

Greedy Search

Always pick the highest probability token:

Step 1: "The" (0.4) ← pick
Step 2: "cat" (0.3) ← pick
Step 3: "sat" (0.5) ← pick

Result: "The cat sat" P = 0.4 × 0.3 × 0.5 = 0.06

Problem: Locally optimal choices may not be globally optimal.

Beam Search

Keep top-k candidates at each step:

Beam width k = 2

Step 1:
  "The" (0.4) ← keep
  "A" (0.35)  ← keep

Step 2:
  "The cat" (0.4 × 0.3 = 0.12)    ← keep
  "A small" (0.35 × 0.4 = 0.14)   ← keep
  "The dog" (0.4 × 0.2 = 0.08)    ✗
  "A cat" (0.35 × 0.3 = 0.105)    ✗

Step 3:
  Continue with top 2 candidates...

Algorithm

def beam_search(model, input_seq, beam_width=5, max_length=50):
    # Initialize with start token
    beams = [([], 0.0)]  # (sequence, log_probability)
    
    for _ in range(max_length):
        all_candidates = []
        
        for seq, score in beams:
            if seq and seq[-1] == EOS_TOKEN:
                all_candidates.append((seq, score))
                continue
                
            # Get next token probabilities
            probs = model.predict_next(input_seq, seq)
            
            # Add all possible extensions
            for token_id, log_prob in enumerate(probs):
                new_seq = seq + [token_id]
                new_score = score + log_prob
                all_candidates.append((new_seq, new_score))
        
        # Keep top k beams
        beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        
        # Check if all beams ended
        if all(seq[-1] == EOS_TOKEN for seq, _ in beams if seq):
            break
    
    return beams[0][0]  # Return best sequence

Length Normalization

The Problem

Longer sequences have lower probabilities (more multiplications):

"Hello" → P = 0.1
"Hello world" → P = 0.1 × 0.05 = 0.005

Beam search unfairly favors shorter sequences.

Solution: Normalize by Length

def length_normalized_score(log_prob, length, alpha=0.7):
    # alpha controls normalization strength
    # alpha = 0: no normalization
    # alpha = 1: full normalization
    return log_prob / (length ** alpha)

Beam Search Parameters

Beam Width (k)

Beam WidthBehavior
k = 1Greedy search
k = 3-5Common for production
k = 10-20Higher quality, slower
k → ∞Exhaustive search

Length Penalty (alpha)

alpha = 0.0: No penalty (favors short)
alpha = 0.6-0.7: Typical value
alpha = 1.0: Full normalization
alpha > 1.0: Favors longer sequences

Diverse Beam Search

Standard beam search often produces similar outputs:

Beam 1: "The cat sat on the mat"
Beam 2: "The cat sat on the rug"
Beam 3: "The cat sat on the floor"

Solutions

1. Group-based diversity:

# Penalize similarity between beams in different groups
diverse_score = score - diversity_penalty * similarity_to_other_groups

2. Sampling from top-k:

# Instead of always taking top-k, sample from top-k
probs = softmax(logits / temperature)
sampled = np.random.choice(vocab, p=probs)

Comparison of Decoding Methods

MethodSpeedQualityDiversity
GreedyFastLowNone
Beam SearchMediumHighLow
SamplingFastVariableHigh
Top-k SamplingFastGoodMedium
Nucleus (top-p)FastGoodMedium

When to Use Beam Search

Good For

  • Machine translation
  • Summarization
  • Image captioning
  • Any task with a "correct" output

Not Ideal For

  • Creative text generation (use sampling)
  • Chatbots (can be repetitive)
  • Open-ended generation

Implementation Tips

1. Use Log Probabilities

# Avoid numerical underflow
log_prob = sum(log(p_i))  # Good
prob = prod(p_i)          # Underflows!

2. Batch Beam Search

# Process all beams in parallel
beam_batch = torch.stack([seq for seq, _ in beams])
all_probs = model(input_seq.expand(beam_width, -1), beam_batch)

3. Early Stopping

# Stop when top beam is complete and no other can beat it
if best_complete_score > max(incomplete_scores):
    break

Practical Example: Hugging Face

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

input_ids = tokenizer.encode("The future of AI is", return_tensors='pt')

# Beam search
outputs = model.generate(
    input_ids,
    max_length=50,
    num_beams=5,          # Beam width
    length_penalty=0.7,   # Length normalization
    early_stopping=True,
    no_repeat_ngram_size=2  # Prevent repetition
)

print(tokenizer.decode(outputs[0]))

Key Takeaways

  1. Beam search maintains k candidates at each step
  2. Beam width balances quality vs speed (k=3-5 typical)
  3. Length normalization prevents bias toward short sequences
  4. Use log probabilities to avoid numerical underflow
  5. For creative tasks, consider sampling instead
  6. Diverse beam search helps when you need variety