beginnerOptimization

Learn how early stopping prevents overfitting by halting training when validation performance stops improving.

regularizationtrainingoverfittingoptimization

Early Stopping

Early stopping is a regularization technique that stops training when the model's performance on a validation set stops improving, preventing overfitting.

The Overfitting Problem

Loss
  │
  │    Training ─────────────────────
  │         ╲
  │          ╲    Validation
  │           ╲      ╱ ╲
  │            ╲    ╱   ╲ ← Overfitting
  │             ╲  ╱     ╲  starts here
  │              ╲╱       ╲
  │               │        ╲
  └───────────────┼─────────────────► Epochs
                  │
              Optimal
              stopping
              point

How Early Stopping Works

  1. Monitor a validation metric each epoch
  2. Track the best value seen so far
  3. Count epochs without improvement (patience)
  4. Stop training when patience is exceeded
  5. Restore the best model weights

Implementation

Basic Early Stopping

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_value = None
        self.best_weights = None
        self.should_stop = False
        
    def __call__(self, value, model):
        if self.best_value is None:
            self.best_value = value
            self.best_weights = model.state_dict().copy()
            return False
            
        improved = False
        if self.mode == 'min':
            improved = value < self.best_value - self.min_delta
        else:
            improved = value > self.best_value + self.min_delta
            
        if improved:
            self.best_value = value
            self.best_weights = model.state_dict().copy()
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                
        return self.should_stop

Usage

early_stopper = EarlyStopping(patience=10, min_delta=0.001, mode='min')

for epoch in range(max_epochs):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)
    
    if early_stopper(val_loss, model):
        print(f"Early stopping at epoch {epoch}")
        break

# Restore best weights
model.load_state_dict(early_stopper.best_weights)

Key Parameters

Patience

Number of epochs to wait without improvement:

Dataset SizeTypical Patience
Small (< 10K)5-10 epochs
Medium10-20 epochs
Large (> 1M)20-50 epochs

min_delta

Minimum change to qualify as an improvement:

# Without min_delta: tiny improvements reset patience
val_loss: 0.500 → 0.499 → 0.498  # Resets each time

# With min_delta=0.01: requires meaningful improvement
val_loss: 0.500 → 0.499 → 0.498  # Counter: 0, 1, 2

What to Monitor

TaskMetricMode
Regressionval_loss, MSEmin
Classificationval_loss, accuracymin/max
RankingNDCG, AUCmax

Framework Implementations

PyTorch Lightning

from pytorch_lightning.callbacks import EarlyStopping

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=10,
    min_delta=0.001,
    mode='min'
)

trainer = Trainer(callbacks=[early_stop])

Keras

from tensorflow.keras.callbacks import EarlyStopping

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=10,
    min_delta=0.001,
    mode='min',
    restore_best_weights=True
)

model.fit(X, y, callbacks=[early_stop])

scikit-learn

# For iterative models like GradientBoosting
from sklearn.ensemble import GradientBoostingClassifier

model = GradientBoostingClassifier(
    n_estimators=1000,
    validation_fraction=0.1,
    n_iter_no_change=10,  # patience
    tol=0.001  # min_delta
)

Early Stopping as Regularization

Early stopping implicitly constrains model complexity:

Effective capacity = f(training time)

Longer training → More complex model
                → Can fit training data better
                → Higher risk of overfitting

Comparison to Other Regularization

TechniqueHow It Regularizes
Early StoppingLimits training time
L2/Weight DecayPenalizes large weights
DropoutRandom feature removal
Data AugmentationIncreases data diversity

Best Practices

1. Use a Proper Validation Set

# Don't use test set for early stopping!
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5)

# Train: 70%, Val: 15%, Test: 15%

2. Save Best Model Weights

# Always restore best weights after stopping
if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(model.state_dict(), 'best_model.pt')

# After training
model.load_state_dict(torch.load('best_model.pt'))

3. Combine with Learning Rate Scheduling

# Reduce LR before stopping completely
scheduler = ReduceLROnPlateau(optimizer, patience=5)  # patience < early_stopping
early_stopper = EarlyStopping(patience=15)

for epoch in range(max_epochs):
    train_loss = train(model)
    val_loss = validate(model)
    
    scheduler.step(val_loss)  # Reduce LR first
    if early_stopper(val_loss, model):  # Then stop if needed
        break

4. Log Training Curves

import matplotlib.pyplot as plt

plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.axvline(best_epoch, color='r', linestyle='--', label='Best epoch')
plt.legend()
plt.savefig('training_curves.png')

Common Pitfalls

1. Patience Too Low

Val loss: 0.50 → 0.48 → 0.49 → 0.47 → STOPPED
                       ↑
                Could have improved more!

2. Not Restoring Best Weights

# Wrong: using final weights instead of best
model.eval()  # Model at epoch 100, but best was epoch 80

3. Monitoring Training Loss

# Wrong: monitoring training loss doesn't detect overfitting
early_stopper = EarlyStopping(monitor='train_loss')  # Bad!

Key Takeaways

  1. Early stopping prevents overfitting by limiting training time
  2. Monitor validation loss, not training loss
  3. Always save and restore the best model weights
  4. Patience should be 5-20 epochs depending on dataset size
  5. Use min_delta to require meaningful improvements
  6. Combine with learning rate scheduling for best results