Model Drift and Monitoring
Model drift occurs when a deployed model's performance degrades over time. Understanding and detecting drift is essential for maintaining ML systems in production.
Types of Drift
Data Drift (Covariate Shift)
Input distribution changes:
Training: Users aged 25-40
Production: Users now include 18-65
P(X) changes, but P(Y|X) stays the same
Concept Drift
Relationship between input and output changes:
Training: "Click" meant interested
Now: Accidental clicks more common
P(Y|X) changes
Label Drift (Prior Probability Shift)
Target distribution changes:
Training: 5% fraud rate
Now: 10% fraud rate due to new attack
P(Y) changes
Feature Drift
Feature values or availability change:
Feature "age" now collected differently
New feature becomes available
Feature removed due to privacy
Why Drift Happens
External Changes
- Market conditions
- Competition
- Regulations
- Seasonal effects
- World events (pandemic)
Internal Changes
- System updates
- Data pipeline changes
- Collection methodology
- User interface changes
Natural Evolution
- User behavior evolves
- Preferences change
- New products/categories
Detecting Drift
Statistical Tests for Data Drift
from scipy import stats
# For continuous features
ks_stat, p_value = stats.ks_2samp(train_data, prod_data)
if p_value < 0.05:
print("Distribution has shifted!")
# For categorical features
chi2, p_value = stats.chisquare(observed, expected)
Common Methods
| Method | Type | Use Case |
|---|---|---|
| KS Test | Continuous | Single feature |
| Chi-Square | Categorical | Single feature |
| PSI | Any | Overall shift magnitude |
| KL Divergence | Any | Distribution comparison |
| Wasserstein | Continuous | Distribution distance |
Population Stability Index (PSI)
PSI = Σ (actual% - expected%) × ln(actual% / expected%)
PSI < 0.1: No significant shift
PSI 0.1-0.2: Moderate shift, investigate
PSI > 0.2: Significant shift, action needed
Detecting Concept Drift
# Monitor prediction accuracy over time
accuracy_window = calculate_accuracy(recent_predictions, true_labels)
if accuracy_window < baseline - threshold:
alert("Possible concept drift")
Monitoring Metrics
Model Performance
- Accuracy, precision, recall, F1
- AUC-ROC
- RMSE, MAE (regression)
Prediction Distribution
- Mean prediction
- Prediction variance
- Class distribution
Data Quality
- Missing values
- Outliers
- Feature ranges
Latency and Infrastructure
- Inference time
- Memory usage
- Error rates
Monitoring Architecture
[Production Model]
↓
[Predictions] → [Prediction Logger] → [Monitoring DB]
↓ ↓
[Downstream] [Dashboard + Alerts]
↓
[Feedback/Labels] → [Label Collector] → [Performance Metrics]
Tools for Monitoring
Open Source
- Evidently AI
- Alibi Detect
- Great Expectations
- Prometheus + Grafana
Cloud/Enterprise
- AWS SageMaker Model Monitor
- Azure ML Data Drift
- Google Cloud Vertex AI
- Weights & Biases
Handling Drift
Reactive Approaches
Retraining
Scheduled: Weekly/monthly retrain
Triggered: When drift exceeds threshold
Incremental: Update with new data
Full: Retrain from scratch
Model Updates
# Example: Incremental update
model.partial_fit(new_data)
# Or retrain and A/B test
new_model = train(historical + recent_data)
ab_test(old_model, new_model)
Proactive Approaches
Training Data Strategy
- Include diverse data
- Simulate potential shifts
- Use time-weighted sampling
Model Design
- Ensemble of models for different scenarios
- Domain adaptation techniques
- Robust training methods
Best Practices
1. Establish Baselines
baseline_metrics = {
'accuracy': 0.92,
'feature_means': {...},
'prediction_distribution': {...}
}
2. Set Alert Thresholds
alert_rules = {
'accuracy_drop': 0.05, # Alert if 5% drop
'psi_threshold': 0.2,
'latency_p99': 200 # ms
}
3. Create Feedback Loops
Prediction → User action → Implicit label
↓
Performance measurement
4. Version Everything
experiment = {
'model_version': 'v2.3.1',
'data_version': '2024-01-15',
'features': ['f1', 'f2', 'f3'],
'hyperparameters': {...}
}
5. Regular Review
- Weekly model performance review
- Monthly drift analysis
- Quarterly retraining strategy review
Example: Complete Monitoring Pipeline
class ModelMonitor:
def __init__(self, baseline_data, thresholds):
self.baseline = baseline_data
self.thresholds = thresholds
def check_data_drift(self, new_data):
psi_scores = {}
for feature in self.baseline.columns:
psi = calculate_psi(self.baseline[feature], new_data[feature])
psi_scores[feature] = psi
if psi > self.thresholds['psi']:
self.alert(f"Data drift in {feature}: PSI={psi:.3f}")
return psi_scores
def check_performance(self, predictions, labels):
metrics = calculate_metrics(predictions, labels)
for metric, value in metrics.items():
if value < self.baseline_metrics[metric] - self.thresholds[metric]:
self.alert(f"Performance drop: {metric}={value:.3f}")
return metrics
def alert(self, message):
send_slack_alert(message)
log_to_db(message)
Key Takeaways
- Drift is inevitable - models degrade over time
- Types: data drift, concept drift, label drift
- Use statistical tests (KS, PSI) to detect drift
- Monitor model performance, not just predictions
- Have a retraining strategy ready
- Invest in observability infrastructure early