Early Stopping
Early stopping is a regularization technique that stops training when the model's performance on a validation set stops improving, preventing overfitting.
The Overfitting Problem
Loss
│
│ Training ─────────────────────
│ ╲
│ ╲ Validation
│ ╲ ╱ ╲
│ ╲ ╱ ╲ ← Overfitting
│ ╲ ╱ ╲ starts here
│ ╲╱ ╲
│ │ ╲
└───────────────┼─────────────────► Epochs
│
Optimal
stopping
point
How Early Stopping Works
- Monitor a validation metric each epoch
- Track the best value seen so far
- Count epochs without improvement (patience)
- Stop training when patience is exceeded
- Restore the best model weights
Implementation
Basic Early Stopping
class EarlyStopping:
def __init__(self, patience=5, min_delta=0.0, mode='min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_value = None
self.best_weights = None
self.should_stop = False
def __call__(self, value, model):
if self.best_value is None:
self.best_value = value
self.best_weights = model.state_dict().copy()
return False
improved = False
if self.mode == 'min':
improved = value < self.best_value - self.min_delta
else:
improved = value > self.best_value + self.min_delta
if improved:
self.best_value = value
self.best_weights = model.state_dict().copy()
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
Usage
early_stopper = EarlyStopping(patience=10, min_delta=0.001, mode='min')
for epoch in range(max_epochs):
train_loss = train_one_epoch(model, train_loader)
val_loss = validate(model, val_loader)
if early_stopper(val_loss, model):
print(f"Early stopping at epoch {epoch}")
break
# Restore best weights
model.load_state_dict(early_stopper.best_weights)
Key Parameters
Patience
Number of epochs to wait without improvement:
| Dataset Size | Typical Patience |
|---|---|
| Small (< 10K) | 5-10 epochs |
| Medium | 10-20 epochs |
| Large (> 1M) | 20-50 epochs |
min_delta
Minimum change to qualify as an improvement:
# Without min_delta: tiny improvements reset patience
val_loss: 0.500 → 0.499 → 0.498 # Resets each time
# With min_delta=0.01: requires meaningful improvement
val_loss: 0.500 → 0.499 → 0.498 # Counter: 0, 1, 2
What to Monitor
| Task | Metric | Mode |
|---|---|---|
| Regression | val_loss, MSE | min |
| Classification | val_loss, accuracy | min/max |
| Ranking | NDCG, AUC | max |
Framework Implementations
PyTorch Lightning
from pytorch_lightning.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.001,
mode='min'
)
trainer = Trainer(callbacks=[early_stop])
Keras
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.001,
mode='min',
restore_best_weights=True
)
model.fit(X, y, callbacks=[early_stop])
scikit-learn
# For iterative models like GradientBoosting
from sklearn.ensemble import GradientBoostingClassifier
model = GradientBoostingClassifier(
n_estimators=1000,
validation_fraction=0.1,
n_iter_no_change=10, # patience
tol=0.001 # min_delta
)
Early Stopping as Regularization
Early stopping implicitly constrains model complexity:
Effective capacity = f(training time)
Longer training → More complex model
→ Can fit training data better
→ Higher risk of overfitting
Comparison to Other Regularization
| Technique | How It Regularizes |
|---|---|
| Early Stopping | Limits training time |
| L2/Weight Decay | Penalizes large weights |
| Dropout | Random feature removal |
| Data Augmentation | Increases data diversity |
Best Practices
1. Use a Proper Validation Set
# Don't use test set for early stopping!
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5)
# Train: 70%, Val: 15%, Test: 15%
2. Save Best Model Weights
# Always restore best weights after stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pt')
# After training
model.load_state_dict(torch.load('best_model.pt'))
3. Combine with Learning Rate Scheduling
# Reduce LR before stopping completely
scheduler = ReduceLROnPlateau(optimizer, patience=5) # patience < early_stopping
early_stopper = EarlyStopping(patience=15)
for epoch in range(max_epochs):
train_loss = train(model)
val_loss = validate(model)
scheduler.step(val_loss) # Reduce LR first
if early_stopper(val_loss, model): # Then stop if needed
break
4. Log Training Curves
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.axvline(best_epoch, color='r', linestyle='--', label='Best epoch')
plt.legend()
plt.savefig('training_curves.png')
Common Pitfalls
1. Patience Too Low
Val loss: 0.50 → 0.48 → 0.49 → 0.47 → STOPPED
↑
Could have improved more!
2. Not Restoring Best Weights
# Wrong: using final weights instead of best
model.eval() # Model at epoch 100, but best was epoch 80
3. Monitoring Training Loss
# Wrong: monitoring training loss doesn't detect overfitting
early_stopper = EarlyStopping(monitor='train_loss') # Bad!
Key Takeaways
- Early stopping prevents overfitting by limiting training time
- Monitor validation loss, not training loss
- Always save and restore the best model weights
- Patience should be 5-20 epochs depending on dataset size
- Use min_delta to require meaningful improvements
- Combine with learning rate scheduling for best results