advancedDeep Learning

Learn about Mixture of Experts (MoE) architecture, which scales model capacity while keeping computation manageable through sparse activation.

moearchitecturetransformersscalingefficiency

Mixture of Experts

Mixture of Experts (MoE) is an architecture that uses multiple specialized sub-networks (experts) and a gating mechanism to route inputs to the most relevant experts, enabling massive parameter counts with manageable computation.

Core Concept

Input → [Gating Network] → Select top-k experts
              ↓
        ┌─────┼─────┐
        ↓     ↓     ↓
    [Expert1][Expert2]...[ExpertN]
        │     │
        ↓     ↓
    Weighted combination → Output

How It Works

1. Expert Networks

Each expert is typically a feed-forward network:

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

2. Gating Network (Router)

Decides which experts to use for each input:

class Router(nn.Module):
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts)
        self.top_k = top_k
    
    def forward(self, x):
        # Get expert scores
        logits = self.gate(x)  # [batch, num_experts]
        
        # Select top-k experts
        top_k_values, top_k_indices = torch.topk(logits, self.top_k)
        top_k_weights = F.softmax(top_k_values, dim=-1)
        
        return top_k_weights, top_k_indices

3. MoE Layer

class MoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, input_dim) 
            for _ in range(num_experts)
        ])
        self.router = Router(input_dim, num_experts, top_k)
    
    def forward(self, x):
        # x: [batch, seq, dim]
        batch_size, seq_len, dim = x.shape
        x_flat = x.view(-1, dim)  # [batch*seq, dim]
        
        # Get routing weights
        weights, indices = self.router(x_flat)
        
        # Compute expert outputs (sparse)
        output = torch.zeros_like(x_flat)
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            mask = (indices == i).any(dim=-1)
            if mask.any():
                expert_input = x_flat[mask]
                expert_output = expert(expert_input)
                
                # Weight by routing probability
                expert_weight = weights[mask, (indices[mask] == i).float().argmax(-1)]
                output[mask] += expert_weight.unsqueeze(-1) * expert_output
        
        return output.view(batch_size, seq_len, dim)

Sparse vs Dense

AspectDense ModelMoE Model
ParametersAll activeSubset active
FLOPsProportional to paramsMuch lower
Example7B params, 7B active47B params, 6B active
Dense 7B:    ████████████████████  (all params used)

MoE 8×7B:    ████░░░░░░░░░░░░████  (2 of 8 experts)
             56B params, ~14B active

Load Balancing

The Problem

Without balancing, some experts get overloaded:

Expert 1: ████████████████████ (80% of tokens)
Expert 2: ████ (10%)
Expert 3: ██ (5%)
Expert 4: ██ (5%)

Auxiliary Loss

def load_balancing_loss(router_probs, expert_indices, num_experts):
    # router_probs: [batch, num_experts] - probability of each expert
    # expert_indices: [batch, top_k] - selected experts
    
    # Fraction of tokens routed to each expert
    expert_counts = torch.zeros(num_experts)
    for i in range(num_experts):
        expert_counts[i] = (expert_indices == i).float().sum()
    tokens_per_expert = expert_counts / expert_counts.sum()
    
    # Average routing probability per expert
    avg_prob_per_expert = router_probs.mean(dim=0)
    
    # Auxiliary loss encourages uniform distribution
    aux_loss = num_experts * (tokens_per_expert * avg_prob_per_expert).sum()
    
    return aux_loss

Capacity Factor

Limit tokens per expert to prevent overflow:

capacity = (num_tokens / num_experts) * capacity_factor

# capacity_factor = 1.0: each expert handles equal share
# capacity_factor = 1.25: 25% buffer for imbalance

MoE in Transformers

Replace FFN layers with MoE:

Standard Transformer Block:
  → Attention
  → FFN (dense)

MoE Transformer Block:
  → Attention
  → MoE Layer (sparse FFN)
class MoETransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, num_experts, top_k):
        super().__init__()
        self.attention = MultiHeadAttention(dim, num_heads)
        self.moe = MoELayer(dim, dim * 4, num_experts, top_k)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    
    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.moe(self.norm2(x))
        return x

Notable MoE Models

ModelTotal ParamsActive ParamsExperts
Mixtral 8x7B47B~13B8
GPT-4 (rumored)~1.8T~220B8×16
Switch Transformer1.6T~100B2048
GShard600B~7B2048

Training Considerations

1. Expert Parallelism

GPU 0: Experts 1-2 + Router
GPU 1: Experts 3-4
GPU 2: Experts 5-6
GPU 3: Experts 7-8

Tokens are routed between GPUs based on expert selection

2. All-to-All Communication

# Tokens need to reach their assigned experts across GPUs
def expert_parallel_forward(tokens, router, experts):
    # 1. Route tokens to experts
    expert_assignments = router(tokens)
    
    # 2. All-to-all: send tokens to expert-owning GPUs
    dispatched = all_to_all(tokens, expert_assignments)
    
    # 3. Process with local experts
    outputs = local_experts(dispatched)
    
    # 4. All-to-all: return outputs
    return all_to_all(outputs, reverse=True)

3. Stability Tricks

# Add noise during training for exploration
def noisy_top_k_gating(logits, noise_std=0.1):
    if self.training:
        noise = torch.randn_like(logits) * noise_std
        logits = logits + noise
    return torch.topk(logits, k=self.top_k)

Advantages

  1. Scale capacity without proportional compute increase
  2. Specialization: experts learn different patterns
  3. Efficiency: only subset of parameters active per input
  4. Flexibility: can add experts without full retraining

Challenges

  1. Load balancing is crucial and tricky
  2. Communication overhead in distributed training
  3. Expert collapse: some experts may never be used
  4. Memory: all experts must fit in memory
  5. Inference complexity: routing adds overhead

Key Takeaways

  1. MoE uses multiple experts with a router to select active ones
  2. Enables massive parameter counts with manageable compute
  3. Typical setup: top-2 routing among 8-2048 experts
  4. Load balancing loss prevents expert collapse
  5. Expert parallelism enables training across GPUs
  6. Powers large models like Mixtral and likely GPT-4