Softmax Function
Softmax converts a vector of raw scores (logits) into a probability distribution. It's the standard output layer for multi-class classification.
Definition
softmax(zᵢ) = exp(zᵢ) / Σⱼ exp(zⱼ)
For a vector z = [z₁, z₂, z₃]:
Input: [2.0, 1.0, 0.1]
↓ exp
[7.39, 2.72, 1.11]
↓ normalize (sum = 11.22)
Output: [0.66, 0.24, 0.10]
Properties
1. Outputs Sum to 1
Σᵢ softmax(zᵢ) = 1
Valid probability distribution.
2. All Outputs Positive
0 < softmax(zᵢ) < 1
Exponentials are always positive.
3. Preserves Ordering
If z₁ > z₂, then softmax(z₁) > softmax(z₂)
Larger logits → larger probabilities.
4. Sensitive to Scale
softmax([1, 2, 3]) = [0.09, 0.24, 0.67]
softmax([10, 20, 30]) = [0.00, 0.00, 1.00] # More extreme
This is the basis for temperature scaling.
Softmax vs Sigmoid
Sigmoid (Binary)
σ(z) = 1 / (1 + exp(-z))
Output: Single probability for one class.
Softmax (Multi-class)
softmax(z)ᵢ = exp(zᵢ) / Σⱼ exp(zⱼ)
Output: Probability for each class.
When to Use Each
| Scenario | Activation | Output |
|---|---|---|
| Binary classification | Sigmoid | P(class=1) |
| Multi-class (exclusive) | Softmax | P(class=i) for each i |
| Multi-label | Sigmoid per class | Independent probabilities |
Temperature Scaling
Definition
softmax(zᵢ/T) = exp(zᵢ/T) / Σⱼ exp(zⱼ/T)
Effect of Temperature
| T | Effect | Distribution |
|---|---|---|
| T → 0 | Argmax (winner-take-all) | [0, 0, 1, 0] |
| T = 1 | Standard softmax | [0.1, 0.2, 0.5, 0.2] |
| T → ∞ | Uniform | [0.25, 0.25, 0.25, 0.25] |
Used in knowledge distillation and LLM sampling.
Numerical Stability
The Problem
exp(1000) = overflow → Inf
exp(-1000) = underflow → 0
The Solution
Subtract max before exponentiating:
softmax(z) = softmax(z - max(z))
def stable_softmax(z):
z_shifted = z - np.max(z) # Prevent overflow
exp_z = np.exp(z_shifted)
return exp_z / np.sum(exp_z)
Cross-Entropy Loss
Softmax is paired with cross-entropy loss:
Loss = -Σᵢ yᵢ log(softmax(zᵢ))
For one-hot y (only one class is 1):
Loss = -log(softmax(z_correct))
Combined: LogSoftmax + NLLLoss
# Efficient and stable
nn.CrossEntropyLoss()(logits, labels)
# Equivalent to:
loss = -F.log_softmax(logits, dim=-1)[labels]
Gradient
The gradient of softmax:
∂softmax(zᵢ)/∂zⱼ = softmax(zᵢ)(δᵢⱼ - softmax(zⱼ))
With cross-entropy, simplifies to:
∂Loss/∂zᵢ = softmax(zᵢ) - yᵢ
Beautiful and simple gradient!
Implementation
NumPy
import numpy as np
def softmax(z):
z = z - np.max(z, axis=-1, keepdims=True)
exp_z = np.exp(z)
return exp_z / np.sum(exp_z, axis=-1, keepdims=True)
# Example
logits = np.array([2.0, 1.0, 0.1])
probs = softmax(logits)
print(probs) # [0.659, 0.242, 0.099]
PyTorch
import torch.nn.functional as F
logits = torch.tensor([2.0, 1.0, 0.1])
probs = F.softmax(logits, dim=-1)
# For classification
loss = F.cross_entropy(logits.unsqueeze(0), labels)
Common Use Cases
Multi-class Classification
class Classifier(nn.Module):
def forward(self, x):
logits = self.final_layer(x) # [batch, num_classes]
return logits # Let CrossEntropyLoss handle softmax
Attention Mechanism
def attention(Q, K, V):
scores = Q @ K.T / sqrt(d_k)
weights = F.softmax(scores, dim=-1) # Attention weights
return weights @ V
Mixture Models
def mixture_of_experts(x, gating_logits, experts):
weights = F.softmax(gating_logits, dim=-1)
outputs = [expert(x) for expert in experts]
return sum(w * out for w, out in zip(weights, outputs))
Alternatives
Sparsemax
Can output exact zeros (sparse attention):
sparsemax([1, 2, 3]) → [0, 0.25, 0.75] # Can have zeros
Gumbel-Softmax
Differentiable sampling:
gumbel_softmax(logits) ≈ one-hot sample
Used for discrete choices in training.
Hardmax
hardmax = one-hot(argmax(z))
Not differentiable, used at inference.
Key Takeaways
- Softmax converts logits to probabilities summing to 1
- Use with cross-entropy for multi-class classification
- Temperature controls sharpness of distribution
- Always use numerically stable implementation
- Sigmoid for binary/multi-label, softmax for multi-class
- Don't apply softmax before CrossEntropyLoss in PyTorch