mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-27 07:09:27 +08:00
feat(skill): add timesfm-forecasting skill for time series forecasting
Add comprehensive TimesFM forecasting skill with mandatory system preflight checks (RAM/GPU/disk), end-to-end CSV forecasting script, full API reference, data preparation guide, and hardware requirements documentation. Supports TimesFM 2.5 (200M), 2.0 (500M), and legacy v1.0 with automatic batch size recommendations based on hardware.
This commit is contained in:
520
scientific-skills/timesfm-forecasting/scripts/check_system.py
Normal file
520
scientific-skills/timesfm-forecasting/scripts/check_system.py
Normal file
@@ -0,0 +1,520 @@
|
||||
#!/usr/bin/env python3
|
||||
"""TimesFM System Requirements Preflight Checker.
|
||||
|
||||
MANDATORY: Run this script before loading TimesFM for the first time.
|
||||
It checks RAM, GPU/VRAM, disk space, Python version, and package
|
||||
installation so the agent never crashes a user's machine.
|
||||
|
||||
Usage:
|
||||
python check_system.py
|
||||
python check_system.py --model v2.5 # default
|
||||
python check_system.py --model v2.0 # archived 500M model
|
||||
python check_system.py --model v1.0 # archived 200M model
|
||||
python check_system.py --json # machine-readable output
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import struct
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model requirement profiles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODEL_PROFILES: dict[str, dict[str, Any]] = {
|
||||
"v2.5": {
|
||||
"name": "TimesFM 2.5 (200M)",
|
||||
"params": "200M",
|
||||
"min_ram_gb": 2.0,
|
||||
"recommended_ram_gb": 4.0,
|
||||
"min_vram_gb": 2.0,
|
||||
"recommended_vram_gb": 4.0,
|
||||
"disk_gb": 2.0, # model weights + overhead
|
||||
"hf_repo": "google/timesfm-2.5-200m-pytorch",
|
||||
},
|
||||
"v2.0": {
|
||||
"name": "TimesFM 2.0 (500M)",
|
||||
"params": "500M",
|
||||
"min_ram_gb": 8.0,
|
||||
"recommended_ram_gb": 16.0,
|
||||
"min_vram_gb": 4.0,
|
||||
"recommended_vram_gb": 8.0,
|
||||
"disk_gb": 4.0,
|
||||
"hf_repo": "google/timesfm-2.0-500m-pytorch",
|
||||
},
|
||||
"v1.0": {
|
||||
"name": "TimesFM 1.0 (200M)",
|
||||
"params": "200M",
|
||||
"min_ram_gb": 4.0,
|
||||
"recommended_ram_gb": 8.0,
|
||||
"min_vram_gb": 2.0,
|
||||
"recommended_vram_gb": 4.0,
|
||||
"disk_gb": 2.0,
|
||||
"hf_repo": "google/timesfm-1.0-200m-pytorch",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckResult:
|
||||
name: str
|
||||
status: str # "pass", "warn", "fail"
|
||||
detail: str
|
||||
value: str = ""
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
return {"pass": "✅", "warn": "⚠️", "fail": "🛑"}.get(self.status, "❓")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[{self.name:<10}] {self.value:<40} {self.icon} {self.status.upper()}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemReport:
|
||||
model: str
|
||||
checks: list[CheckResult] = field(default_factory=list)
|
||||
verdict: str = ""
|
||||
verdict_detail: str = ""
|
||||
recommended_batch_size: int = 1
|
||||
mode: str = "cpu" # "cpu", "gpu", "mps"
|
||||
|
||||
@property
|
||||
def passed(self) -> bool:
|
||||
return all(c.status != "fail" for c in self.checks)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"passed": self.passed,
|
||||
"mode": self.mode,
|
||||
"recommended_batch_size": self.recommended_batch_size,
|
||||
"verdict": self.verdict,
|
||||
"verdict_detail": self.verdict_detail,
|
||||
"checks": [
|
||||
{
|
||||
"name": c.name,
|
||||
"status": c.status,
|
||||
"detail": c.detail,
|
||||
"value": c.value,
|
||||
}
|
||||
for c in self.checks
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Individual checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_total_ram_gb() -> float:
|
||||
"""Return total physical RAM in GB, cross-platform."""
|
||||
try:
|
||||
if sys.platform == "linux":
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
if line.startswith("MemTotal"):
|
||||
return int(line.split()[1]) / (1024 * 1024)
|
||||
elif sys.platform == "darwin":
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
["sysctl", "-n", "hw.memsize"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return int(result.stdout.strip()) / (1024**3)
|
||||
elif sys.platform == "win32":
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
|
||||
|
||||
class MEMORYSTATUSEX(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("dwLength", ctypes.c_ulong),
|
||||
("dwMemoryLoad", ctypes.c_ulong),
|
||||
("ullTotalPhys", ctypes.c_ulonglong),
|
||||
("ullAvailPhys", ctypes.c_ulonglong),
|
||||
("ullTotalPageFile", ctypes.c_ulonglong),
|
||||
("ullAvailPageFile", ctypes.c_ulonglong),
|
||||
("ullTotalVirtual", ctypes.c_ulonglong),
|
||||
("ullAvailVirtual", ctypes.c_ulonglong),
|
||||
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
|
||||
]
|
||||
|
||||
stat = MEMORYSTATUSEX()
|
||||
stat.dwLength = ctypes.sizeof(stat)
|
||||
kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))
|
||||
return stat.ullTotalPhys / (1024**3)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: use struct to estimate (unreliable)
|
||||
return struct.calcsize("P") * 8 / 8 # placeholder
|
||||
|
||||
|
||||
def _get_available_ram_gb() -> float:
|
||||
"""Return available RAM in GB."""
|
||||
try:
|
||||
if sys.platform == "linux":
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
if line.startswith("MemAvailable"):
|
||||
return int(line.split()[1]) / (1024 * 1024)
|
||||
elif sys.platform == "darwin":
|
||||
import subprocess
|
||||
|
||||
# Use vm_stat for available memory on macOS
|
||||
result = subprocess.run(
|
||||
["vm_stat"], capture_output=True, text=True, check=True
|
||||
)
|
||||
free = 0
|
||||
page_size = 4096
|
||||
for line in result.stdout.split("\n"):
|
||||
if "Pages free" in line or "Pages inactive" in line:
|
||||
val = line.split(":")[1].strip().rstrip(".")
|
||||
free += int(val) * page_size
|
||||
return free / (1024**3)
|
||||
elif sys.platform == "win32":
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
|
||||
|
||||
class MEMORYSTATUSEX(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("dwLength", ctypes.c_ulong),
|
||||
("dwMemoryLoad", ctypes.c_ulong),
|
||||
("ullTotalPhys", ctypes.c_ulonglong),
|
||||
("ullAvailPhys", ctypes.c_ulonglong),
|
||||
("ullTotalPageFile", ctypes.c_ulonglong),
|
||||
("ullAvailPageFile", ctypes.c_ulonglong),
|
||||
("ullTotalVirtual", ctypes.c_ulonglong),
|
||||
("ullAvailVirtual", ctypes.c_ulonglong),
|
||||
("sullAvailExtendedVirtual", ctypes.c_ulonglong),
|
||||
]
|
||||
|
||||
stat = MEMORYSTATUSEX()
|
||||
stat.dwLength = ctypes.sizeof(stat)
|
||||
kernel32.GlobalMemoryStatusEx(ctypes.byref(stat))
|
||||
return stat.ullAvailPhys / (1024**3)
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
def check_ram(profile: dict[str, Any]) -> CheckResult:
|
||||
"""Check if system has enough RAM."""
|
||||
total = _get_total_ram_gb()
|
||||
available = _get_available_ram_gb()
|
||||
min_ram = profile["min_ram_gb"]
|
||||
rec_ram = profile["recommended_ram_gb"]
|
||||
|
||||
value = f"Total: {total:.1f} GB | Available: {available:.1f} GB"
|
||||
|
||||
if total < min_ram:
|
||||
return CheckResult(
|
||||
name="RAM",
|
||||
status="fail",
|
||||
detail=(
|
||||
f"System has {total:.1f} GB RAM but {profile['name']} requires "
|
||||
f"at least {min_ram:.0f} GB. The model will likely fail to load "
|
||||
f"or cause the system to swap heavily and become unresponsive."
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
elif total < rec_ram:
|
||||
return CheckResult(
|
||||
name="RAM",
|
||||
status="warn",
|
||||
detail=(
|
||||
f"System has {total:.1f} GB RAM. {profile['name']} recommends "
|
||||
f"{rec_ram:.0f} GB. It may work with small batch sizes but could "
|
||||
f"be tight. Use per_core_batch_size=4 or lower."
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
else:
|
||||
return CheckResult(
|
||||
name="RAM",
|
||||
status="pass",
|
||||
detail=f"System has {total:.1f} GB RAM, meets {rec_ram:.0f} GB recommendation.",
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def check_gpu() -> CheckResult:
|
||||
"""Check GPU availability and VRAM."""
|
||||
# Try CUDA first
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
name = torch.cuda.get_device_name(0)
|
||||
vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
return CheckResult(
|
||||
name="GPU",
|
||||
status="pass",
|
||||
detail=f"{name} with {vram:.1f} GB VRAM detected.",
|
||||
value=f"{name} | VRAM: {vram:.1f} GB",
|
||||
)
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return CheckResult(
|
||||
name="GPU",
|
||||
status="pass",
|
||||
detail="Apple Silicon MPS backend available. Uses unified memory.",
|
||||
value="Apple Silicon MPS",
|
||||
)
|
||||
else:
|
||||
return CheckResult(
|
||||
name="GPU",
|
||||
status="warn",
|
||||
detail=(
|
||||
"No GPU detected. TimesFM will run on CPU (slower but functional). "
|
||||
"Install CUDA-enabled PyTorch for GPU acceleration."
|
||||
),
|
||||
value="None (CPU only)",
|
||||
)
|
||||
except ImportError:
|
||||
return CheckResult(
|
||||
name="GPU",
|
||||
status="warn",
|
||||
detail="PyTorch not installed — cannot check GPU. Install torch first.",
|
||||
value="Unknown (torch not installed)",
|
||||
)
|
||||
|
||||
|
||||
def check_disk(profile: dict[str, Any]) -> CheckResult:
|
||||
"""Check available disk space for model download."""
|
||||
# Check HuggingFace cache dir or home dir
|
||||
hf_cache = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
||||
cache_dir = Path(hf_cache)
|
||||
check_dir = cache_dir if cache_dir.exists() else Path.home()
|
||||
|
||||
usage = shutil.disk_usage(str(check_dir))
|
||||
free_gb = usage.free / (1024**3)
|
||||
required = profile["disk_gb"]
|
||||
|
||||
value = f"Free: {free_gb:.1f} GB (in {check_dir})"
|
||||
|
||||
if free_gb < required:
|
||||
return CheckResult(
|
||||
name="Disk",
|
||||
status="fail",
|
||||
detail=(
|
||||
f"Only {free_gb:.1f} GB free in {check_dir}. "
|
||||
f"Need at least {required:.0f} GB for model weights. "
|
||||
f"Free up space or set HF_HOME to a larger volume."
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
else:
|
||||
return CheckResult(
|
||||
name="Disk",
|
||||
status="pass",
|
||||
detail=f"{free_gb:.1f} GB available, exceeds {required:.0f} GB requirement.",
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def check_python() -> CheckResult:
|
||||
"""Check Python version >= 3.10."""
|
||||
version = sys.version.split()[0]
|
||||
major, minor = sys.version_info[:2]
|
||||
|
||||
if (major, minor) < (3, 10):
|
||||
return CheckResult(
|
||||
name="Python",
|
||||
status="fail",
|
||||
detail=f"Python {version} detected. TimesFM requires Python >= 3.10.",
|
||||
value=version,
|
||||
)
|
||||
else:
|
||||
return CheckResult(
|
||||
name="Python",
|
||||
status="pass",
|
||||
detail=f"Python {version} meets >= 3.10 requirement.",
|
||||
value=version,
|
||||
)
|
||||
|
||||
|
||||
def check_package(pkg_name: str, import_name: str | None = None) -> CheckResult:
|
||||
"""Check if a Python package is installed."""
|
||||
import_name = import_name or pkg_name
|
||||
try:
|
||||
mod = __import__(import_name)
|
||||
version = getattr(mod, "__version__", "unknown")
|
||||
return CheckResult(
|
||||
name=pkg_name,
|
||||
status="pass",
|
||||
detail=f"{pkg_name} {version} is installed.",
|
||||
value=f"Installed ({version})",
|
||||
)
|
||||
except ImportError:
|
||||
return CheckResult(
|
||||
name=pkg_name,
|
||||
status="warn",
|
||||
detail=f"{pkg_name} is not installed. Run: uv pip install {pkg_name}",
|
||||
value="Not installed",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch size recommendation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def recommend_batch_size(report: SystemReport) -> int:
|
||||
"""Recommend per_core_batch_size based on available resources."""
|
||||
total_ram = _get_total_ram_gb()
|
||||
|
||||
# Check if GPU is available
|
||||
gpu_check = next((c for c in report.checks if c.name == "GPU"), None)
|
||||
|
||||
if gpu_check and gpu_check.status == "pass" and "VRAM" in gpu_check.value:
|
||||
# Extract VRAM
|
||||
try:
|
||||
vram_str = gpu_check.value.split("VRAM:")[1].strip().split()[0]
|
||||
vram = float(vram_str)
|
||||
if vram >= 24:
|
||||
return 256
|
||||
elif vram >= 16:
|
||||
return 128
|
||||
elif vram >= 8:
|
||||
return 64
|
||||
elif vram >= 4:
|
||||
return 32
|
||||
else:
|
||||
return 16
|
||||
except (ValueError, IndexError):
|
||||
return 32
|
||||
elif gpu_check and "MPS" in gpu_check.value:
|
||||
# Apple Silicon — use unified memory heuristic
|
||||
if total_ram >= 32:
|
||||
return 64
|
||||
elif total_ram >= 16:
|
||||
return 32
|
||||
else:
|
||||
return 16
|
||||
else:
|
||||
# CPU only
|
||||
if total_ram >= 32:
|
||||
return 64
|
||||
elif total_ram >= 16:
|
||||
return 32
|
||||
elif total_ram >= 8:
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_checks(model_version: str = "v2.5") -> SystemReport:
|
||||
"""Run all system checks and return a report."""
|
||||
profile = MODEL_PROFILES[model_version]
|
||||
report = SystemReport(model=profile["name"])
|
||||
|
||||
# Run checks
|
||||
report.checks.append(check_ram(profile))
|
||||
report.checks.append(check_gpu())
|
||||
report.checks.append(check_disk(profile))
|
||||
report.checks.append(check_python())
|
||||
report.checks.append(check_package("timesfm"))
|
||||
report.checks.append(check_package("torch"))
|
||||
|
||||
# Determine mode
|
||||
gpu_check = next((c for c in report.checks if c.name == "GPU"), None)
|
||||
if gpu_check and gpu_check.status == "pass":
|
||||
if "MPS" in gpu_check.value:
|
||||
report.mode = "mps"
|
||||
else:
|
||||
report.mode = "gpu"
|
||||
else:
|
||||
report.mode = "cpu"
|
||||
|
||||
# Batch size
|
||||
report.recommended_batch_size = recommend_batch_size(report)
|
||||
|
||||
# Verdict
|
||||
if report.passed:
|
||||
report.verdict = (
|
||||
f"✅ System is ready for {profile['name']} ({report.mode.upper()} mode)"
|
||||
)
|
||||
report.verdict_detail = (
|
||||
f"Recommended: per_core_batch_size={report.recommended_batch_size}"
|
||||
)
|
||||
else:
|
||||
failed = [c for c in report.checks if c.status == "fail"]
|
||||
report.verdict = f"🛑 System does NOT meet requirements for {profile['name']}"
|
||||
report.verdict_detail = "; ".join(c.detail for c in failed)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def print_report(report: SystemReport) -> None:
|
||||
"""Print a human-readable report to stdout."""
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f" TimesFM System Requirements Check")
|
||||
print(f" Model: {report.model}")
|
||||
print(f"{'=' * 50}\n")
|
||||
|
||||
for check in report.checks:
|
||||
print(f" {check}")
|
||||
print()
|
||||
|
||||
print(f" VERDICT: {report.verdict}")
|
||||
if report.verdict_detail:
|
||||
print(f" {report.verdict_detail}")
|
||||
print()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Check system requirements for TimesFM."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=list(MODEL_PROFILES.keys()),
|
||||
default="v2.5",
|
||||
help="Model version to check requirements for (default: v2.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Output results as JSON (machine-readable)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
report = run_checks(args.model)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(report.to_dict(), indent=2))
|
||||
else:
|
||||
print_report(report)
|
||||
|
||||
# Exit with non-zero if any check failed
|
||||
sys.exit(0 if report.passed else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
269
scientific-skills/timesfm-forecasting/scripts/forecast_csv.py
Normal file
269
scientific-skills/timesfm-forecasting/scripts/forecast_csv.py
Normal file
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
"""End-to-end CSV forecasting with TimesFM.
|
||||
|
||||
Loads a CSV, runs the system preflight check, loads TimesFM, forecasts
|
||||
the requested columns, and writes results to a new CSV or JSON.
|
||||
|
||||
Usage:
|
||||
python forecast_csv.py input.csv --horizon 24
|
||||
python forecast_csv.py input.csv --horizon 12 --date-col date --value-cols sales,revenue
|
||||
python forecast_csv.py input.csv --horizon 52 --output forecasts.csv
|
||||
python forecast_csv.py input.csv --horizon 30 --output forecasts.json --format json
|
||||
|
||||
The script automatically:
|
||||
1. Runs the system preflight check (exits if it fails).
|
||||
2. Loads TimesFM 2.5 from Hugging Face.
|
||||
3. Reads the CSV and identifies time series columns.
|
||||
4. Forecasts each series with prediction intervals.
|
||||
5. Writes results to the specified output file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def run_preflight() -> dict:
|
||||
"""Run the system preflight check and return the report."""
|
||||
# Import the check_system module from the same directory
|
||||
script_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(script_dir))
|
||||
from check_system import run_checks
|
||||
|
||||
report = run_checks("v2.5")
|
||||
if not report.passed:
|
||||
print("\n🛑 System check FAILED. Cannot proceed with forecasting.")
|
||||
print(f" {report.verdict_detail}")
|
||||
print("\nRun 'python scripts/check_system.py' for details.")
|
||||
sys.exit(1)
|
||||
|
||||
return report.to_dict()
|
||||
|
||||
|
||||
def load_model(batch_size: int = 32):
|
||||
"""Load and compile the TimesFM model."""
|
||||
import torch
|
||||
import timesfm
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
print("Loading TimesFM 2.5 from Hugging Face...")
|
||||
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
||||
"google/timesfm-2.5-200m-pytorch"
|
||||
)
|
||||
|
||||
print(f"Compiling with per_core_batch_size={batch_size}...")
|
||||
model.compile(
|
||||
timesfm.ForecastConfig(
|
||||
max_context=1024,
|
||||
max_horizon=256,
|
||||
normalize_inputs=True,
|
||||
use_continuous_quantile_head=True,
|
||||
force_flip_invariance=True,
|
||||
infer_is_positive=True,
|
||||
fix_quantile_crossing=True,
|
||||
per_core_batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_csv(
|
||||
path: str,
|
||||
date_col: str | None = None,
|
||||
value_cols: list[str] | None = None,
|
||||
) -> tuple[pd.DataFrame, list[str], str | None]:
|
||||
"""Load CSV and identify time series columns.
|
||||
|
||||
Returns:
|
||||
(dataframe, value_column_names, date_column_name_or_none)
|
||||
"""
|
||||
df = pd.read_csv(path)
|
||||
|
||||
# Identify date column
|
||||
if date_col and date_col in df.columns:
|
||||
df[date_col] = pd.to_datetime(df[date_col])
|
||||
elif date_col:
|
||||
print(f"⚠️ Date column '{date_col}' not found. Available: {list(df.columns)}")
|
||||
date_col = None
|
||||
|
||||
# Identify value columns
|
||||
if value_cols:
|
||||
missing = [c for c in value_cols if c not in df.columns]
|
||||
if missing:
|
||||
print(f"⚠️ Columns not found: {missing}. Available: {list(df.columns)}")
|
||||
value_cols = [c for c in value_cols if c in df.columns]
|
||||
else:
|
||||
# Auto-detect numeric columns (exclude date)
|
||||
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
||||
if date_col and date_col in numeric_cols:
|
||||
numeric_cols.remove(date_col)
|
||||
value_cols = numeric_cols
|
||||
|
||||
if not value_cols:
|
||||
print("🛑 No numeric columns found to forecast.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(value_cols)} series to forecast: {value_cols}")
|
||||
return df, value_cols, date_col
|
||||
|
||||
|
||||
def forecast_series(
|
||||
model, df: pd.DataFrame, value_cols: list[str], horizon: int
|
||||
) -> dict[str, dict]:
|
||||
"""Forecast all series and return results dict."""
|
||||
inputs = []
|
||||
for col in value_cols:
|
||||
values = df[col].dropna().values.astype(np.float32)
|
||||
inputs.append(values)
|
||||
|
||||
print(f"Forecasting {len(inputs)} series with horizon={horizon}...")
|
||||
point, quantiles = model.forecast(horizon=horizon, inputs=inputs)
|
||||
|
||||
results = {}
|
||||
for i, col in enumerate(value_cols):
|
||||
results[col] = {
|
||||
"forecast": point[i].tolist(),
|
||||
"lower_90": quantiles[i, :, 1].tolist(), # 10th percentile
|
||||
"lower_80": quantiles[i, :, 2].tolist(), # 20th percentile
|
||||
"median": quantiles[i, :, 5].tolist(), # 50th percentile
|
||||
"upper_80": quantiles[i, :, 8].tolist(), # 80th percentile
|
||||
"upper_90": quantiles[i, :, 9].tolist(), # 90th percentile
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def write_csv_output(
|
||||
results: dict[str, dict],
|
||||
output_path: str,
|
||||
df: pd.DataFrame,
|
||||
date_col: str | None,
|
||||
horizon: int,
|
||||
) -> None:
|
||||
"""Write forecast results to CSV."""
|
||||
rows = []
|
||||
for col, data in results.items():
|
||||
# Try to generate future dates
|
||||
future_dates = list(range(1, horizon + 1))
|
||||
if date_col and date_col in df.columns:
|
||||
try:
|
||||
last_date = df[date_col].dropna().iloc[-1]
|
||||
freq = pd.infer_freq(df[date_col].dropna())
|
||||
if freq:
|
||||
future_dates = pd.date_range(
|
||||
last_date, periods=horizon + 1, freq=freq
|
||||
)[1:].tolist()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for h in range(horizon):
|
||||
row = {
|
||||
"series": col,
|
||||
"step": h + 1,
|
||||
"forecast": data["forecast"][h],
|
||||
"lower_90": data["lower_90"][h],
|
||||
"lower_80": data["lower_80"][h],
|
||||
"median": data["median"][h],
|
||||
"upper_80": data["upper_80"][h],
|
||||
"upper_90": data["upper_90"][h],
|
||||
}
|
||||
if isinstance(future_dates[0], (pd.Timestamp,)):
|
||||
row["date"] = future_dates[h]
|
||||
rows.append(row)
|
||||
|
||||
out_df = pd.DataFrame(rows)
|
||||
out_df.to_csv(output_path, index=False)
|
||||
print(f"✅ Wrote {len(rows)} forecast rows to {output_path}")
|
||||
|
||||
|
||||
def write_json_output(results: dict[str, dict], output_path: str) -> None:
|
||||
"""Write forecast results to JSON."""
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"✅ Wrote forecasts for {len(results)} series to {output_path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Forecast time series from CSV using TimesFM."
|
||||
)
|
||||
parser.add_argument("input", help="Path to input CSV file")
|
||||
parser.add_argument(
|
||||
"--horizon", type=int, required=True, help="Number of steps to forecast"
|
||||
)
|
||||
parser.add_argument("--date-col", help="Name of the date/time column")
|
||||
parser.add_argument(
|
||||
"--value-cols",
|
||||
help="Comma-separated list of value columns to forecast (default: all numeric)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="forecasts.csv",
|
||||
help="Output file path (default: forecasts.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=["csv", "json"],
|
||||
default=None,
|
||||
help="Output format (inferred from --output extension if not set)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Override per_core_batch_size (auto-detected from system check if omitted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-check",
|
||||
action="store_true",
|
||||
help="Skip system preflight check (not recommended)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse value columns
|
||||
value_cols = None
|
||||
if args.value_cols:
|
||||
value_cols = [c.strip() for c in args.value_cols.split(",")]
|
||||
|
||||
# Determine output format
|
||||
out_format = args.format
|
||||
if not out_format:
|
||||
out_format = "json" if args.output.endswith(".json") else "csv"
|
||||
|
||||
# 1. Preflight check
|
||||
if not args.skip_check:
|
||||
print("Running system preflight check...")
|
||||
report = run_preflight()
|
||||
batch_size = args.batch_size or report.get("recommended_batch_size", 32)
|
||||
else:
|
||||
print("⚠️ Skipping system check (--skip-check). Proceed with caution.")
|
||||
batch_size = args.batch_size or 32
|
||||
|
||||
# 2. Load model
|
||||
model = load_model(batch_size=batch_size)
|
||||
|
||||
# 3. Load CSV
|
||||
df, cols, date_col = load_csv(args.input, args.date_col, value_cols)
|
||||
|
||||
# 4. Forecast
|
||||
results = forecast_series(model, df, cols, args.horizon)
|
||||
|
||||
# 5. Write output
|
||||
if out_format == "json":
|
||||
write_json_output(results, args.output)
|
||||
else:
|
||||
write_csv_output(results, args.output, df, date_col, args.horizon)
|
||||
|
||||
print("\nDone! 🎉")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user