Loss Functions
Loss functions measure how wrong your model's predictions are. They're the objective that training optimizes. Choosing the right loss function is crucial for model performance.
The Role of Loss Functions
Input → Model → Prediction → Loss(prediction, target) → Gradients → Update
The loss function:
- Quantifies prediction error
- Provides gradients for optimization
- Defines what "good" means for your task
Regression Loss Functions
Mean Squared Error (MSE) / L2 Loss
MSE = (1/n) × Σ(yᵢ - ŷᵢ)²
Properties:
- Penalizes large errors heavily (squared)
- Differentiable everywhere
- Sensitive to outliers
Use when:
- Errors are normally distributed
- Large errors are especially bad
- Standard regression tasks
Mean Absolute Error (MAE) / L1 Loss
MAE = (1/n) × Σ|yᵢ - ŷᵢ|
Properties:
- Linear penalty (not squared)
- More robust to outliers
- Not differentiable at zero
Use when:
- Data has outliers
- All errors equally important
- Median prediction desired
Huber Loss (Smooth L1)
Huber = {
0.5 × (y - ŷ)² if |y - ŷ| ≤ δ
δ × |y - ŷ| - 0.5 × δ² if |y - ŷ| > δ
}
Properties:
- MSE for small errors (smooth gradients)
- MAE for large errors (robust to outliers)
- Best of both worlds
Use when:
- Want robustness but smooth optimization
- Object detection (bounding box regression)
Log-Cosh Loss
Log-Cosh = Σ log(cosh(yᵢ - ŷᵢ))
Properties:
- Approximately MSE for small errors
- Approximately MAE for large errors
- Twice differentiable (good for Newton methods)
Classification Loss Functions
Binary Cross-Entropy
BCE = -(1/n) × Σ[yᵢ log(ŷᵢ) + (1-yᵢ) log(1-ŷᵢ)]
Use with: Sigmoid output, binary classification
nn.BCEWithLogitsLoss() # More stable than BCE + Sigmoid
Categorical Cross-Entropy
CCE = -(1/n) × Σ Σ yᵢⱼ log(ŷᵢⱼ)
Use with: Softmax output, multi-class classification
nn.CrossEntropyLoss() # Combines LogSoftmax + NLLLoss
Focal Loss
FL = -αₜ(1 - pₜ)^γ log(pₜ)
Properties:
- Down-weights easy examples
- Focuses on hard examples
- γ controls focusing (γ=0 is CE)
Use when:
- Class imbalance
- Object detection (many background examples)
Hinge Loss (SVM)
Hinge = max(0, 1 - y × ŷ)
Properties:
- Zero loss when correctly classified with margin
- Linear penalty for violations
Use when:
- Maximum margin classification
- SVMs
Ranking Loss Functions
Triplet Loss
L = max(0, d(anchor, positive) - d(anchor, negative) + margin)
Use when:
- Learning embeddings
- Face recognition
- Similarity learning
Contrastive Loss
L = (1-y) × d² + y × max(0, margin - d)²
Use when:
- Siamese networks
- Verification tasks
Specialized Loss Functions
KL Divergence
KL(P || Q) = Σ P(x) log(P(x) / Q(x))
Use when:
- Comparing distributions
- VAE latent space regularization
- Knowledge distillation
Dice Loss / IoU Loss
Dice = 1 - 2|A ∩ B| / (|A| + |B|)
Use when:
- Image segmentation
- Imbalanced segmentation masks
CTC Loss (Connectionist Temporal Classification)
Use when:
- Sequence-to-sequence without alignment
- Speech recognition
- OCR
Choosing the Right Loss
| Task | Loss Function |
|---|---|
| Regression | MSE, MAE, Huber |
| Binary classification | BCE |
| Multi-class classification | Cross-Entropy |
| Multi-label classification | BCE (per label) |
| Imbalanced classification | Focal Loss |
| Segmentation | Cross-Entropy + Dice |
| Embeddings | Triplet, Contrastive |
| Object detection | Focal + Smooth L1 |
| Generative (VAE) | Reconstruction + KL |
Custom Loss Functions
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self, weight=1.0):
super().__init__()
self.weight = weight
def forward(self, pred, target):
loss = ... # Your loss computation
return loss
# Use it
criterion = CustomLoss(weight=2.0)
loss = criterion(predictions, targets)
loss.backward()
Combining Losses
# Weighted combination
loss = α × loss1 + β × loss2 + γ × loss3
# Example: Segmentation
loss = cross_entropy_loss + 0.5 * dice_loss
# Example: VAE
loss = reconstruction_loss + β * kl_divergence
Common Pitfalls
1. Class Imbalance
# Wrong: Unweighted CE with imbalanced classes
# Right: Weighted CE or Focal Loss
nn.CrossEntropyLoss(weight=class_weights)
2. Wrong Scale
# MSE for classification → bad gradients
# Cross-entropy for regression → doesn't make sense
3. Numerical Stability
# Wrong: Manual log(softmax(x))
# Right: Use combined functions
nn.CrossEntropyLoss() # LogSoftmax + NLLLoss internally
nn.BCEWithLogitsLoss() # Sigmoid + BCE internally
Key Takeaways
- Loss function defines what "good" means for optimization
- MSE for regression, cross-entropy for classification
- Huber loss combines MSE + MAE benefits
- Focal loss helps with class imbalance
- Use LogSoftmax+NLL or BCEWithLogits for stability
- Combine losses for multi-objective tasks
- Match loss to your evaluation metric when possible