Files
claude-scientific-skills/scientific-skills/timesfm-forecasting/examples/anomaly-detection/detect_anomalies.py
Clayton Young 0d98fa353c feat(examples): add anomaly detection and covariates examples
Anomaly Detection Example:
- Uses quantile forecasts as prediction intervals
- Flags values outside 80%/90% CI as warnings/critical anomalies
- Includes visualization with deviation plot

Covariates (XReg) Example:
- Demonstrates forecast_with_covariates() API
- Shows dynamic numerical/categorical covariates
- Shows static categorical covariates
- Includes synthetic retail sales data with price, promotion, holiday

SKILL.md Updates:
- Added anomaly detection section with code example
- Expanded covariates section with covariate types table
- Added XReg modes explanation
- Updated 'When not to use' section to note anomaly detection workaround
2026-02-23 07:43:04 -05:00

340 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
TimesFM Anomaly Detection Example
This example demonstrates how to use TimesFM's quantile forecasts for
anomaly detection. The approach:
1. Forecast with quantile intervals (10th-90th percentiles)
2. Compare actual values against prediction intervals
3. Flag values outside intervals as anomalies
TimesFM does NOT have built-in anomaly detection, but the quantile
forecasts provide natural anomaly detection via prediction intervals.
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timesfm
# Configuration
HORIZON = 12 # Forecast horizon
ANOMALY_THRESHOLD_WARNING = 0.80 # Outside 80% CI = warning
ANOMALY_THRESHOLD_CRITICAL = 0.90 # Outside 90% CI = critical
EXAMPLE_DIR = Path(__file__).parent
DATA_FILE = (
Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
)
OUTPUT_DIR = EXAMPLE_DIR / "output"
def inject_anomalies(
values: np.ndarray, n_anomalies: int = 3, seed: int = 42
) -> tuple[np.ndarray, list[int]]:
"""Inject synthetic anomalies into the data for demonstration."""
rng = np.random.default_rng(seed)
anomaly_indices = rng.choice(len(values), size=n_anomalies, replace=False).tolist()
anomalous_values = values.copy()
for idx in anomaly_indices:
# Inject spike or dip (±40-60% of value)
multiplier = rng.choice([0.4, 0.6]) * rng.choice([1, -1])
anomalous_values[idx] = values[idx] * (1 + multiplier)
return anomalous_values, sorted(anomaly_indices)
def main() -> None:
print("=" * 60)
print(" TIMESFM ANOMALY DETECTION DEMO")
print("=" * 60)
OUTPUT_DIR.mkdir(exist_ok=True)
# Load temperature data
print("\n📊 Loading temperature anomaly data...")
df = pd.read_csv(DATA_FILE, parse_dates=["date"])
df = df.sort_values("date").reset_index(drop=True)
# Split into context (first 24 months) and test (last 12 months)
context_values = df["anomaly_c"].values[:24].astype(np.float32)
actual_future = df["anomaly_c"].values[24:36].astype(np.float32)
dates_future = df["date"].values[24:36]
print(f" Context: 24 months (2022-01 to 2023-12)")
print(f" Test: 12 months (2024-01 to 2024-12)")
# Inject anomalies into test data for demonstration
print("\n🔬 Injecting synthetic anomalies for demonstration...")
test_values_with_anomalies, anomaly_indices = inject_anomalies(
actual_future, n_anomalies=3
)
print(f" Injected anomalies at months: {anomaly_indices}")
# Load TimesFM
print("\n🤖 Loading TimesFM 1.0 (200M) PyTorch...")
hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
checkpoint = timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
# Forecast with quantiles
print("\n📈 Forecasting with quantile intervals...")
point_forecast, quantile_forecast = model.forecast(
[context_values],
freq=[0],
)
# Extract quantiles
# quantile_forecast shape: (1, 12, 10) - [mean, q10, q20, ..., q90]
point = point_forecast[0]
q10 = quantile_forecast[0, :, 0] # 10th percentile
q20 = quantile_forecast[0, :, 1] # 20th percentile
q50 = quantile_forecast[0, :, 4] # 50th percentile (median)
q80 = quantile_forecast[0, :, 7] # 80th percentile
q90 = quantile_forecast[0, :, 8] # 90th percentile
print(f" Forecast mean: {point.mean():.3f}°C")
print(f" 90% CI width: {(q90 - q10).mean():.3f}°C (avg)")
# Detect anomalies
print("\n🔍 Detecting anomalies...")
anomalies = []
for i, (actual, lower_80, upper_80, lower_90, upper_90) in enumerate(
zip(test_values_with_anomalies, q20, q80, q10, q90)
):
month = dates_future[i]
month_str = pd.to_datetime(month).strftime("%Y-%m")
if actual < lower_90 or actual > upper_90:
severity = "CRITICAL"
threshold = "90% CI"
color = "red"
elif actual < lower_80 or actual > upper_80:
severity = "WARNING"
threshold = "80% CI"
color = "orange"
else:
severity = "NORMAL"
threshold = "within bounds"
color = "green"
anomalies.append(
{
"month": month_str,
"actual": float(actual),
"forecast": float(point[i]),
"lower_80": float(lower_80),
"upper_80": float(upper_80),
"lower_90": float(lower_90),
"upper_90": float(upper_90),
"severity": severity,
"threshold": threshold,
"color": color,
}
)
if severity != "NORMAL":
deviation = abs(actual - point[i])
print(
f" [{severity}] {month_str}: {actual:.2f}°C (forecast: {point[i]:.2f}°C, deviation: {deviation:.2f}°C)"
)
# Create visualization
print("\n📊 Creating anomaly visualization...")
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# Plot 1: Full time series with forecast and anomalies
ax1 = axes[0]
# Historical data
historical_dates = df["date"].values[:24]
ax1.plot(
historical_dates,
context_values,
"b-",
linewidth=2,
label="Historical Data",
marker="o",
markersize=4,
)
# Actual future (with anomalies)
ax1.plot(
dates_future,
actual_future,
"g--",
linewidth=1.5,
label="Actual (clean)",
alpha=0.5,
)
ax1.plot(
dates_future,
test_values_with_anomalies,
"ko",
markersize=8,
label="Actual (with anomalies)",
alpha=0.7,
)
# Forecast
ax1.plot(
dates_future,
point,
"r-",
linewidth=2,
label="Forecast (median)",
marker="s",
markersize=6,
)
# 90% CI
ax1.fill_between(dates_future, q10, q90, alpha=0.2, color="red", label="90% CI")
# 80% CI
ax1.fill_between(dates_future, q20, q80, alpha=0.3, color="red", label="80% CI")
# Highlight anomalies
for anomaly in anomalies:
if anomaly["severity"] != "NORMAL":
idx = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future].index(
anomaly["month"]
)
ax1.scatter(
[dates_future[idx]],
[test_values_with_anomalies[idx]],
c=anomaly["color"],
s=200,
marker="x" if anomaly["severity"] == "CRITICAL" else "^",
linewidths=3,
zorder=5,
)
ax1.set_xlabel("Date", fontsize=12)
ax1.set_ylabel("Temperature Anomaly (°C)", fontsize=12)
ax1.set_title(
"TimesFM Anomaly Detection: Forecast Intervals Method",
fontsize=14,
fontweight="bold",
)
ax1.legend(loc="upper left", fontsize=10)
ax1.grid(True, alpha=0.3)
# Add annotation for anomalies
ax1.annotate(
"× = Critical (outside 90% CI)\n▲ = Warning (outside 80% CI)",
xy=(0.98, 0.02),
xycoords="axes fraction",
ha="right",
va="bottom",
fontsize=10,
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
)
# Plot 2: Deviation from forecast with thresholds
ax2 = axes[1]
deviation = test_values_with_anomalies - point
lower_90_dev = q10 - point
upper_90_dev = q90 - point
lower_80_dev = q20 - point
upper_80_dev = q80 - point
months = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future]
x = np.arange(len(months))
# Threshold bands
ax2.fill_between(
x, lower_90_dev, upper_90_dev, alpha=0.2, color="red", label="90% CI bounds"
)
ax2.fill_between(
x, lower_80_dev, upper_80_dev, alpha=0.3, color="red", label="80% CI bounds"
)
# Deviation bars
colors = [
"red"
if d < lower_90_dev[i] or d > upper_90_dev[i]
else "orange"
if d < lower_80_dev[i] or d > upper_80_dev[i]
else "green"
for i, d in enumerate(deviation)
]
ax2.bar(x, deviation, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5)
# Zero line
ax2.axhline(y=0, color="black", linestyle="-", linewidth=1)
ax2.set_xlabel("Month", fontsize=12)
ax2.set_ylabel("Deviation from Forecast (°C)", fontsize=12)
ax2.set_title(
"Deviation from Forecast with Anomaly Thresholds",
fontsize=14,
fontweight="bold",
)
ax2.set_xticks(x)
ax2.set_xticklabels(months, rotation=45, ha="right")
ax2.legend(loc="upper right", fontsize=10)
ax2.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
output_path = OUTPUT_DIR / "anomaly_detection.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f" Saved: {output_path}")
plt.close()
# Save results
results = {
"method": "quantile_intervals",
"description": "Anomaly detection using TimesFM quantile forecasts as prediction intervals",
"thresholds": {
"warning": f"Outside {ANOMALY_THRESHOLD_WARNING * 100:.0f}% CI (q20-q80)",
"critical": f"Outside {ANOMALY_THRESHOLD_CRITICAL * 100:.0f}% CI (q10-q90)",
},
"anomalies": anomalies,
"summary": {
"total_points": len(anomalies),
"critical": sum(1 for a in anomalies if a["severity"] == "CRITICAL"),
"warning": sum(1 for a in anomalies if a["severity"] == "WARNING"),
"normal": sum(1 for a in anomalies if a["severity"] == "NORMAL"),
},
}
results_path = OUTPUT_DIR / "anomaly_detection.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved: {results_path}")
# Print summary
print("\n" + "=" * 60)
print(" ✅ ANOMALY DETECTION COMPLETE")
print("=" * 60)
print(f"\n📊 Summary:")
print(f" Total test points: {results['summary']['total_points']}")
print(f" Critical anomalies: {results['summary']['critical']} (outside 90% CI)")
print(f" Warnings: {results['summary']['warning']} (outside 80% CI)")
print(f" Normal: {results['summary']['normal']}")
print("\n💡 How It Works:")
print(" 1. TimesFM forecasts with quantile intervals (q10, q20, ..., q90)")
print(" 2. If actual value falls outside 90% CI → CRITICAL anomaly")
print(" 3. If actual value falls outside 80% CI → WARNING")
print(" 4. Otherwise → NORMAL")
print("\n📁 Output Files:")
print(f" {output_path}")
print(f" {results_path}")
if __name__ == "__main__":
main()