mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-27 07:09:27 +08:00
feat(examples): add anomaly detection and covariates examples
Anomaly Detection Example: - Uses quantile forecasts as prediction intervals - Flags values outside 80%/90% CI as warnings/critical anomalies - Includes visualization with deviation plot Covariates (XReg) Example: - Demonstrates forecast_with_covariates() API - Shows dynamic numerical/categorical covariates - Shows static categorical covariates - Includes synthetic retail sales data with price, promotion, holiday SKILL.md Updates: - Added anomaly detection section with code example - Expanded covariates section with covariate types table - Added XReg modes explanation - Updated 'When not to use' section to note anomaly detection workaround
This commit is contained in:
@@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TimesFM Anomaly Detection Example
|
||||
|
||||
This example demonstrates how to use TimesFM's quantile forecasts for
|
||||
anomaly detection. The approach:
|
||||
1. Forecast with quantile intervals (10th-90th percentiles)
|
||||
2. Compare actual values against prediction intervals
|
||||
3. Flag values outside intervals as anomalies
|
||||
|
||||
TimesFM does NOT have built-in anomaly detection, but the quantile
|
||||
forecasts provide natural anomaly detection via prediction intervals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import timesfm
|
||||
|
||||
# Configuration
|
||||
HORIZON = 12 # Forecast horizon
|
||||
ANOMALY_THRESHOLD_WARNING = 0.80 # Outside 80% CI = warning
|
||||
ANOMALY_THRESHOLD_CRITICAL = 0.90 # Outside 90% CI = critical
|
||||
|
||||
EXAMPLE_DIR = Path(__file__).parent
|
||||
DATA_FILE = (
|
||||
Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
|
||||
)
|
||||
OUTPUT_DIR = EXAMPLE_DIR / "output"
|
||||
|
||||
|
||||
def inject_anomalies(
|
||||
values: np.ndarray, n_anomalies: int = 3, seed: int = 42
|
||||
) -> tuple[np.ndarray, list[int]]:
|
||||
"""Inject synthetic anomalies into the data for demonstration."""
|
||||
rng = np.random.default_rng(seed)
|
||||
anomaly_indices = rng.choice(len(values), size=n_anomalies, replace=False).tolist()
|
||||
|
||||
anomalous_values = values.copy()
|
||||
for idx in anomaly_indices:
|
||||
# Inject spike or dip (±40-60% of value)
|
||||
multiplier = rng.choice([0.4, 0.6]) * rng.choice([1, -1])
|
||||
anomalous_values[idx] = values[idx] * (1 + multiplier)
|
||||
|
||||
return anomalous_values, sorted(anomaly_indices)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("=" * 60)
|
||||
print(" TIMESFM ANOMALY DETECTION DEMO")
|
||||
print("=" * 60)
|
||||
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Load temperature data
|
||||
print("\n📊 Loading temperature anomaly data...")
|
||||
df = pd.read_csv(DATA_FILE, parse_dates=["date"])
|
||||
df = df.sort_values("date").reset_index(drop=True)
|
||||
|
||||
# Split into context (first 24 months) and test (last 12 months)
|
||||
context_values = df["anomaly_c"].values[:24].astype(np.float32)
|
||||
actual_future = df["anomaly_c"].values[24:36].astype(np.float32)
|
||||
dates_future = df["date"].values[24:36]
|
||||
|
||||
print(f" Context: 24 months (2022-01 to 2023-12)")
|
||||
print(f" Test: 12 months (2024-01 to 2024-12)")
|
||||
|
||||
# Inject anomalies into test data for demonstration
|
||||
print("\n🔬 Injecting synthetic anomalies for demonstration...")
|
||||
test_values_with_anomalies, anomaly_indices = inject_anomalies(
|
||||
actual_future, n_anomalies=3
|
||||
)
|
||||
print(f" Injected anomalies at months: {anomaly_indices}")
|
||||
|
||||
# Load TimesFM
|
||||
print("\n🤖 Loading TimesFM 1.0 (200M) PyTorch...")
|
||||
hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
|
||||
checkpoint = timesfm.TimesFmCheckpoint(
|
||||
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
|
||||
)
|
||||
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
|
||||
|
||||
# Forecast with quantiles
|
||||
print("\n📈 Forecasting with quantile intervals...")
|
||||
point_forecast, quantile_forecast = model.forecast(
|
||||
[context_values],
|
||||
freq=[0],
|
||||
)
|
||||
|
||||
# Extract quantiles
|
||||
# quantile_forecast shape: (1, 12, 10) - [mean, q10, q20, ..., q90]
|
||||
point = point_forecast[0]
|
||||
q10 = quantile_forecast[0, :, 0] # 10th percentile
|
||||
q20 = quantile_forecast[0, :, 1] # 20th percentile
|
||||
q50 = quantile_forecast[0, :, 4] # 50th percentile (median)
|
||||
q80 = quantile_forecast[0, :, 7] # 80th percentile
|
||||
q90 = quantile_forecast[0, :, 8] # 90th percentile
|
||||
|
||||
print(f" Forecast mean: {point.mean():.3f}°C")
|
||||
print(f" 90% CI width: {(q90 - q10).mean():.3f}°C (avg)")
|
||||
|
||||
# Detect anomalies
|
||||
print("\n🔍 Detecting anomalies...")
|
||||
anomalies = []
|
||||
for i, (actual, lower_80, upper_80, lower_90, upper_90) in enumerate(
|
||||
zip(test_values_with_anomalies, q20, q80, q10, q90)
|
||||
):
|
||||
month = dates_future[i]
|
||||
month_str = pd.to_datetime(month).strftime("%Y-%m")
|
||||
|
||||
if actual < lower_90 or actual > upper_90:
|
||||
severity = "CRITICAL"
|
||||
threshold = "90% CI"
|
||||
color = "red"
|
||||
elif actual < lower_80 or actual > upper_80:
|
||||
severity = "WARNING"
|
||||
threshold = "80% CI"
|
||||
color = "orange"
|
||||
else:
|
||||
severity = "NORMAL"
|
||||
threshold = "within bounds"
|
||||
color = "green"
|
||||
|
||||
anomalies.append(
|
||||
{
|
||||
"month": month_str,
|
||||
"actual": float(actual),
|
||||
"forecast": float(point[i]),
|
||||
"lower_80": float(lower_80),
|
||||
"upper_80": float(upper_80),
|
||||
"lower_90": float(lower_90),
|
||||
"upper_90": float(upper_90),
|
||||
"severity": severity,
|
||||
"threshold": threshold,
|
||||
"color": color,
|
||||
}
|
||||
)
|
||||
|
||||
if severity != "NORMAL":
|
||||
deviation = abs(actual - point[i])
|
||||
print(
|
||||
f" [{severity}] {month_str}: {actual:.2f}°C (forecast: {point[i]:.2f}°C, deviation: {deviation:.2f}°C)"
|
||||
)
|
||||
|
||||
# Create visualization
|
||||
print("\n📊 Creating anomaly visualization...")
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
|
||||
|
||||
# Plot 1: Full time series with forecast and anomalies
|
||||
ax1 = axes[0]
|
||||
|
||||
# Historical data
|
||||
historical_dates = df["date"].values[:24]
|
||||
ax1.plot(
|
||||
historical_dates,
|
||||
context_values,
|
||||
"b-",
|
||||
linewidth=2,
|
||||
label="Historical Data",
|
||||
marker="o",
|
||||
markersize=4,
|
||||
)
|
||||
|
||||
# Actual future (with anomalies)
|
||||
ax1.plot(
|
||||
dates_future,
|
||||
actual_future,
|
||||
"g--",
|
||||
linewidth=1.5,
|
||||
label="Actual (clean)",
|
||||
alpha=0.5,
|
||||
)
|
||||
ax1.plot(
|
||||
dates_future,
|
||||
test_values_with_anomalies,
|
||||
"ko",
|
||||
markersize=8,
|
||||
label="Actual (with anomalies)",
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Forecast
|
||||
ax1.plot(
|
||||
dates_future,
|
||||
point,
|
||||
"r-",
|
||||
linewidth=2,
|
||||
label="Forecast (median)",
|
||||
marker="s",
|
||||
markersize=6,
|
||||
)
|
||||
|
||||
# 90% CI
|
||||
ax1.fill_between(dates_future, q10, q90, alpha=0.2, color="red", label="90% CI")
|
||||
|
||||
# 80% CI
|
||||
ax1.fill_between(dates_future, q20, q80, alpha=0.3, color="red", label="80% CI")
|
||||
|
||||
# Highlight anomalies
|
||||
for anomaly in anomalies:
|
||||
if anomaly["severity"] != "NORMAL":
|
||||
idx = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future].index(
|
||||
anomaly["month"]
|
||||
)
|
||||
ax1.scatter(
|
||||
[dates_future[idx]],
|
||||
[test_values_with_anomalies[idx]],
|
||||
c=anomaly["color"],
|
||||
s=200,
|
||||
marker="x" if anomaly["severity"] == "CRITICAL" else "^",
|
||||
linewidths=3,
|
||||
zorder=5,
|
||||
)
|
||||
|
||||
ax1.set_xlabel("Date", fontsize=12)
|
||||
ax1.set_ylabel("Temperature Anomaly (°C)", fontsize=12)
|
||||
ax1.set_title(
|
||||
"TimesFM Anomaly Detection: Forecast Intervals Method",
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax1.legend(loc="upper left", fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Add annotation for anomalies
|
||||
ax1.annotate(
|
||||
"× = Critical (outside 90% CI)\n▲ = Warning (outside 80% CI)",
|
||||
xy=(0.98, 0.02),
|
||||
xycoords="axes fraction",
|
||||
ha="right",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
|
||||
)
|
||||
|
||||
# Plot 2: Deviation from forecast with thresholds
|
||||
ax2 = axes[1]
|
||||
|
||||
deviation = test_values_with_anomalies - point
|
||||
lower_90_dev = q10 - point
|
||||
upper_90_dev = q90 - point
|
||||
lower_80_dev = q20 - point
|
||||
upper_80_dev = q80 - point
|
||||
|
||||
months = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future]
|
||||
x = np.arange(len(months))
|
||||
|
||||
# Threshold bands
|
||||
ax2.fill_between(
|
||||
x, lower_90_dev, upper_90_dev, alpha=0.2, color="red", label="90% CI bounds"
|
||||
)
|
||||
ax2.fill_between(
|
||||
x, lower_80_dev, upper_80_dev, alpha=0.3, color="red", label="80% CI bounds"
|
||||
)
|
||||
|
||||
# Deviation bars
|
||||
colors = [
|
||||
"red"
|
||||
if d < lower_90_dev[i] or d > upper_90_dev[i]
|
||||
else "orange"
|
||||
if d < lower_80_dev[i] or d > upper_80_dev[i]
|
||||
else "green"
|
||||
for i, d in enumerate(deviation)
|
||||
]
|
||||
ax2.bar(x, deviation, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5)
|
||||
|
||||
# Zero line
|
||||
ax2.axhline(y=0, color="black", linestyle="-", linewidth=1)
|
||||
|
||||
ax2.set_xlabel("Month", fontsize=12)
|
||||
ax2.set_ylabel("Deviation from Forecast (°C)", fontsize=12)
|
||||
ax2.set_title(
|
||||
"Deviation from Forecast with Anomaly Thresholds",
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax2.set_xticks(x)
|
||||
ax2.set_xticklabels(months, rotation=45, ha="right")
|
||||
ax2.legend(loc="upper right", fontsize=10)
|
||||
ax2.grid(True, alpha=0.3, axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
output_path = OUTPUT_DIR / "anomaly_detection.png"
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
print(f" Saved: {output_path}")
|
||||
plt.close()
|
||||
|
||||
# Save results
|
||||
results = {
|
||||
"method": "quantile_intervals",
|
||||
"description": "Anomaly detection using TimesFM quantile forecasts as prediction intervals",
|
||||
"thresholds": {
|
||||
"warning": f"Outside {ANOMALY_THRESHOLD_WARNING * 100:.0f}% CI (q20-q80)",
|
||||
"critical": f"Outside {ANOMALY_THRESHOLD_CRITICAL * 100:.0f}% CI (q10-q90)",
|
||||
},
|
||||
"anomalies": anomalies,
|
||||
"summary": {
|
||||
"total_points": len(anomalies),
|
||||
"critical": sum(1 for a in anomalies if a["severity"] == "CRITICAL"),
|
||||
"warning": sum(1 for a in anomalies if a["severity"] == "WARNING"),
|
||||
"normal": sum(1 for a in anomalies if a["severity"] == "NORMAL"),
|
||||
},
|
||||
}
|
||||
|
||||
results_path = OUTPUT_DIR / "anomaly_detection.json"
|
||||
with open(results_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f" Saved: {results_path}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✅ ANOMALY DETECTION COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"\n📊 Summary:")
|
||||
print(f" Total test points: {results['summary']['total_points']}")
|
||||
print(f" Critical anomalies: {results['summary']['critical']} (outside 90% CI)")
|
||||
print(f" Warnings: {results['summary']['warning']} (outside 80% CI)")
|
||||
print(f" Normal: {results['summary']['normal']}")
|
||||
|
||||
print("\n💡 How It Works:")
|
||||
print(" 1. TimesFM forecasts with quantile intervals (q10, q20, ..., q90)")
|
||||
print(" 2. If actual value falls outside 90% CI → CRITICAL anomaly")
|
||||
print(" 3. If actual value falls outside 80% CI → WARNING")
|
||||
print(" 4. Otherwise → NORMAL")
|
||||
|
||||
print("\n📁 Output Files:")
|
||||
print(f" {output_path}")
|
||||
print(f" {results_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user