mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-28 07:33:45 +08:00
Improve Pytorch Lightning skill
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Helper script to quickly set up a PyTorch Lightning Trainer with common configurations.
|
||||
Quick Trainer Setup Examples for PyTorch Lightning.
|
||||
|
||||
This script provides preset configurations for different training scenarios
|
||||
and makes it easy to create a Trainer with best practices.
|
||||
This script provides ready-to-use Trainer configurations for common use cases.
|
||||
Copy and modify these configurations for your specific needs.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
@@ -10,253 +10,445 @@ from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
DeviceStatsMonitor,
|
||||
RichProgressBar,
|
||||
ModelSummary,
|
||||
)
|
||||
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
|
||||
from lightning.pytorch import loggers as pl_loggers
|
||||
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
|
||||
|
||||
|
||||
def create_trainer(
|
||||
preset: str = "default",
|
||||
max_epochs: int = 100,
|
||||
accelerator: str = "auto",
|
||||
devices: int = 1,
|
||||
log_dir: str = "./logs",
|
||||
experiment_name: str = "lightning_experiment",
|
||||
enable_checkpointing: bool = True,
|
||||
enable_early_stopping: bool = True,
|
||||
**kwargs
|
||||
# =============================================================================
|
||||
# 1. BASIC TRAINING (Single GPU/CPU)
|
||||
# =============================================================================
|
||||
|
||||
def basic_trainer():
|
||||
"""
|
||||
Simple trainer for quick prototyping.
|
||||
Use for: Small models, debugging, single GPU training
|
||||
"""
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
accelerator="auto", # Automatically select GPU/CPU
|
||||
devices="auto", # Use all available devices
|
||||
enable_progress_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. DEBUGGING CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
def debug_trainer():
|
||||
"""
|
||||
Trainer for debugging with fast dev run and anomaly detection.
|
||||
Use for: Finding bugs, testing code quickly
|
||||
"""
|
||||
trainer = L.Trainer(
|
||||
fast_dev_run=True, # Run 1 batch through train/val/test
|
||||
accelerator="cpu", # Use CPU for easier debugging
|
||||
detect_anomaly=True, # Detect NaN/Inf in gradients
|
||||
log_every_n_steps=1, # Log every step
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. PRODUCTION TRAINING (Single GPU)
|
||||
# =============================================================================
|
||||
|
||||
def production_single_gpu_trainer(
|
||||
max_epochs=100,
|
||||
log_dir="logs",
|
||||
checkpoint_dir="checkpoints"
|
||||
):
|
||||
"""
|
||||
Create a Lightning Trainer with preset configurations.
|
||||
|
||||
Args:
|
||||
preset: Configuration preset - "default", "fast_dev", "production", "distributed"
|
||||
max_epochs: Maximum number of training epochs
|
||||
accelerator: Device to use ("auto", "gpu", "cpu", "tpu")
|
||||
devices: Number of devices to use
|
||||
log_dir: Directory for logs and checkpoints
|
||||
experiment_name: Name for the experiment
|
||||
enable_checkpointing: Whether to enable model checkpointing
|
||||
enable_early_stopping: Whether to enable early stopping
|
||||
**kwargs: Additional arguments to pass to Trainer
|
||||
|
||||
Returns:
|
||||
Configured Lightning Trainer instance
|
||||
Production-ready trainer for single GPU with checkpointing and logging.
|
||||
Use for: Final training runs on single GPU
|
||||
"""
|
||||
|
||||
callbacks = []
|
||||
logger_list = []
|
||||
|
||||
# Configure based on preset
|
||||
if preset == "fast_dev":
|
||||
# Fast development run - minimal epochs, quick debugging
|
||||
config = {
|
||||
"fast_dev_run": False,
|
||||
"max_epochs": 3,
|
||||
"limit_train_batches": 100,
|
||||
"limit_val_batches": 50,
|
||||
"log_every_n_steps": 10,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
}
|
||||
|
||||
elif preset == "production":
|
||||
# Production-ready configuration with all bells and whistles
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"precision": "16-mixed",
|
||||
"gradient_clip_val": 1.0,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
"deterministic": True,
|
||||
"benchmark": True,
|
||||
}
|
||||
|
||||
# Add model checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
verbose=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add early stopping
|
||||
if enable_early_stopping:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=10,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add learning rate monitor
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
# Add TensorBoard logger
|
||||
logger_list.append(
|
||||
TensorBoardLogger(
|
||||
save_dir=log_dir,
|
||||
name=experiment_name,
|
||||
version=None,
|
||||
)
|
||||
)
|
||||
|
||||
elif preset == "distributed":
|
||||
# Distributed training configuration
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"strategy": "ddp",
|
||||
"precision": "16-mixed",
|
||||
"sync_batchnorm": True,
|
||||
"use_distributed_sampler": True,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
}
|
||||
|
||||
# Add model checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
)
|
||||
|
||||
else: # default
|
||||
# Default configuration - balanced for most use cases
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
}
|
||||
|
||||
# Add basic checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
save_last=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add CSV logger
|
||||
logger_list.append(
|
||||
CSVLogger(
|
||||
save_dir=log_dir,
|
||||
name=experiment_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add progress bar
|
||||
if config.get("enable_progress_bar", True):
|
||||
callbacks.append(RichProgressBar())
|
||||
|
||||
# Merge with provided kwargs
|
||||
final_config = {
|
||||
**config,
|
||||
"accelerator": accelerator,
|
||||
"devices": devices,
|
||||
"callbacks": callbacks,
|
||||
"logger": logger_list if logger_list else True,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Create and return trainer
|
||||
return L.Trainer(**final_config)
|
||||
|
||||
|
||||
def create_debugging_trainer():
|
||||
"""Create a trainer optimized for debugging."""
|
||||
return create_trainer(
|
||||
preset="fast_dev",
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=5,
|
||||
num_sanity_val_steps=2,
|
||||
# Callbacks
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename="{epoch:02d}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=10,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def create_gpu_trainer(num_gpus: int = 1, precision: str = "16-mixed"):
|
||||
"""Create a trainer optimized for GPU training."""
|
||||
return create_trainer(
|
||||
preset="production",
|
||||
lr_monitor = LearningRateMonitor(logging_interval="step")
|
||||
|
||||
# Logger
|
||||
tb_logger = pl_loggers.TensorBoardLogger(
|
||||
save_dir=log_dir,
|
||||
name="my_model",
|
||||
)
|
||||
|
||||
# Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
precision="16-mixed", # Mixed precision for speed
|
||||
callbacks=[
|
||||
checkpoint_callback,
|
||||
early_stop_callback,
|
||||
lr_monitor,
|
||||
],
|
||||
logger=tb_logger,
|
||||
log_every_n_steps=50,
|
||||
gradient_clip_val=1.0, # Clip gradients
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. MULTI-GPU TRAINING (DDP)
|
||||
# =============================================================================
|
||||
|
||||
def multi_gpu_ddp_trainer(
|
||||
max_epochs=100,
|
||||
num_gpus=4,
|
||||
log_dir="logs",
|
||||
checkpoint_dir="checkpoints"
|
||||
):
|
||||
"""
|
||||
Multi-GPU training with Distributed Data Parallel.
|
||||
Use for: Models <500M parameters, standard deep learning models
|
||||
"""
|
||||
# Callbacks
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename="{epoch:02d}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=10,
|
||||
mode="min",
|
||||
)
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval="step")
|
||||
|
||||
# Logger
|
||||
wandb_logger = pl_loggers.WandbLogger(
|
||||
project="my-project",
|
||||
save_dir=log_dir,
|
||||
)
|
||||
|
||||
# Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="gpu",
|
||||
devices=num_gpus,
|
||||
precision=precision,
|
||||
strategy=DDPStrategy(
|
||||
find_unused_parameters=False,
|
||||
gradient_as_bucket_view=True,
|
||||
),
|
||||
precision="16-mixed",
|
||||
callbacks=[
|
||||
checkpoint_callback,
|
||||
early_stop_callback,
|
||||
lr_monitor,
|
||||
],
|
||||
logger=wandb_logger,
|
||||
log_every_n_steps=50,
|
||||
gradient_clip_val=1.0,
|
||||
sync_batchnorm=True, # Sync batch norm across GPUs
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
def create_distributed_trainer(num_gpus: int = 2, num_nodes: int = 1):
|
||||
"""Create a trainer for distributed training across multiple GPUs."""
|
||||
return create_trainer(
|
||||
preset="distributed",
|
||||
|
||||
# =============================================================================
|
||||
# 5. LARGE MODEL TRAINING (FSDP)
|
||||
# =============================================================================
|
||||
|
||||
def large_model_fsdp_trainer(
|
||||
max_epochs=100,
|
||||
num_gpus=8,
|
||||
log_dir="logs",
|
||||
checkpoint_dir="checkpoints"
|
||||
):
|
||||
"""
|
||||
Training for large models (500M+ parameters) with FSDP.
|
||||
Use for: Large transformers, models that don't fit in single GPU
|
||||
"""
|
||||
import torch.nn as nn
|
||||
|
||||
# Callbacks
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename="{epoch:02d}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval="step")
|
||||
|
||||
# Logger
|
||||
wandb_logger = pl_loggers.WandbLogger(
|
||||
project="large-model",
|
||||
save_dir=log_dir,
|
||||
)
|
||||
|
||||
# Trainer with FSDP
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="gpu",
|
||||
devices=num_gpus,
|
||||
num_nodes=num_nodes,
|
||||
strategy="ddp",
|
||||
strategy=FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
activation_checkpointing_policy={
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
},
|
||||
cpu_offload=False, # Set True if GPU memory insufficient
|
||||
),
|
||||
precision="bf16-mixed", # BFloat16 for A100/H100
|
||||
callbacks=[
|
||||
checkpoint_callback,
|
||||
lr_monitor,
|
||||
],
|
||||
logger=wandb_logger,
|
||||
log_every_n_steps=10,
|
||||
gradient_clip_val=1.0,
|
||||
accumulate_grad_batches=4, # Gradient accumulation
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
print("Creating different trainer configurations...\n")
|
||||
|
||||
# 1. Default trainer
|
||||
print("1. Default trainer:")
|
||||
trainer_default = create_trainer(preset="default", max_epochs=50)
|
||||
print(f" Max epochs: {trainer_default.max_epochs}")
|
||||
print(f" Accelerator: {trainer_default.accelerator}")
|
||||
print(f" Callbacks: {len(trainer_default.callbacks)}")
|
||||
print()
|
||||
# =============================================================================
|
||||
# 6. VERY LARGE MODEL TRAINING (DeepSpeed)
|
||||
# =============================================================================
|
||||
|
||||
# 2. Fast development trainer
|
||||
print("2. Fast development trainer:")
|
||||
trainer_dev = create_trainer(preset="fast_dev")
|
||||
print(f" Max epochs: {trainer_dev.max_epochs}")
|
||||
print(f" Train batches limit: {trainer_dev.limit_train_batches}")
|
||||
print()
|
||||
def deepspeed_trainer(
|
||||
max_epochs=100,
|
||||
num_gpus=8,
|
||||
stage=3,
|
||||
log_dir="logs",
|
||||
checkpoint_dir="checkpoints"
|
||||
):
|
||||
"""
|
||||
Training for very large models with DeepSpeed.
|
||||
Use for: Models >10B parameters, maximum memory efficiency
|
||||
"""
|
||||
# Callbacks
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename="{epoch:02d}-{step:06d}",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
every_n_train_steps=1000, # Save every N steps
|
||||
)
|
||||
|
||||
# 3. Production trainer
|
||||
print("3. Production trainer:")
|
||||
trainer_prod = create_trainer(
|
||||
preset="production",
|
||||
lr_monitor = LearningRateMonitor(logging_interval="step")
|
||||
|
||||
# Logger
|
||||
wandb_logger = pl_loggers.WandbLogger(
|
||||
project="very-large-model",
|
||||
save_dir=log_dir,
|
||||
)
|
||||
|
||||
# Select DeepSpeed stage
|
||||
strategy = f"deepspeed_stage_{stage}"
|
||||
|
||||
# Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="gpu",
|
||||
devices=num_gpus,
|
||||
strategy=strategy,
|
||||
precision="16-mixed",
|
||||
callbacks=[
|
||||
checkpoint_callback,
|
||||
lr_monitor,
|
||||
],
|
||||
logger=wandb_logger,
|
||||
log_every_n_steps=10,
|
||||
gradient_clip_val=1.0,
|
||||
accumulate_grad_batches=4,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. HYPERPARAMETER TUNING
|
||||
# =============================================================================
|
||||
|
||||
def hyperparameter_tuning_trainer(max_epochs=50):
|
||||
"""
|
||||
Lightweight trainer for hyperparameter search.
|
||||
Use for: Quick experiments, hyperparameter tuning
|
||||
"""
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="auto",
|
||||
devices=1,
|
||||
enable_checkpointing=False, # Don't save checkpoints
|
||||
logger=False, # Disable logging
|
||||
enable_progress_bar=False,
|
||||
limit_train_batches=0.5, # Use 50% of training data
|
||||
limit_val_batches=0.5, # Use 50% of validation data
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. OVERFITTING TEST
|
||||
# =============================================================================
|
||||
|
||||
def overfit_test_trainer(num_batches=10):
|
||||
"""
|
||||
Trainer for overfitting on small subset to verify model capacity.
|
||||
Use for: Testing if model can learn, debugging
|
||||
"""
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
experiment_name="my_experiment"
|
||||
accelerator="auto",
|
||||
devices=1,
|
||||
overfit_batches=num_batches, # Overfit on N batches
|
||||
log_every_n_steps=1,
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
print(f" Max epochs: {trainer_prod.max_epochs}")
|
||||
print(f" Precision: {trainer_prod.precision}")
|
||||
print(f" Callbacks: {len(trainer_prod.callbacks)}")
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 9. TIME-LIMITED TRAINING (SLURM)
|
||||
# =============================================================================
|
||||
|
||||
def time_limited_trainer(
|
||||
max_time_hours=23.5,
|
||||
max_epochs=1000,
|
||||
checkpoint_dir="checkpoints"
|
||||
):
|
||||
"""
|
||||
Training with time limit for SLURM clusters.
|
||||
Use for: Cluster jobs with time limits
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
save_top_k=3,
|
||||
save_last=True, # Important for resuming
|
||||
every_n_epochs=5,
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
max_time=timedelta(hours=max_time_hours),
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
callbacks=[checkpoint_callback],
|
||||
log_every_n_steps=50,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 10. REPRODUCIBLE RESEARCH
|
||||
# =============================================================================
|
||||
|
||||
def reproducible_trainer(seed=42, max_epochs=100):
|
||||
"""
|
||||
Fully reproducible trainer for research papers.
|
||||
Use for: Publications, reproducible results
|
||||
"""
|
||||
# Set seed
|
||||
L.seed_everything(seed, workers=True)
|
||||
|
||||
# Callbacks
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="checkpoints",
|
||||
filename="{epoch:02d}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
|
||||
# Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=max_epochs,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
precision="32-true", # Full precision for reproducibility
|
||||
deterministic=True, # Use deterministic algorithms
|
||||
benchmark=False, # Disable cudnn benchmarking
|
||||
callbacks=[checkpoint_callback],
|
||||
log_every_n_steps=50,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# USAGE EXAMPLES
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("PyTorch Lightning Trainer Configurations\n")
|
||||
|
||||
# Example 1: Basic training
|
||||
print("1. Basic Trainer:")
|
||||
trainer = basic_trainer()
|
||||
print(f" - Max epochs: {trainer.max_epochs}")
|
||||
print(f" - Accelerator: {trainer.accelerator}")
|
||||
print()
|
||||
|
||||
# 4. Debugging trainer
|
||||
print("4. Debugging trainer:")
|
||||
trainer_debug = create_debugging_trainer()
|
||||
print(f" Max epochs: {trainer_debug.max_epochs}")
|
||||
print(f" Train batches: {trainer_debug.limit_train_batches}")
|
||||
# Example 2: Debug training
|
||||
print("2. Debug Trainer:")
|
||||
trainer = debug_trainer()
|
||||
print(f" - Fast dev run: {trainer.fast_dev_run}")
|
||||
print(f" - Detect anomaly: {trainer.detect_anomaly}")
|
||||
print()
|
||||
|
||||
# 5. GPU trainer
|
||||
print("5. GPU trainer:")
|
||||
trainer_gpu = create_gpu_trainer(num_gpus=1)
|
||||
print(f" Accelerator: {trainer_gpu.accelerator}")
|
||||
print(f" Precision: {trainer_gpu.precision}")
|
||||
# Example 3: Production single GPU
|
||||
print("3. Production Single GPU Trainer:")
|
||||
trainer = production_single_gpu_trainer(max_epochs=100)
|
||||
print(f" - Max epochs: {trainer.max_epochs}")
|
||||
print(f" - Precision: {trainer.precision}")
|
||||
print(f" - Callbacks: {len(trainer.callbacks)}")
|
||||
print()
|
||||
|
||||
print("All trainer configurations created successfully!")
|
||||
# Example 4: Multi-GPU DDP
|
||||
print("4. Multi-GPU DDP Trainer:")
|
||||
trainer = multi_gpu_ddp_trainer(num_gpus=4)
|
||||
print(f" - Strategy: {trainer.strategy}")
|
||||
print(f" - Devices: {trainer.num_devices}")
|
||||
print()
|
||||
|
||||
# Example 5: FSDP for large models
|
||||
print("5. FSDP Trainer for Large Models:")
|
||||
trainer = large_model_fsdp_trainer(num_gpus=8)
|
||||
print(f" - Strategy: {trainer.strategy}")
|
||||
print(f" - Precision: {trainer.precision}")
|
||||
print()
|
||||
|
||||
print("\nTo use these configurations:")
|
||||
print("1. Import the desired function")
|
||||
print("2. Create trainer: trainer = production_single_gpu_trainer()")
|
||||
print("3. Train model: trainer.fit(model, datamodule=dm)")
|
||||
|
||||
@@ -1,221 +1,328 @@
|
||||
"""
|
||||
Template for creating a PyTorch Lightning DataModule.
|
||||
|
||||
This template includes all common hooks and patterns for organizing
|
||||
data processing workflows with best practices.
|
||||
This template provides a complete boilerplate for building a LightningDataModule
|
||||
with all essential methods and best practices for data handling.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
import torch
|
||||
|
||||
|
||||
class TemplateDataset(Dataset):
|
||||
"""Example dataset - replace with your actual dataset."""
|
||||
class CustomDataset(Dataset):
|
||||
"""
|
||||
Custom Dataset implementation.
|
||||
|
||||
def __init__(self, data, targets, transform=None):
|
||||
self.data = data
|
||||
self.targets = targets
|
||||
Replace this with your actual dataset implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, transform=None):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
Args:
|
||||
data_path: Path to data directory
|
||||
transform: Optional transforms to apply
|
||||
"""
|
||||
self.data_path = data_path
|
||||
self.transform = transform
|
||||
|
||||
# Load your data here
|
||||
# self.data = load_data(data_path)
|
||||
# self.labels = load_labels(data_path)
|
||||
|
||||
# Placeholder data
|
||||
self.data = torch.randn(1000, 3, 224, 224)
|
||||
self.labels = torch.randint(0, 10, (1000,))
|
||||
|
||||
def __len__(self):
|
||||
"""Return the size of the dataset."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.data[idx]
|
||||
y = self.targets[idx]
|
||||
"""
|
||||
Get a single item from the dataset.
|
||||
|
||||
Args:
|
||||
idx: Index of the item
|
||||
|
||||
Returns:
|
||||
Tuple of (data, label)
|
||||
"""
|
||||
sample = self.data[idx]
|
||||
label = self.labels[idx]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
sample = self.transform(sample)
|
||||
|
||||
return x, y
|
||||
return sample, label
|
||||
|
||||
|
||||
class TemplateDataModule(L.LightningDataModule):
|
||||
"""Template DataModule with all common hooks and patterns."""
|
||||
"""
|
||||
Template LightningDataModule for data handling.
|
||||
|
||||
This class encapsulates all data processing steps:
|
||||
1. Download/prepare data (prepare_data)
|
||||
2. Create datasets (setup)
|
||||
3. Create dataloaders (train/val/test/predict_dataloader)
|
||||
|
||||
Args:
|
||||
data_dir: Directory containing the data
|
||||
batch_size: Batch size for dataloaders
|
||||
num_workers: Number of workers for data loading
|
||||
train_val_split: Train/validation split ratio
|
||||
pin_memory: Whether to pin memory for faster GPU transfer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str = "./data",
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
train_val_split: tuple = (0.8, 0.2),
|
||||
seed: int = 42,
|
||||
train_val_split: float = 0.8,
|
||||
pin_memory: bool = True,
|
||||
persistent_workers: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Save hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Initialize attributes
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.train_val_split = train_val_split
|
||||
self.seed = seed
|
||||
self.pin_memory = pin_memory
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
# Placeholders for datasets
|
||||
# Initialize as None (will be set in setup)
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
self.test_dataset = None
|
||||
self.predict_dataset = None
|
||||
|
||||
# Placeholder for transforms
|
||||
self.train_transform = None
|
||||
self.val_transform = None
|
||||
self.test_transform = None
|
||||
|
||||
def prepare_data(self):
|
||||
"""
|
||||
Download and prepare data (called only on 1 GPU/TPU in distributed settings).
|
||||
Use this for downloading, tokenizing, etc. Do NOT set state here (no self.x = y).
|
||||
Download and prepare data.
|
||||
|
||||
This method is called only once and on a single process.
|
||||
Do not set state here (e.g., self.x = y) because it's not
|
||||
transferred to other processes.
|
||||
|
||||
Use this for:
|
||||
- Downloading datasets
|
||||
- Tokenizing text
|
||||
- Saving processed data to disk
|
||||
"""
|
||||
# Example: Download datasets
|
||||
# datasets.MNIST(self.data_dir, train=True, download=True)
|
||||
# datasets.MNIST(self.data_dir, train=False, download=True)
|
||||
# Example: Download data if not exists
|
||||
# if not os.path.exists(self.hparams.data_dir):
|
||||
# download_dataset(self.hparams.data_dir)
|
||||
|
||||
# Example: Process and save data
|
||||
# process_and_save(self.hparams.data_dir)
|
||||
|
||||
pass
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
"""
|
||||
Load data and create train/val/test splits (called on every GPU/TPU in distributed).
|
||||
Use this for splitting, creating datasets, etc. Setting state is OK here (self.x = y).
|
||||
Create datasets for each stage.
|
||||
|
||||
This method is called on every process in distributed training.
|
||||
Set state here (e.g., self.train_dataset = ...).
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'
|
||||
stage: Current stage ('fit', 'validate', 'test', or 'predict')
|
||||
"""
|
||||
# Define transforms
|
||||
train_transform = self._get_train_transforms()
|
||||
test_transform = self._get_test_transforms()
|
||||
|
||||
# Fit stage: setup training and validation datasets
|
||||
# Setup for training and validation
|
||||
if stage == "fit" or stage is None:
|
||||
# Load full dataset
|
||||
# Example: full_dataset = datasets.MNIST(self.data_dir, train=True, transform=self.train_transform)
|
||||
|
||||
# Create dummy data for template
|
||||
full_data = torch.randn(1000, 784)
|
||||
full_targets = torch.randint(0, 10, (1000,))
|
||||
full_dataset = TemplateDataset(full_data, full_targets, transform=self.train_transform)
|
||||
full_dataset = CustomDataset(
|
||||
self.hparams.data_dir, transform=train_transform
|
||||
)
|
||||
|
||||
# Split into train and validation
|
||||
train_size = int(len(full_dataset) * self.train_val_split[0])
|
||||
train_size = int(self.hparams.train_val_split * len(full_dataset))
|
||||
val_size = len(full_dataset) - train_size
|
||||
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
full_dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(self.seed)
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
)
|
||||
|
||||
# Apply validation transform if different from train
|
||||
if self.val_transform:
|
||||
self.val_dataset.dataset.transform = self.val_transform
|
||||
# Apply test transforms to validation set
|
||||
# (Note: random_split doesn't support different transforms,
|
||||
# you may need to implement a custom wrapper)
|
||||
|
||||
# Test stage: setup test dataset
|
||||
# Setup for testing
|
||||
if stage == "test" or stage is None:
|
||||
# Example: self.test_dataset = datasets.MNIST(
|
||||
# self.data_dir, train=False, transform=self.test_transform
|
||||
# )
|
||||
self.test_dataset = CustomDataset(
|
||||
self.hparams.data_dir, transform=test_transform
|
||||
)
|
||||
|
||||
# Create dummy test data for template
|
||||
test_data = torch.randn(200, 784)
|
||||
test_targets = torch.randint(0, 10, (200,))
|
||||
self.test_dataset = TemplateDataset(test_data, test_targets, transform=self.test_transform)
|
||||
# Setup for prediction
|
||||
if stage == "predict":
|
||||
self.predict_dataset = CustomDataset(
|
||||
self.hparams.data_dir, transform=test_transform
|
||||
)
|
||||
|
||||
# Predict stage: setup prediction dataset
|
||||
if stage == "predict" or stage is None:
|
||||
# Example: self.predict_dataset = YourCustomDataset(...)
|
||||
def _get_train_transforms(self):
|
||||
"""
|
||||
Define training transforms/augmentations.
|
||||
|
||||
# Create dummy predict data for template
|
||||
predict_data = torch.randn(100, 784)
|
||||
predict_targets = torch.zeros(100, dtype=torch.long)
|
||||
self.predict_dataset = TemplateDataset(predict_data, predict_targets)
|
||||
Returns:
|
||||
Training transforms
|
||||
"""
|
||||
# Example with torchvision:
|
||||
# from torchvision import transforms
|
||||
# return transforms.Compose([
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
# transforms.RandomRotation(10),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
# std=[0.229, 0.224, 0.225])
|
||||
# ])
|
||||
|
||||
return None
|
||||
|
||||
def _get_test_transforms(self):
|
||||
"""
|
||||
Define test/validation transforms (no augmentation).
|
||||
|
||||
Returns:
|
||||
Test/validation transforms
|
||||
"""
|
||||
# Example with torchvision:
|
||||
# from torchvision import transforms
|
||||
# return transforms.Compose([
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
# std=[0.229, 0.224, 0.225])
|
||||
# ])
|
||||
|
||||
return None
|
||||
|
||||
def train_dataloader(self):
|
||||
"""Return training dataloader."""
|
||||
"""
|
||||
Create training dataloader.
|
||||
|
||||
Returns:
|
||||
Training DataLoader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
num_workers=self.hparams.num_workers,
|
||||
pin_memory=self.hparams.pin_memory,
|
||||
persistent_workers=True if self.hparams.num_workers > 0 else False,
|
||||
drop_last=True, # Drop last incomplete batch
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
"""Return validation dataloader."""
|
||||
"""
|
||||
Create validation dataloader.
|
||||
|
||||
Returns:
|
||||
Validation DataLoader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
num_workers=self.hparams.num_workers,
|
||||
pin_memory=self.hparams.pin_memory,
|
||||
persistent_workers=True if self.hparams.num_workers > 0 else False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""Return test dataloader."""
|
||||
"""
|
||||
Create test dataloader.
|
||||
|
||||
Returns:
|
||||
Test DataLoader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
num_workers=self.hparams.num_workers,
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""Return prediction dataloader."""
|
||||
"""
|
||||
Create prediction dataloader.
|
||||
|
||||
Returns:
|
||||
Prediction DataLoader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.predict_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
num_workers=self.hparams.num_workers,
|
||||
)
|
||||
|
||||
def teardown(self, stage: str = None):
|
||||
"""Clean up after fit, validate, test, or predict."""
|
||||
# Example: close database connections, clear caches, etc.
|
||||
pass
|
||||
# Optional: State management for checkpointing
|
||||
|
||||
def state_dict(self):
|
||||
"""Save state for checkpointing."""
|
||||
# Return anything you want to save in the checkpoint
|
||||
return {}
|
||||
"""
|
||||
Save DataModule state for checkpointing.
|
||||
|
||||
Returns:
|
||||
State dictionary
|
||||
"""
|
||||
return {"train_val_split": self.hparams.train_val_split}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state from checkpoint."""
|
||||
# Restore state from checkpoint
|
||||
pass
|
||||
"""
|
||||
Restore DataModule state from checkpoint.
|
||||
|
||||
Args:
|
||||
state_dict: State dictionary
|
||||
"""
|
||||
self.hparams.train_val_split = state_dict["train_val_split"]
|
||||
|
||||
# Optional: Teardown for cleanup
|
||||
|
||||
def teardown(self, stage: str = None):
|
||||
"""
|
||||
Cleanup after training/testing/prediction.
|
||||
|
||||
Args:
|
||||
stage: Current stage ('fit', 'validate', 'test', or 'predict')
|
||||
"""
|
||||
# Clean up resources
|
||||
if stage == "fit":
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
elif stage == "test":
|
||||
self.test_dataset = None
|
||||
elif stage == "predict":
|
||||
self.predict_dataset = None
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create datamodule
|
||||
datamodule = TemplateDataModule(
|
||||
# Create DataModule
|
||||
dm = TemplateDataModule(
|
||||
data_dir="./data",
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
train_val_split=(0.8, 0.2),
|
||||
batch_size=64,
|
||||
num_workers=8,
|
||||
train_val_split=0.8,
|
||||
)
|
||||
|
||||
# Prepare and setup data
|
||||
datamodule.prepare_data()
|
||||
datamodule.setup("fit")
|
||||
# Setup for training
|
||||
dm.prepare_data()
|
||||
dm.setup(stage="fit")
|
||||
|
||||
# Get dataloaders
|
||||
train_loader = datamodule.train_dataloader()
|
||||
val_loader = datamodule.val_dataloader()
|
||||
train_loader = dm.train_dataloader()
|
||||
val_loader = dm.val_dataloader()
|
||||
|
||||
print("Template DataModule created successfully!")
|
||||
print(f"Train dataset size: {len(dm.train_dataset)}")
|
||||
print(f"Validation dataset size: {len(dm.val_dataset)}")
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
print(f"Batch size: {datamodule.batch_size}")
|
||||
print(f"Validation batches: {len(val_loader)}")
|
||||
|
||||
# Test a batch
|
||||
batch = next(iter(train_loader))
|
||||
x, y = batch
|
||||
print(f"Batch shape: {x.shape}, {y.shape}")
|
||||
# Example: Use with Trainer
|
||||
# from template_lightning_module import TemplateLightningModule
|
||||
# model = TemplateLightningModule()
|
||||
# trainer = L.Trainer(max_epochs=10)
|
||||
# trainer.fit(model, datamodule=dm)
|
||||
|
||||
@@ -1,190 +1,197 @@
|
||||
"""
|
||||
Template for creating a PyTorch Lightning LightningModule.
|
||||
Template for creating a PyTorch Lightning Module.
|
||||
|
||||
This template includes all common hooks and patterns for building
|
||||
a Lightning model with best practices.
|
||||
This template provides a complete boilerplate for building a LightningModule
|
||||
with all essential methods and best practices.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam, SGD
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
|
||||
class TemplateLightningModule(L.LightningModule):
|
||||
"""Template LightningModule with all common hooks and patterns."""
|
||||
"""
|
||||
Template LightningModule for building deep learning models.
|
||||
|
||||
Args:
|
||||
learning_rate: Learning rate for optimizer
|
||||
hidden_dim: Hidden dimension size
|
||||
dropout: Dropout probability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Model architecture parameters
|
||||
input_dim: int = 784,
|
||||
hidden_dim: int = 128,
|
||||
output_dim: int = 10,
|
||||
# Optimization parameters
|
||||
learning_rate: float = 1e-3,
|
||||
optimizer_type: str = "adam",
|
||||
scheduler_type: str = None,
|
||||
# Other hyperparameters
|
||||
learning_rate: float = 0.001,
|
||||
hidden_dim: int = 256,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Save hyperparameters for checkpointing and logging
|
||||
# Save hyperparameters (accessible via self.hparams)
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Define model architecture
|
||||
# Define your model architecture
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.Linear(784, self.hparams.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
nn.Dropout(self.hparams.dropout),
|
||||
nn.Linear(self.hparams.hidden_dim, 10),
|
||||
)
|
||||
|
||||
# Define loss function
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# For tracking validation outputs (optional)
|
||||
self.validation_step_outputs = []
|
||||
# Optional: Define metrics
|
||||
# from torchmetrics import Accuracy
|
||||
# self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
|
||||
# self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass for inference."""
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
Model output
|
||||
"""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Training step - called for each training batch."""
|
||||
"""
|
||||
Training step (called for each training batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
# Calculate accuracy (optional)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
|
||||
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
self.log("train/acc", acc, on_step=True, on_epoch=True)
|
||||
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""Validation step - called for each validation batch."""
|
||||
x, y = batch
|
||||
"""
|
||||
Validation step (called for each validation batch).
|
||||
|
||||
# Forward pass (model automatically in eval mode)
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
|
||||
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
# Optional: store outputs for epoch-level processing
|
||||
self.validation_step_outputs.append({"loss": loss, "acc": acc})
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
"""Called at the end of validation epoch."""
|
||||
# Optional: process all validation outputs
|
||||
if self.validation_step_outputs:
|
||||
avg_loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()
|
||||
avg_acc = torch.stack([x["acc"] for x in self.validation_step_outputs]).mean()
|
||||
|
||||
# Log epoch-level metrics if needed
|
||||
# self.log("val_epoch_loss", avg_loss)
|
||||
# self.log("val_epoch_acc", avg_acc)
|
||||
|
||||
# Clear outputs
|
||||
self.validation_step_outputs.clear()
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""Test step - called for each test batch."""
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics (automatically aggregated across batches)
|
||||
self.log("val/loss", loss, on_epoch=True, prog_bar=True, sync_dist=True)
|
||||
self.log("val/acc", acc, on_epoch=True, prog_bar=True, sync_dist=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""
|
||||
Test step (called for each test batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.log("test_acc", acc, on_step=False, on_epoch=True)
|
||||
|
||||
return loss
|
||||
self.log("test/loss", loss, on_epoch=True)
|
||||
self.log("test/acc", acc, on_epoch=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
"""Prediction step - called for each prediction batch."""
|
||||
"""
|
||||
Prediction step (called for each prediction batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
dataloader_idx: Index of the dataloader (if multiple)
|
||||
|
||||
Returns:
|
||||
Predictions
|
||||
"""
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
return preds
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Configure optimizer and learning rate scheduler."""
|
||||
# Create optimizer
|
||||
if self.hparams.optimizer_type.lower() == "adam":
|
||||
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
elif self.hparams.optimizer_type.lower() == "sgd":
|
||||
optimizer = SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {self.hparams.optimizer_type}")
|
||||
"""
|
||||
Configure optimizers and learning rate schedulers.
|
||||
|
||||
# Configure with scheduler if specified
|
||||
if self.hparams.scheduler_type:
|
||||
if self.hparams.scheduler_type.lower() == "reduce_on_plateau":
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
elif self.hparams.scheduler_type.lower() == "step":
|
||||
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
Returns:
|
||||
Optimizer and scheduler configuration
|
||||
"""
|
||||
# Define optimizer
|
||||
optimizer = Adam(
|
||||
self.parameters(),
|
||||
lr=self.hparams.learning_rate,
|
||||
weight_decay=1e-5,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
# Define scheduler
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Optional: Additional hooks for custom behavior
|
||||
# Return configuration
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val/loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
},
|
||||
}
|
||||
|
||||
def on_train_start(self):
|
||||
"""Called at the beginning of training."""
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
"""Called at the beginning of each training epoch."""
|
||||
pass
|
||||
# Optional: Add custom methods for model-specific logic
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
"""Called at the end of each training epoch."""
|
||||
# Example: Log custom metrics
|
||||
pass
|
||||
|
||||
def on_train_end(self):
|
||||
"""Called at the end of training."""
|
||||
def on_validation_epoch_end(self):
|
||||
"""Called at the end of each validation epoch."""
|
||||
# Example: Compute epoch-level metrics
|
||||
pass
|
||||
|
||||
|
||||
@@ -192,24 +199,21 @@ class TemplateLightningModule(L.LightningModule):
|
||||
if __name__ == "__main__":
|
||||
# Create model
|
||||
model = TemplateLightningModule(
|
||||
input_dim=784,
|
||||
hidden_dim=128,
|
||||
output_dim=10,
|
||||
learning_rate=1e-3,
|
||||
optimizer_type="adam",
|
||||
scheduler_type="reduce_on_plateau"
|
||||
learning_rate=0.001,
|
||||
hidden_dim=256,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
accelerator="auto",
|
||||
devices=1,
|
||||
log_every_n_steps=50,
|
||||
devices="auto",
|
||||
logger=True,
|
||||
)
|
||||
|
||||
# Note: You would need to provide dataloaders
|
||||
# Train (you need to provide train_dataloader and val_dataloader)
|
||||
# trainer.fit(model, train_dataloader, val_dataloader)
|
||||
|
||||
print("Template LightningModule created successfully!")
|
||||
print(f"Model hyperparameters: {model.hparams}")
|
||||
print(f"Model created with {model.num_parameters:,} parameters")
|
||||
print(f"Hyperparameters: {model.hparams}")
|
||||
|
||||
Reference in New Issue
Block a user