Drift Explain¶
The Drift Explain module (introduced in v0.3.0) provides detailed insights into why drift was detected and how distributions have shifted.
Overview¶
When drift is detected, you often want to understand:
- How much did the distribution shift?
- Which statistics changed the most?
- What does it look like visually?
DriftWatch provides two main classes for this:
| Class | Purpose |
|---|---|
DriftExplainer |
Detailed statistical analysis |
DriftVisualizer |
Histogram overlays and plots |
Installation¶
For visualization support, install with the viz extra:
DriftExplainer¶
The DriftExplainer provides detailed statistics for understanding drift.
Basic Usage¶
from driftwatch import Monitor
from driftwatch.explain import DriftExplainer
# Setup
monitor = Monitor(reference_data=train_df)
report = monitor.check(production_df)
# Explain drift
explainer = DriftExplainer(train_df, production_df, report)
explanation = explainer.explain()
# Display summary
print(explanation.summary())
Statistics Provided¶
For each numeric feature, you get:
| Statistic | Description |
|---|---|
mean_shift |
Absolute change in mean |
mean_shift_percent |
Relative change (%) |
std_change |
Absolute change in standard deviation |
std_change_percent |
Relative change (%) |
ref_min, prod_min |
Minimum values |
ref_max, prod_max |
Maximum values |
quantile_stats |
Q25, Q50, Q75 comparisons |
Example Output¶
━━━ age ━━━
Status: 🔴 DRIFT DETECTED
Score (psi): 2.9555
📊 Central Tendency:
Mean: 30.0234 → 40.1567 (+33.75%)
📈 Spread:
Std: 5.0123 → 4.9876 (-0.49%)
📏 Range:
Min: 15.2341 → 25.1234
Max: 44.8765 → 55.3421
📐 Quantiles:
Q25: 26.5432 → 36.4321 (+37.25%)
Q50: 29.8765 → 40.0123 (+33.93%)
Q75: 33.2109 → 43.7654 (+31.79%)
Explain Single Feature¶
# Get explanation for specific feature
age_exp = explainer.explain_feature("age")
print(f"Mean shifted by {age_exp.mean_shift_percent:.1f}%")
print(f"Std changed by {age_exp.std_change_percent:.1f}%")
Custom Quantiles¶
# Analyze different quantiles
explainer = DriftExplainer(
train_df,
production_df,
report,
quantiles=[0.1, 0.5, 0.9] # 10th, 50th, 90th percentiles
)
Export to JSON¶
import json
# Export for logging/analysis
data = explanation.to_dict()
print(json.dumps(data, indent=2, default=str))
DriftVisualizer¶
The DriftVisualizer creates histogram overlays to visualize distribution shifts.
Basic Usage¶
from driftwatch.explain import DriftVisualizer
import matplotlib.pyplot as plt
viz = DriftVisualizer(train_df, production_df, report)
# Plot single feature
fig = viz.plot_feature("age")
plt.show()
Plot All Features¶
Customization¶
# Customize appearance
fig = viz.plot_feature(
"age",
bins=30, # Number of histogram bins
figsize=(12, 8), # Figure size
show_stats=True, # Show stats box
alpha=0.7 # Transparency
)
Save to File¶
# Save single feature
viz.save("age_drift.png", feature_name="age", dpi=150)
# Save all features
viz.save("drift_report.png", dpi=150)
viz.save("drift_report.pdf") # Vector format
Complete Example¶
import pandas as pd
import numpy as np
from driftwatch import Monitor
from driftwatch.explain import DriftExplainer, DriftVisualizer
# Generate data
np.random.seed(42)
train_df = pd.DataFrame({
'age': np.random.normal(30, 5, 1000),
'income': np.random.normal(50000, 10000, 1000),
})
prod_df = pd.DataFrame({
'age': np.random.normal(40, 5, 1000), # Drift!
'income': np.random.normal(50000, 10000, 1000),
})
# Detect drift
monitor = Monitor(reference_data=train_df)
report = monitor.check(prod_df)
# Explain
explainer = DriftExplainer(train_df, prod_df, report)
explanation = explainer.explain()
print(explanation.summary())
# Visualize
viz = DriftVisualizer(train_df, prod_df, report)
viz.save("drift_analysis.png")
API Reference¶
driftwatch.explain.DriftExplainer
¶
DriftExplainer(reference_data: DataFrame, production_data: DataFrame, report: DriftReport, quantiles: list[float] | None = None)
Explains drift detection results with detailed statistics.
The explainer takes reference and production DataFrames along with a DriftReport and provides detailed insights into why drift was detected and how distributions have shifted.
Example
from driftwatch import Monitor from driftwatch.explain import DriftExplainer
monitor = Monitor(reference_data=train_df) report = monitor.check(prod_df)
explainer = DriftExplainer(train_df, prod_df, report) explanation = explainer.explain() print(explanation.summary())
Initialize the DriftExplainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reference_data
|
DataFrame
|
Reference DataFrame (training data) |
required |
production_data
|
DataFrame
|
Production DataFrame to explain |
required |
report
|
DriftReport
|
DriftReport from Monitor.check() |
required |
quantiles
|
list[float] | None
|
List of quantiles to analyze (default: [0.25, 0.5, 0.75]) |
None
|
explain
¶
Generate detailed explanations for all features.
Returns:
| Type | Description |
|---|---|
DriftExplanation
|
DriftExplanation containing per-feature statistical analysis |
explain_feature
¶
Generate detailed explanation for a single feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_name
|
str
|
Name of the feature to explain |
required |
Returns:
| Type | Description |
|---|---|
FeatureExplanation | None
|
FeatureExplanation or None if feature not found |
driftwatch.explain.DriftVisualizer
¶
DriftVisualizer(reference_data: DataFrame, production_data: DataFrame, report: DriftReport, style: str = 'seaborn-v0_8-whitegrid', colors: dict[str, str] | None = None)
Visualizes drift between reference and production distributions.
Creates matplotlib figures showing distribution overlays, helping users understand exactly how data has shifted.
Example
from driftwatch import Monitor from driftwatch.explain import DriftVisualizer
monitor = Monitor(reference_data=train_df) report = monitor.check(prod_df)
viz = DriftVisualizer(train_df, prod_df, report) fig = viz.plot_feature("age") plt.show()
Or plot all features¶
fig = viz.plot_all() plt.savefig("drift_report.png")
Initialize the DriftVisualizer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reference_data
|
DataFrame
|
Reference DataFrame (training data) |
required |
production_data
|
DataFrame
|
Production DataFrame |
required |
report
|
DriftReport
|
DriftReport from Monitor.check() |
required |
style
|
str
|
Matplotlib style to use (default: seaborn-v0_8-whitegrid) |
'seaborn-v0_8-whitegrid'
|
colors
|
dict[str, str] | None
|
Custom color scheme override |
None
|
plot_feature
¶
plot_feature(feature_name: str, bins: int = 50, figsize: tuple[int, int] = (10, 6), show_stats: bool = True, alpha: float = 0.6, colors: dict[str, str] | None = None, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None, hist_kwargs: dict[str, Any] | None = None, stats_kwargs: dict[str, Any] | None = None) -> Any
Plot histogram overlay for a single feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_name
|
str
|
Name of the feature to plot |
required |
bins
|
int
|
Number of histogram bins |
50
|
figsize
|
tuple[int, int]
|
Figure size (width, height) |
(10, 6)
|
show_stats
|
bool
|
Whether to show statistical annotations |
True
|
alpha
|
float
|
Transparency of histograms |
0.6
|
colors
|
dict[str, str] | None
|
Custom colors for this plot (keys: reference, production) |
None
|
title
|
str | None
|
Custom title override |
None
|
xlabel
|
str | None
|
Custom x-axis label |
None
|
ylabel
|
str | None
|
Custom y-axis label |
None
|
hist_kwargs
|
dict[str, Any] | None
|
Additional arguments passed to ax.hist |
None
|
stats_kwargs
|
dict[str, Any] | None
|
Additional arguments passed to the stats text box |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
matplotlib Figure object |
Raises:
| Type | Description |
|---|---|
ImportError
|
If matplotlib is not installed |
ValueError
|
If feature not found in data |
plot_all
¶
plot_all(cols: int = 2, figsize: tuple[int, int] | None = None, bins: int = 50, alpha: float = 0.6, colors: dict[str, str] | None = None, hist_kwargs: dict[str, Any] | None = None) -> Any
Plot histogram overlays for all numeric features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cols
|
int
|
Number of columns in the grid |
2
|
figsize
|
tuple[int, int] | None
|
Figure size (auto-calculated if None) |
None
|
bins
|
int
|
Number of histogram bins |
50
|
alpha
|
float
|
Transparency of histograms |
0.6
|
colors
|
dict[str, str] | None
|
Custom colors for this plot (keys: reference, production) |
None
|
hist_kwargs
|
dict[str, Any] | None
|
Additional arguments passed to ax.hist |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
matplotlib Figure object |
save
¶
Save visualization to file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filename
|
str
|
Output filename (supports png, pdf, svg) |
required |
feature_name
|
str | None
|
Specific feature to plot (or all if None) |
None
|
dpi
|
int
|
Resolution for raster formats |
150
|
**kwargs
|
Any
|
Additional arguments passed to savefig |
{}
|
Returns:
| Type | Description |
|---|---|
str
|
The filename that was saved |