fix(examples): correct quantile indices, variable shadowing, and test design in anomaly + covariates examples

Anomaly detection fixes:
- Fix critical quantile index bug: index 0 is mean not q10; correct indices are q10=1, q20=2, q80=8, q90=9
- Redesign test: use all 36 months as context, inject 3 synthetic anomalies into future
- Result: 3 CRITICAL detected (was 11/12 — caused by test-set leakage + wrong indices)
- Update severity labels: CRITICAL = outside 80% PI, WARNING = outside 60% PI

Covariates fixes:
- Fix variable-shadowing bug: inner dict comprehension overwrote outer loop store_id
  causing all stores to get identical covariate arrays (store_A's price for everyone)
- Give each store a distinct price baseline (premium $12, standard $10, discount $7.50)
- Trim CONTEXT_LEN from 48 → 24 weeks; CSV now 108 rows (was 180)
- Add NOTE ON REAL DATA comment: temp file pattern for large external datasets

Both scripts regenerated with clean outputs.
This commit is contained in:
Clayton Young
2026-02-21 18:27:45 -05:00
parent 0d98fa353c
commit 509190118f
7 changed files with 612 additions and 697 deletions

View File

@@ -2,14 +2,19 @@
"""
TimesFM Anomaly Detection Example
This example demonstrates how to use TimesFM's quantile forecasts for
anomaly detection. The approach:
1. Forecast with quantile intervals (10th-90th percentiles)
2. Compare actual values against prediction intervals
3. Flag values outside intervals as anomalies
Demonstrates using TimesFM quantile forecasts as prediction intervals
for anomaly detection. Approach:
1. Use 36 months of real data as context
2. Create synthetic 12-month future (natural continuation of trend)
3. Inject 3 clear anomalies into that future
4. Forecast with quantile intervals → flag anomalies by severity
TimesFM does NOT have built-in anomaly detection, but the quantile
forecasts provide natural anomaly detection via prediction intervals.
TimesFM has NO built-in anomaly detection. Quantile forecasts provide
natural prediction intervals — values outside them are statistically unusual.
Quantile index reference (index 0 = mean, 1-9 = q10-q90):
80% PI = q10 (idx 1) to q90 (idx 9)
60% PI = q20 (idx 2) to q80 (idx 8)
"""
from __future__ import annotations
@@ -18,36 +23,51 @@ import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import timesfm
# Configuration
HORIZON = 12 # Forecast horizon
ANOMALY_THRESHOLD_WARNING = 0.80 # Outside 80% CI = warning
ANOMALY_THRESHOLD_CRITICAL = 0.90 # Outside 90% CI = critical
EXAMPLE_DIR = Path(__file__).parent
HORIZON = 12 # Forecast horizon (months)
DATA_FILE = (
Path(__file__).parent.parent / "global-temperature" / "temperature_anomaly.csv"
)
OUTPUT_DIR = EXAMPLE_DIR / "output"
OUTPUT_DIR = Path(__file__).parent / "output"
# Anomaly thresholds using available quantile outputs
# 80% PI = q10-q90 → "critical" if outside
# 60% PI = q20-q80 → "warning" if outside
IDX_Q10, IDX_Q20, IDX_Q80, IDX_Q90 = 1, 2, 8, 9
def inject_anomalies(
values: np.ndarray, n_anomalies: int = 3, seed: int = 42
def build_synthetic_future(
context: np.ndarray, n: int, seed: int = 42
) -> tuple[np.ndarray, list[int]]:
"""Inject synthetic anomalies into the data for demonstration."""
"""Build synthetic future that looks like a natural continuation.
Takes the mean/std of the last 6 context months as the baseline,
then injects 3 clear anomalies (2 high, 1 low) at fixed positions.
"""
rng = np.random.default_rng(seed)
anomaly_indices = rng.choice(len(values), size=n_anomalies, replace=False).tolist()
recent_mean = float(context[-6:].mean())
recent_std = float(context[-6:].std())
anomalous_values = values.copy()
for idx in anomaly_indices:
# Inject spike or dip (±40-60% of value)
multiplier = rng.choice([0.4, 0.6]) * rng.choice([1, -1])
anomalous_values[idx] = values[idx] * (1 + multiplier)
# Natural-looking continuation: small gaussian noise around recent mean
future = recent_mean + rng.normal(0, recent_std * 0.4, n).astype(np.float32)
return anomalous_values, sorted(anomaly_indices)
# Inject 3 unmistakable anomalies
anomaly_cfg = [
(2, +0.55), # month 3 — large spike up
(7, -0.50), # month 8 — large dip down
(10, +0.48), # month 11 — spike up
]
anomaly_indices = []
for idx, delta in anomaly_cfg:
future[idx] = recent_mean + delta
anomaly_indices.append(idx)
return future, sorted(anomaly_indices)
def main() -> None:
@@ -57,27 +77,30 @@ def main() -> None:
OUTPUT_DIR.mkdir(exist_ok=True)
# Load temperature data
print("\n📊 Loading temperature anomaly data...")
# ── Load all 36 months as context ─────────────────────────────
print("\n📊 Loading temperature data (all 36 months as context)...")
df = pd.read_csv(DATA_FILE, parse_dates=["date"])
df = df.sort_values("date").reset_index(drop=True)
context_values = df["anomaly_c"].values.astype(np.float32) # all 36 months
context_dates = df["date"].tolist()
# Split into context (first 24 months) and test (last 12 months)
context_values = df["anomaly_c"].values[:24].astype(np.float32)
actual_future = df["anomaly_c"].values[24:36].astype(np.float32)
dates_future = df["date"].values[24:36]
print(f" Context: 24 months (2022-01 to 2023-12)")
print(f" Test: 12 months (2024-01 to 2024-12)")
# Inject anomalies into test data for demonstration
print("\n🔬 Injecting synthetic anomalies for demonstration...")
test_values_with_anomalies, anomaly_indices = inject_anomalies(
actual_future, n_anomalies=3
print(
f" Context: {len(context_values)} months ({context_dates[0].strftime('%Y-%m')}{context_dates[-1].strftime('%Y-%m')})"
)
print(f" Injected anomalies at months: {anomaly_indices}")
# Load TimesFM
# ── Build synthetic future with known anomalies ────────────────
print("\n🔬 Building synthetic 12-month future with injected anomalies...")
future_values, injected_at = build_synthetic_future(context_values, HORIZON)
future_dates = pd.date_range(
start=context_dates[-1] + pd.DateOffset(months=1),
periods=HORIZON,
freq="MS",
)
print(
f" Anomalies injected at months: {[future_dates[i].strftime('%Y-%m') for i in injected_at]}"
)
# ── Load TimesFM and forecast ──────────────────────────────────
print("\n🤖 Loading TimesFM 1.0 (200M) PyTorch...")
hparams = timesfm.TimesFmHparams(horizon_len=HORIZON)
checkpoint = timesfm.TimesFmCheckpoint(
@@ -85,254 +108,186 @@ def main() -> None:
)
model = timesfm.TimesFm(hparams=hparams, checkpoint=checkpoint)
# Forecast with quantiles
print("\n📈 Forecasting with quantile intervals...")
point_forecast, quantile_forecast = model.forecast(
[context_values],
freq=[0],
)
print("\n📈 Forecasting...")
point_fc, quant_fc = model.forecast([context_values], freq=[0])
# Extract quantiles
# quantile_forecast shape: (1, 12, 10) - [mean, q10, q20, ..., q90]
point = point_forecast[0]
q10 = quantile_forecast[0, :, 0] # 10th percentile
q20 = quantile_forecast[0, :, 1] # 20th percentile
q50 = quantile_forecast[0, :, 4] # 50th percentile (median)
q80 = quantile_forecast[0, :, 7] # 80th percentile
q90 = quantile_forecast[0, :, 8] # 90th percentile
# quantile_forecast shape: (1, horizon, 10)
# index 0 = mean, index 1 = q10, ..., index 9 = q90
point = point_fc[0] # shape (12,)
q10 = quant_fc[0, :, IDX_Q10] # 10th pct
q20 = quant_fc[0, :, IDX_Q20] # 20th pct
q80 = quant_fc[0, :, IDX_Q80] # 80th pct
q90 = quant_fc[0, :, IDX_Q90] # 90th pct
print(f" Forecast mean: {point.mean():.3f}°C")
print(f" 90% CI width: {(q90 - q10).mean():.3f}°C (avg)")
print(f" 80% PI width: {(q90 - q10).mean():.3f}°C (avg)")
# Detect anomalies
# ── Detect anomalies ───────────────────────────────────────────
print("\n🔍 Detecting anomalies...")
anomalies = []
for i, (actual, lower_80, upper_80, lower_90, upper_90) in enumerate(
zip(test_values_with_anomalies, q20, q80, q10, q90)
records = []
for i, (actual, fcast, lo60, hi60, lo80, hi80) in enumerate(
zip(future_values, point, q20, q80, q10, q90)
):
month = dates_future[i]
month_str = pd.to_datetime(month).strftime("%Y-%m")
month = future_dates[i].strftime("%Y-%m")
if actual < lower_90 or actual > upper_90:
severity = "CRITICAL"
threshold = "90% CI"
color = "red"
elif actual < lower_80 or actual > upper_80:
severity = "WARNING"
threshold = "80% CI"
color = "orange"
if actual < lo80 or actual > hi80:
severity = "CRITICAL" # outside 80% PI
elif actual < lo60 or actual > hi60:
severity = "WARNING" # outside 60% PI
else:
severity = "NORMAL"
threshold = "within bounds"
color = "green"
anomalies.append(
records.append(
{
"month": month_str,
"actual": float(actual),
"forecast": float(point[i]),
"lower_80": float(lower_80),
"upper_80": float(upper_80),
"lower_90": float(lower_90),
"upper_90": float(upper_90),
"month": month,
"actual": round(float(actual), 4),
"forecast": round(float(fcast), 4),
"lower_60pi": round(float(lo60), 4),
"upper_60pi": round(float(hi60), 4),
"lower_80pi": round(float(lo80), 4),
"upper_80pi": round(float(hi80), 4),
"severity": severity,
"threshold": threshold,
"color": color,
"injected": (i in injected_at),
}
)
if severity != "NORMAL":
deviation = abs(actual - point[i])
dev = actual - fcast
print(
f" [{severity}] {month_str}: {actual:.2f}°C (forecast: {point[i]:.2f}°C, deviation: {deviation:.2f}°C)"
f" [{severity}] {month}: actual={actual:.2f} forecast={fcast:.2f} Δ={dev:+.2f}°C"
)
# Create visualization
print("\n📊 Creating anomaly visualization...")
# ── Visualise ─────────────────────────────────────────────────
print("\n📊 Creating visualization...")
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
fig, axes = plt.subplots(2, 1, figsize=(13, 9))
# Plot 1: Full time series with forecast and anomalies
ax1 = axes[0]
clr = {"CRITICAL": "red", "WARNING": "orange", "NORMAL": "steelblue"}
# Historical data
historical_dates = df["date"].values[:24]
ax1.plot(
historical_dates,
# — Panel 1: full series ———————————————————————————————————————
ax = axes[0]
ax.plot(
context_dates,
context_values,
"b-",
linewidth=2,
label="Historical Data",
lw=2,
marker="o",
markersize=4,
ms=4,
label="Context (36 months)",
)
# Actual future (with anomalies)
ax1.plot(
dates_future,
actual_future,
"g--",
linewidth=1.5,
label="Actual (clean)",
ax.fill_between(
future_dates, q10, q90, alpha=0.18, color="tomato", label="80% PI (q10q90)"
)
ax.fill_between(
future_dates, q20, q80, alpha=0.28, color="tomato", label="60% PI (q20q80)"
)
ax.plot(future_dates, point, "r-", lw=2, marker="s", ms=5, label="Forecast")
ax.plot(
future_dates,
future_values,
"k--",
lw=1.3,
alpha=0.5,
)
ax1.plot(
dates_future,
test_values_with_anomalies,
"ko",
markersize=8,
label="Actual (with anomalies)",
alpha=0.7,
label="Synthetic future (clean)",
)
# Forecast
ax1.plot(
dates_future,
point,
"r-",
linewidth=2,
label="Forecast (median)",
marker="s",
markersize=6,
)
# 90% CI
ax1.fill_between(dates_future, q10, q90, alpha=0.2, color="red", label="90% CI")
# 80% CI
ax1.fill_between(dates_future, q20, q80, alpha=0.3, color="red", label="80% CI")
# Highlight anomalies
for anomaly in anomalies:
if anomaly["severity"] != "NORMAL":
idx = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future].index(
anomaly["month"]
)
ax1.scatter(
[dates_future[idx]],
[test_values_with_anomalies[idx]],
c=anomaly["color"],
s=200,
marker="x" if anomaly["severity"] == "CRITICAL" else "^",
linewidths=3,
zorder=5,
# mark anomalies
for rec in records:
if rec["severity"] != "NORMAL":
dt = pd.to_datetime(rec["month"])
c = "red" if rec["severity"] == "CRITICAL" else "orange"
mk = "X" if rec["severity"] == "CRITICAL" else "^"
ax.scatter(
[dt], [rec["actual"]], c=c, s=220, marker=mk, zorder=6, linewidths=2
)
ax1.set_xlabel("Date", fontsize=12)
ax1.set_ylabel("Temperature Anomaly (°C)", fontsize=12)
ax1.set_title(
"TimesFM Anomaly Detection: Forecast Intervals Method",
fontsize=14,
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right")
ax.set_ylabel("Temperature Anomaly (°C)", fontsize=11)
ax.set_title(
"TimesFM Anomaly Detection — Prediction Interval Method",
fontsize=13,
fontweight="bold",
)
ax1.legend(loc="upper left", fontsize=10)
ax1.grid(True, alpha=0.3)
# Add annotation for anomalies
ax1.annotate(
"× = Critical (outside 90% CI)\n▲ = Warning (outside 80% CI)",
xy=(0.98, 0.02),
ax.legend(loc="upper left", fontsize=9, ncol=2)
ax.grid(True, alpha=0.25)
ax.annotate(
"X = Critical (outside 80% PI)\n▲ = Warning (outside 60% PI)",
xy=(0.98, 0.04),
xycoords="axes fraction",
ha="right",
va="bottom",
fontsize=10,
fontsize=9,
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
)
# Plot 2: Deviation from forecast with thresholds
# — Panel 2: deviation bars ———————————————————————————————————
ax2 = axes[1]
deviations = future_values - point
lo80_dev = q10 - point
hi80_dev = q90 - point
lo60_dev = q20 - point
hi60_dev = q80 - point
x = np.arange(HORIZON)
deviation = test_values_with_anomalies - point
lower_90_dev = q10 - point
upper_90_dev = q90 - point
lower_80_dev = q20 - point
upper_80_dev = q80 - point
ax2.fill_between(x, lo80_dev, hi80_dev, alpha=0.15, color="tomato", label="80% PI")
ax2.fill_between(x, lo60_dev, hi60_dev, alpha=0.25, color="tomato", label="60% PI")
bar_colors = [clr[r["severity"]] for r in records]
ax2.bar(x, deviations, color=bar_colors, alpha=0.75, edgecolor="black", lw=0.5)
ax2.axhline(0, color="black", lw=1)
months = [pd.to_datetime(d).strftime("%Y-%m") for d in dates_future]
x = np.arange(len(months))
# Threshold bands
ax2.fill_between(
x, lower_90_dev, upper_90_dev, alpha=0.2, color="red", label="90% CI bounds"
ax2.set_xticks(x)
ax2.set_xticklabels(
[r["month"] for r in records], rotation=45, ha="right", fontsize=9
)
ax2.fill_between(
x, lower_80_dev, upper_80_dev, alpha=0.3, color="red", label="80% CI bounds"
)
# Deviation bars
colors = [
"red"
if d < lower_90_dev[i] or d > upper_90_dev[i]
else "orange"
if d < lower_80_dev[i] or d > upper_80_dev[i]
else "green"
for i, d in enumerate(deviation)
]
ax2.bar(x, deviation, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5)
# Zero line
ax2.axhline(y=0, color="black", linestyle="-", linewidth=1)
ax2.set_xlabel("Month", fontsize=12)
ax2.set_ylabel("Deviation from Forecast (°C)", fontsize=12)
ax2.set_ylabel("Δ from Forecast (°C)", fontsize=11)
ax2.set_title(
"Deviation from Forecast with Anomaly Thresholds",
fontsize=14,
fontsize=13,
fontweight="bold",
)
ax2.set_xticks(x)
ax2.set_xticklabels(months, rotation=45, ha="right")
ax2.legend(loc="upper right", fontsize=10)
ax2.grid(True, alpha=0.3, axis="y")
ax2.legend(loc="upper right", fontsize=9)
ax2.grid(True, alpha=0.25, axis="y")
plt.tight_layout()
output_path = OUTPUT_DIR / "anomaly_detection.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f" Saved: {output_path}")
png_path = OUTPUT_DIR / "anomaly_detection.png"
plt.savefig(png_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" Saved: {png_path}")
# Save results
results = {
"method": "quantile_intervals",
"description": "Anomaly detection using TimesFM quantile forecasts as prediction intervals",
"thresholds": {
"warning": f"Outside {ANOMALY_THRESHOLD_WARNING * 100:.0f}% CI (q20-q80)",
"critical": f"Outside {ANOMALY_THRESHOLD_CRITICAL * 100:.0f}% CI (q10-q90)",
},
"anomalies": anomalies,
"summary": {
"total_points": len(anomalies),
"critical": sum(1 for a in anomalies if a["severity"] == "CRITICAL"),
"warning": sum(1 for a in anomalies if a["severity"] == "WARNING"),
"normal": sum(1 for a in anomalies if a["severity"] == "NORMAL"),
},
# ── Save JSON results ──────────────────────────────────────────
summary = {
"total": len(records),
"critical": sum(1 for r in records if r["severity"] == "CRITICAL"),
"warning": sum(1 for r in records if r["severity"] == "WARNING"),
"normal": sum(1 for r in records if r["severity"] == "NORMAL"),
}
out = {
"method": "quantile_prediction_intervals",
"description": (
"Anomaly detection via TimesFM quantile forecasts. "
"80% PI = q10q90 (CRITICAL if violated). "
"60% PI = q20q80 (WARNING if violated)."
),
"context": "36 months of real NOAA temperature anomaly data (2022-2024)",
"future": "12 synthetic months with 3 injected anomalies",
"quantile_indices": {"q10": 1, "q20": 2, "q80": 8, "q90": 9},
"summary": summary,
"detections": records,
}
json_path = OUTPUT_DIR / "anomaly_detection.json"
with open(json_path, "w") as f:
json.dump(out, f, indent=2)
print(f" Saved: {json_path}")
results_path = OUTPUT_DIR / "anomaly_detection.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f" Saved: {results_path}")
# Print summary
# ── Summary ────────────────────────────────────────────────────
print("\n" + "=" * 60)
print(" ✅ ANOMALY DETECTION COMPLETE")
print("=" * 60)
print(f"\n📊 Summary:")
print(f" Total test points: {results['summary']['total_points']}")
print(f" Critical anomalies: {results['summary']['critical']} (outside 90% CI)")
print(f" Warnings: {results['summary']['warning']} (outside 80% CI)")
print(f" Normal: {results['summary']['normal']}")
print("\n💡 How It Works:")
print(" 1. TimesFM forecasts with quantile intervals (q10, q20, ..., q90)")
print(" 2. If actual value falls outside 90% CI → CRITICAL anomaly")
print(" 3. If actual value falls outside 80% CI → WARNING")
print(" 4. Otherwise → NORMAL")
print("\n📁 Output Files:")
print(f" {output_path}")
print(f" {results_path}")
print(f"\n Total future points : {summary['total']}")
print(f" Critical (80% PI) : {summary['critical']}")
print(f" Warning (60% PI) : {summary['warning']}")
print(f" Normal : {summary['normal']}")
if __name__ == "__main__":