Mixture of Experts
Mixture of Experts (MoE) is an architecture that uses multiple specialized sub-networks (experts) and a gating mechanism to route inputs to the most relevant experts, enabling massive parameter counts with manageable computation.
Core Concept
Input → [Gating Network] → Select top-k experts
↓
┌─────┼─────┐
↓ ↓ ↓
[Expert1][Expert2]...[ExpertN]
│ │
↓ ↓
Weighted combination → Output
How It Works
1. Expert Networks
Each expert is typically a feed-forward network:
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
2. Gating Network (Router)
Decides which experts to use for each input:
class Router(nn.Module):
def __init__(self, input_dim, num_experts, top_k=2):
super().__init__()
self.gate = nn.Linear(input_dim, num_experts)
self.top_k = top_k
def forward(self, x):
# Get expert scores
logits = self.gate(x) # [batch, num_experts]
# Select top-k experts
top_k_values, top_k_indices = torch.topk(logits, self.top_k)
top_k_weights = F.softmax(top_k_values, dim=-1)
return top_k_weights, top_k_indices
3. MoE Layer
class MoELayer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_experts, top_k=2):
super().__init__()
self.experts = nn.ModuleList([
Expert(input_dim, hidden_dim, input_dim)
for _ in range(num_experts)
])
self.router = Router(input_dim, num_experts, top_k)
def forward(self, x):
# x: [batch, seq, dim]
batch_size, seq_len, dim = x.shape
x_flat = x.view(-1, dim) # [batch*seq, dim]
# Get routing weights
weights, indices = self.router(x_flat)
# Compute expert outputs (sparse)
output = torch.zeros_like(x_flat)
for i, expert in enumerate(self.experts):
# Find tokens routed to this expert
mask = (indices == i).any(dim=-1)
if mask.any():
expert_input = x_flat[mask]
expert_output = expert(expert_input)
# Weight by routing probability
expert_weight = weights[mask, (indices[mask] == i).float().argmax(-1)]
output[mask] += expert_weight.unsqueeze(-1) * expert_output
return output.view(batch_size, seq_len, dim)
Sparse vs Dense
| Aspect | Dense Model | MoE Model |
|---|---|---|
| Parameters | All active | Subset active |
| FLOPs | Proportional to params | Much lower |
| Example | 7B params, 7B active | 47B params, 6B active |
Dense 7B: ████████████████████ (all params used)
MoE 8×7B: ████░░░░░░░░░░░░████ (2 of 8 experts)
56B params, ~14B active
Load Balancing
The Problem
Without balancing, some experts get overloaded:
Expert 1: ████████████████████ (80% of tokens)
Expert 2: ████ (10%)
Expert 3: ██ (5%)
Expert 4: ██ (5%)
Auxiliary Loss
def load_balancing_loss(router_probs, expert_indices, num_experts):
# router_probs: [batch, num_experts] - probability of each expert
# expert_indices: [batch, top_k] - selected experts
# Fraction of tokens routed to each expert
expert_counts = torch.zeros(num_experts)
for i in range(num_experts):
expert_counts[i] = (expert_indices == i).float().sum()
tokens_per_expert = expert_counts / expert_counts.sum()
# Average routing probability per expert
avg_prob_per_expert = router_probs.mean(dim=0)
# Auxiliary loss encourages uniform distribution
aux_loss = num_experts * (tokens_per_expert * avg_prob_per_expert).sum()
return aux_loss
Capacity Factor
Limit tokens per expert to prevent overflow:
capacity = (num_tokens / num_experts) * capacity_factor
# capacity_factor = 1.0: each expert handles equal share
# capacity_factor = 1.25: 25% buffer for imbalance
MoE in Transformers
Replace FFN layers with MoE:
Standard Transformer Block:
→ Attention
→ FFN (dense)
MoE Transformer Block:
→ Attention
→ MoE Layer (sparse FFN)
class MoETransformerBlock(nn.Module):
def __init__(self, dim, num_heads, num_experts, top_k):
super().__init__()
self.attention = MultiHeadAttention(dim, num_heads)
self.moe = MoELayer(dim, dim * 4, num_experts, top_k)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.moe(self.norm2(x))
return x
Notable MoE Models
| Model | Total Params | Active Params | Experts |
|---|---|---|---|
| Mixtral 8x7B | 47B | ~13B | 8 |
| GPT-4 (rumored) | ~1.8T | ~220B | 8×16 |
| Switch Transformer | 1.6T | ~100B | 2048 |
| GShard | 600B | ~7B | 2048 |
Training Considerations
1. Expert Parallelism
GPU 0: Experts 1-2 + Router
GPU 1: Experts 3-4
GPU 2: Experts 5-6
GPU 3: Experts 7-8
Tokens are routed between GPUs based on expert selection
2. All-to-All Communication
# Tokens need to reach their assigned experts across GPUs
def expert_parallel_forward(tokens, router, experts):
# 1. Route tokens to experts
expert_assignments = router(tokens)
# 2. All-to-all: send tokens to expert-owning GPUs
dispatched = all_to_all(tokens, expert_assignments)
# 3. Process with local experts
outputs = local_experts(dispatched)
# 4. All-to-all: return outputs
return all_to_all(outputs, reverse=True)
3. Stability Tricks
# Add noise during training for exploration
def noisy_top_k_gating(logits, noise_std=0.1):
if self.training:
noise = torch.randn_like(logits) * noise_std
logits = logits + noise
return torch.topk(logits, k=self.top_k)
Advantages
- Scale capacity without proportional compute increase
- Specialization: experts learn different patterns
- Efficiency: only subset of parameters active per input
- Flexibility: can add experts without full retraining
Challenges
- Load balancing is crucial and tricky
- Communication overhead in distributed training
- Expert collapse: some experts may never be used
- Memory: all experts must fit in memory
- Inference complexity: routing adds overhead
Key Takeaways
- MoE uses multiple experts with a router to select active ones
- Enables massive parameter counts with manageable compute
- Typical setup: top-2 routing among 8-2048 experts
- Load balancing loss prevents expert collapse
- Expert parallelism enables training across GPUs
- Powers large models like Mixtral and likely GPT-4