Add more scientific skills

This commit is contained in:
Timothy Kassis
2025-10-19 14:12:02 -07:00
parent 78d5ac2b56
commit 660c8574d0
210 changed files with 88957 additions and 1 deletions

View File

@@ -0,0 +1,660 @@
---
name: pytorch-lightning
description: Comprehensive toolkit for PyTorch Lightning, a deep learning framework for organizing PyTorch code. Use this skill when working with PyTorch Lightning for training deep learning models, implementing LightningModules, configuring Trainers, setting up distributed training, creating DataModules, or converting existing PyTorch code to Lightning format. The skill provides templates, reference documentation, and best practices for efficient deep learning workflows.
---
# PyTorch Lightning
## Overview
PyTorch Lightning is a deep learning framework that organizes PyTorch code to decouple research from engineering. It automates training loop complexity (multi-GPU, mixed precision, checkpointing, logging) while maintaining full flexibility over model architecture and training logic.
**Core Philosophy:** Separate concerns
- **LightningModule** - Research code (model architecture, training logic)
- **Trainer** - Engineering automation (hardware, optimization, logging)
- **DataModule** - Data processing (downloading, loading, transforms)
- **Callbacks** - Non-essential functionality (checkpointing, early stopping)
## When to Use This Skill
Use this skill when:
- Building or training deep learning models with PyTorch
- Converting existing PyTorch code to Lightning structure
- Setting up distributed training across multiple GPUs or nodes
- Implementing custom training loops with validation and testing
- Organizing data processing pipelines
- Configuring experiment logging and model checkpointing
- Optimizing training performance and memory usage
- Working with large models requiring model parallelism
## Quick Start
### Basic Lightning Workflow
1. **Define a LightningModule** (organize your model)
2. **Create a DataModule or DataLoaders** (organize your data)
3. **Configure a Trainer** (automate training)
4. **Train** with `trainer.fit()`
### Minimal Example
```python
import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# 1. Define LightningModule
class SimpleModel(L.LightningModule):
def __init__(self, input_dim, output_dim):
super().__init__()
self.save_hyperparameters()
self.model = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 2. Prepare data
train_data = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
train_loader = DataLoader(train_data, batch_size=32)
# 3. Create Trainer
trainer = L.Trainer(max_epochs=10, accelerator='auto')
# 4. Train
model = SimpleModel(input_dim=10, output_dim=1)
trainer.fit(model, train_loader)
```
## Core Workflows
### 1. Creating a LightningModule
Structure model code by implementing essential hooks:
**Template:** Use `scripts/template_lightning_module.py` as a starting point.
```python
class MyLightningModule(L.LightningModule):
def __init__(self, hyperparameters):
super().__init__()
self.save_hyperparameters() # Save for checkpointing
self.model = YourModel()
def forward(self, x):
"""Inference forward pass."""
return self.model(x)
def training_step(self, batch, batch_idx):
"""Define training loop logic."""
x, y = batch
y_hat = self(x)
loss = self.compute_loss(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
"""Define validation logic."""
x, y = batch
y_hat = self(x)
loss = self.compute_loss(y_hat, y)
self.log('val_loss', loss, on_epoch=True)
return loss
def configure_optimizers(self):
"""Return optimizer and optional scheduler."""
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
}
}
```
**Key Points:**
- Use `self.save_hyperparameters()` to automatically save init args
- Use `self.log()` to track metrics across loggers
- Return loss from training_step for automatic optimization
- Keep model architecture separate from training logic
### 2. Creating a DataModule
Organize all data processing in a reusable module:
**Template:** Use `scripts/template_datamodule.py` as a starting point.
```python
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size=32):
super().__init__()
self.save_hyperparameters()
def prepare_data(self):
"""Download data (called once, single process)."""
# Download datasets, tokenize, etc.
pass
def setup(self, stage=None):
"""Create datasets (called on every process)."""
if stage == 'fit' or stage is None:
# Create train/val datasets
self.train_dataset = ...
self.val_dataset = ...
if stage == 'test' or stage is None:
# Create test dataset
self.test_dataset = ...
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size)
```
**Key Points:**
- `prepare_data()` for downloading (single process)
- `setup()` for creating datasets (every process)
- Use `stage` parameter to separate fit/test logic
- Makes data code reusable across projects
### 3. Configuring the Trainer
The Trainer automates training complexity:
**Helper:** Use `scripts/quick_trainer_setup.py` for preset configurations.
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
trainer = L.Trainer(
# Training duration
max_epochs=100,
# Hardware
accelerator='auto', # 'cpu', 'gpu', 'tpu'
devices=1, # Number of devices or specific IDs
# Optimization
precision='16-mixed', # Mixed precision training
gradient_clip_val=1.0,
accumulate_grad_batches=4, # Gradient accumulation
# Validation
check_val_every_n_epoch=1,
val_check_interval=1.0, # Validate every epoch
# Logging
log_every_n_steps=50,
logger=TensorBoardLogger('logs/'),
# Callbacks
callbacks=[
ModelCheckpoint(monitor='val_loss', mode='min'),
EarlyStopping(monitor='val_loss', patience=10),
],
# Debugging
fast_dev_run=False, # Quick test with few batches
enable_progress_bar=True,
)
```
**Common Presets:**
```python
from scripts.quick_trainer_setup import create_trainer
# Development preset (fast debugging)
trainer = create_trainer(preset='fast_dev', max_epochs=3)
# Production preset (full features)
trainer = create_trainer(preset='production', max_epochs=100)
# Distributed preset (multi-GPU)
trainer = create_trainer(preset='distributed', devices=4)
```
### 4. Training and Evaluation
```python
# Training
trainer.fit(model, datamodule=dm)
# Or with dataloaders
trainer.fit(model, train_loader, val_loader)
# Resume from checkpoint
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
# Testing
trainer.test(model, datamodule=dm)
# Or load best checkpoint
trainer.test(ckpt_path='best', datamodule=dm)
# Prediction
predictions = trainer.predict(model, predict_loader)
# Validation only
trainer.validate(model, datamodule=dm)
```
### 5. Distributed Training
Lightning handles distributed training automatically:
```python
# Single machine, multiple GPUs (Data Parallel)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='ddp', # DistributedDataParallel
)
# Multiple machines, multiple GPUs
trainer = L.Trainer(
accelerator='gpu',
devices=4, # GPUs per node
num_nodes=8, # Number of machines
strategy='ddp',
)
# Large models (Model Parallel with FSDP)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='fsdp', # Fully Sharded Data Parallel
)
# Large models (Model Parallel with DeepSpeed)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='deepspeed_stage_2',
precision='16-mixed',
)
```
**For detailed distributed training guide, see:** `references/distributed_training.md`
**Strategy Selection:**
- Models < 500M params → Use `ddp`
- Models > 500M params → Use `fsdp` or `deepspeed`
- Maximum memory efficiency → Use DeepSpeed Stage 3 with offloading
- Native PyTorch → Use `fsdp`
- Cutting-edge features → Use `deepspeed`
### 6. Callbacks
Extend training with modular functionality:
```python
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
)
callbacks = [
# Save best models
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='{epoch}-{val_loss:.2f}',
),
# Stop when no improvement
EarlyStopping(
monitor='val_loss',
patience=10,
mode='min',
),
# Log learning rate
LearningRateMonitor(logging_interval='epoch'),
# Rich progress bar
RichProgressBar(),
]
trainer = L.Trainer(callbacks=callbacks)
```
**Custom Callbacks:**
```python
from lightning.pytorch.callbacks import Callback
class MyCustomCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Custom logic at end of each epoch
print(f"Epoch {trainer.current_epoch} completed")
def on_validation_end(self, trainer, pl_module):
val_loss = trainer.callback_metrics.get('val_loss')
# Custom validation logic
pass
```
### 7. Logging
Track experiments with various loggers:
```python
from lightning.pytorch.loggers import (
TensorBoardLogger,
WandbLogger,
CSVLogger,
MLFlowLogger,
)
# Single logger
logger = TensorBoardLogger('logs/', name='my_experiment')
# Multiple loggers
loggers = [
TensorBoardLogger('logs/'),
WandbLogger(project='my_project'),
CSVLogger('logs/'),
]
trainer = L.Trainer(logger=loggers)
```
**Logging in LightningModule:**
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log single metric
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# Log multiple metrics
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
self.log_dict(metrics, on_step=True, on_epoch=True)
return loss
```
## Converting Existing PyTorch Code
### Standard PyTorch → Lightning
**Before (PyTorch):**
```python
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(num_epochs):
for batch in train_loader:
optimizer.zero_grad()
x, y = batch
y_hat = model(x)
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
```
**After (Lightning):**
```python
class MyLightningModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
trainer = L.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader)
```
**Key Changes:**
1. Wrap model in LightningModule
2. Move training loop logic to `training_step()`
3. Move optimizer setup to `configure_optimizers()`
4. Replace manual loop with `trainer.fit()`
5. Lightning handles: `.zero_grad()`, `.backward()`, `.step()`, device placement
## Common Patterns
### Reproducibility
```python
from lightning.pytorch import seed_everything
# Set seed for reproducibility
seed_everything(42, workers=True)
trainer = L.Trainer(deterministic=True)
```
### Mixed Precision Training
```python
# 16-bit mixed precision
trainer = L.Trainer(precision='16-mixed')
# BFloat16 mixed precision (more stable)
trainer = L.Trainer(precision='bf16-mixed')
```
### Gradient Accumulation
```python
# Effective batch size = 4x actual batch size
trainer = L.Trainer(accumulate_grad_batches=4)
```
### Learning Rate Finding
```python
from lightning.pytorch.tuner import Tuner
trainer = L.Trainer()
tuner = Tuner(trainer)
# Find optimal learning rate
lr_finder = tuner.lr_find(model, train_dataloader)
model.hparams.learning_rate = lr_finder.suggestion()
# Find optimal batch size
tuner.scale_batch_size(model, mode="power")
```
### Checkpointing and Loading
```python
# Save checkpoint
trainer.fit(model, datamodule=dm)
# Checkpoint automatically saved to checkpoints/
# Load from checkpoint
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
# Resume training
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
# Test from checkpoint
trainer.test(ckpt_path='best', datamodule=dm)
```
### Debugging
```python
# Quick test with few batches
trainer = L.Trainer(fast_dev_run=10)
# Overfit on small data (debug model)
trainer = L.Trainer(overfit_batches=100)
# Limit batches for quick iteration
trainer = L.Trainer(
limit_train_batches=100,
limit_val_batches=50,
)
# Profile training
trainer = L.Trainer(profiler='simple') # or 'advanced'
```
## Best Practices
### Code Organization
1. **Separate concerns:**
- Model architecture in `__init__()`
- Training logic in `training_step()`
- Validation logic in `validation_step()`
- Data processing in DataModule
2. **Use `save_hyperparameters()`:**
```python
def __init__(self, lr, hidden_dim, dropout):
super().__init__()
self.save_hyperparameters() # Automatically saves all args
```
3. **Device-agnostic code:**
```python
# Avoid manual device placement
# BAD: tensor.cuda()
# GOOD: Lightning handles this automatically
# Create tensors on model's device
new_tensor = torch.zeros(10, device=self.device)
```
4. **Log comprehensively:**
```python
self.log('metric', value, on_step=True, on_epoch=True, prog_bar=True)
```
### Performance Optimization
1. **Use DataLoader best practices:**
```python
DataLoader(
dataset,
batch_size=32,
num_workers=4, # Multiple workers
pin_memory=True, # Faster GPU transfer
persistent_workers=True, # Keep workers alive
)
```
2. **Enable benchmark mode for fixed input sizes:**
```python
trainer = L.Trainer(benchmark=True)
```
3. **Use gradient clipping:**
```python
trainer = L.Trainer(gradient_clip_val=1.0)
```
4. **Enable mixed precision:**
```python
trainer = L.Trainer(precision='16-mixed')
```
### Distributed Training
1. **Sync metrics across devices:**
```python
self.log('metric', value, sync_dist=True)
```
2. **Rank-specific operations:**
```python
if self.trainer.is_global_zero:
# Only run on main process
self.save_artifacts()
```
3. **Use appropriate strategy:**
- Small models → `ddp`
- Large models → `fsdp` or `deepspeed`
## Resources
### Scripts
Executable templates for quick implementation:
- **`template_lightning_module.py`** - Complete LightningModule template with all hooks, logging, and optimization patterns
- **`template_datamodule.py`** - Complete DataModule template with data loading, splitting, and transformation patterns
- **`quick_trainer_setup.py`** - Helper functions to create Trainers with preset configurations (development, production, distributed)
### References
Comprehensive documentation for deep-dive learning:
- **`api_reference.md`** - Complete API reference covering LightningModule hooks, Trainer parameters, Callbacks, DataModules, Loggers, and common patterns
- **`distributed_training.md`** - In-depth guide for distributed training strategies (DDP, FSDP, DeepSpeed), multi-node setup, memory optimization, and troubleshooting
Load references when needing detailed information:
```python
# Example: Load distributed training reference
# See references/distributed_training.md for comprehensive distributed training guide
```
## Troubleshooting
### Common Issues
**Out of Memory:**
- Reduce batch size
- Use gradient accumulation
- Enable mixed precision (`precision='16-mixed'`)
- Use FSDP or DeepSpeed for large models
- Enable activation checkpointing
**Slow Training:**
- Use multiple DataLoader workers (`num_workers > 0`)
- Enable `pin_memory=True` and `persistent_workers=True`
- Enable `benchmark=True` for fixed input sizes
- Profile with `profiler='simple'`
**Validation Not Running:**
- Check `check_val_every_n_epoch` setting
- Ensure validation data provided
- Verify `validation_step()` implemented
**Checkpoints Not Saving:**
- Ensure `enable_checkpointing=True`
- Check `ModelCheckpoint` callback configuration
- Verify `monitor` metric exists in logs
## Additional Resources
- Official Documentation: https://lightning.ai/docs/pytorch/stable/
- GitHub: https://github.com/Lightning-AI/lightning
- Community: https://lightning.ai/community
When unclear about specific functionality, refer to `references/api_reference.md` for detailed API documentation or `references/distributed_training.md` for distributed training specifics.

View File

@@ -0,0 +1,490 @@
# PyTorch Lightning API Reference
Comprehensive reference for PyTorch Lightning core APIs, hooks, and components.
## LightningModule
The LightningModule is the core abstraction for organizing PyTorch code in Lightning.
### Essential Hooks
#### `__init__(self, *args, **kwargs)`
Initialize the model, define layers, and save hyperparameters.
```python
def __init__(self, learning_rate=1e-3, hidden_dim=128):
super().__init__()
self.save_hyperparameters() # Saves all args to self.hparams
self.model = nn.Sequential(...)
```
#### `forward(self, x)`
Define the forward pass for inference. Called by `predict_step` by default.
```python
def forward(self, x):
return self.model(x)
```
#### `training_step(self, batch, batch_idx)`
Define the training loop logic. Return loss for automatic optimization.
```python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
```
#### `validation_step(self, batch, batch_idx)`
Define the validation loop logic. Model automatically in eval mode with no gradients.
```python
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
return loss
```
#### `test_step(self, batch, batch_idx)`
Define the test loop logic. Only runs when `trainer.test()` is called.
```python
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
return loss
```
#### `predict_step(self, batch, batch_idx, dataloader_idx=0)`
Define prediction logic for inference. Defaults to calling `forward()`.
```python
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
return self(x)
```
#### `configure_optimizers(self)`
Return optimizer(s) and optional learning rate scheduler(s).
```python
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min')
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
}
}
```
### Lifecycle Hooks
#### Epoch-Level Hooks
- `on_train_epoch_start()` - Called at the start of each training epoch
- `on_train_epoch_end()` - Called at the end of each training epoch
- `on_validation_epoch_start()` - Called at the start of validation epoch
- `on_validation_epoch_end()` - Called at the end of validation epoch
- `on_test_epoch_start()` - Called at the start of test epoch
- `on_test_epoch_end()` - Called at the end of test epoch
#### Batch-Level Hooks
- `on_train_batch_start(batch, batch_idx)` - Called before training batch
- `on_train_batch_end(outputs, batch, batch_idx)` - Called after training batch
- `on_validation_batch_start(batch, batch_idx)` - Called before validation batch
- `on_validation_batch_end(outputs, batch, batch_idx)` - Called after validation batch
#### Training Lifecycle
- `on_fit_start()` - Called at the start of fit
- `on_fit_end()` - Called at the end of fit
- `on_train_start()` - Called at the start of training
- `on_train_end()` - Called at the end of training
### Logging
#### `self.log(name, value, **kwargs)`
Log a metric to all configured loggers.
**Common Parameters:**
- `on_step` (bool) - Log at each batch step
- `on_epoch` (bool) - Log at the end of epoch (automatically aggregated)
- `prog_bar` (bool) - Display in progress bar
- `logger` (bool) - Send to logger
- `sync_dist` (bool) - Synchronize across all distributed processes
- `reduce_fx` (str) - Reduction function for distributed ("mean", "sum", etc.)
```python
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
```
#### `self.log_dict(dictionary, **kwargs)`
Log multiple metrics at once.
```python
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
self.log_dict(metrics, on_step=True, on_epoch=True)
```
### Device Management
- `self.device` - Current device (automatically managed)
- `self.to(device)` - Move model to device (usually handled automatically)
**Best Practice:** Create tensors on model's device:
```python
new_tensor = torch.zeros(10, device=self.device)
```
### Hyperparameter Management
#### `self.save_hyperparameters(*args, **kwargs)`
Automatically save init arguments to `self.hparams` and checkpoints.
```python
def __init__(self, learning_rate, hidden_dim):
super().__init__()
self.save_hyperparameters() # Saves all args
# Access via self.hparams.learning_rate, self.hparams.hidden_dim
```
---
## Trainer
The Trainer automates the training loop and engineering complexity.
### Core Parameters
#### Training Duration
- `max_epochs` (int) - Maximum number of epochs (default: 1000)
- `min_epochs` (int) - Minimum number of epochs
- `max_steps` (int) - Maximum number of optimizer steps
- `min_steps` (int) - Minimum number of optimizer steps
- `max_time` (str/dict) - Maximum training time ("DD:HH:MM:SS" or dict)
#### Hardware Configuration
- `accelerator` (str) - Hardware to use: "cpu", "gpu", "tpu", "auto"
- `devices` (int/list) - Number or specific device IDs: 1, 4, [0,2], "auto"
- `num_nodes` (int) - Number of GPU nodes for distributed training
- `strategy` (str) - Training strategy: "ddp", "fsdp", "deepspeed", etc.
#### Data Management
- `limit_train_batches` (int/float) - Limit training batches (0.0-1.0 for %, int for count)
- `limit_val_batches` (int/float) - Limit validation batches
- `limit_test_batches` (int/float) - Limit test batches
- `limit_predict_batches` (int/float) - Limit prediction batches
#### Validation
- `check_val_every_n_epoch` (int) - Run validation every N epochs
- `val_check_interval` (int/float) - Validate every N batches or fraction
- `num_sanity_val_steps` (int) - Validation steps before training (default: 2)
#### Optimization
- `gradient_clip_val` (float) - Clip gradients by value
- `gradient_clip_algorithm` (str) - "value" or "norm"
- `accumulate_grad_batches` (int) - Accumulate gradients over K batches
- `precision` (str) - Training precision: "32-true", "16-mixed", "bf16-mixed", "64-true"
#### Logging and Checkpointing
- `logger` (Logger/list) - Logger instance(s) or True/False
- `log_every_n_steps` (int) - Logging frequency
- `enable_checkpointing` (bool) - Enable automatic checkpointing
- `callbacks` (list) - List of callback instances
- `default_root_dir` (str) - Default path for logs and checkpoints
#### Debugging
- `fast_dev_run` (bool/int) - Run N batches for quick testing
- `overfit_batches` (int/float) - Overfit on limited data for debugging
- `detect_anomaly` (bool) - Enable PyTorch anomaly detection
- `profiler` (str/Profiler) - Profile training: "simple", "advanced", or custom
#### Performance
- `benchmark` (bool) - Enable cudnn.benchmark for performance
- `deterministic` (bool) - Enable deterministic training for reproducibility
- `sync_batchnorm` (bool) - Synchronize batch norm across GPUs
### Training Methods
#### `trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)`
Run the full training routine.
```python
trainer.fit(model, train_loader, val_loader)
# Or with DataModule
trainer.fit(model, datamodule=dm)
# Resume from checkpoint
trainer.fit(model, train_loader, val_loader, ckpt_path="path/to/checkpoint.ckpt")
```
#### `trainer.validate(model, dataloaders=None, datamodule=None, ckpt_path=None)`
Run validation independently.
```python
trainer.validate(model, val_loader)
```
#### `trainer.test(model, dataloaders=None, datamodule=None, ckpt_path=None)`
Run test evaluation.
```python
trainer.test(model, test_loader)
# Or load from checkpoint
trainer.test(ckpt_path="best_model.ckpt", datamodule=dm)
```
#### `trainer.predict(model, dataloaders=None, datamodule=None, ckpt_path=None)`
Run inference predictions.
```python
predictions = trainer.predict(model, predict_loader)
```
---
## LightningDataModule
Encapsulates all data processing logic in a reusable class.
### Core Methods
#### `prepare_data(self)`
Download and prepare data (called once on single process).
Do NOT set state here (no self.x = y).
```python
def prepare_data(self):
# Download datasets
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
```
#### `setup(self, stage=None)`
Load data and create splits (called on every process/GPU).
Setting state is OK here.
**stage parameter:** "fit", "validate", "test", or "predict"
```python
def setup(self, stage=None):
if stage == "fit" or stage is None:
full_dataset = datasets.MNIST(self.data_dir, train=True)
self.train_dataset, self.val_dataset = random_split(full_dataset, [55000, 5000])
if stage == "test" or stage is None:
self.test_dataset = datasets.MNIST(self.data_dir, train=False)
```
#### DataLoader Methods
- `train_dataloader(self)` - Return training DataLoader
- `val_dataloader(self)` - Return validation DataLoader
- `test_dataloader(self)` - Return test DataLoader
- `predict_dataloader(self)` - Return prediction DataLoader
```python
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=32, shuffle=True)
```
### Optional Methods
- `teardown(stage=None)` - Cleanup after training/testing
- `state_dict()` - Save state for checkpointing
- `load_state_dict(state_dict)` - Load state from checkpoint
---
## Callbacks
Extend training with modular, reusable functionality.
### Built-in Callbacks
#### ModelCheckpoint
Save model checkpoints based on monitored metrics.
```python
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints/',
filename='{epoch}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=3,
save_last=True,
verbose=True,
)
```
**Key Parameters:**
- `monitor` - Metric to monitor
- `mode` - "min" or "max"
- `save_top_k` - Save top K models
- `save_last` - Always save last checkpoint
- `every_n_epochs` - Save every N epochs
#### EarlyStopping
Stop training when metric stops improving.
```python
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
mode='min',
verbose=True,
)
```
#### LearningRateMonitor
Log learning rate values.
```python
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')
```
#### RichProgressBar
Display rich progress bar with metrics.
```python
from lightning.pytorch.callbacks import RichProgressBar
progress_bar = RichProgressBar()
```
### Custom Callbacks
Create custom callbacks by inheriting from `Callback`.
```python
from lightning.pytorch.callbacks import Callback
class MyCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training starting!")
def on_train_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} ended")
def on_validation_end(self, trainer, pl_module):
val_loss = trainer.callback_metrics.get('val_loss')
print(f"Validation loss: {val_loss}")
```
**Common Hooks:**
- `on_train_start/end`
- `on_train_epoch_start/end`
- `on_validation_epoch_start/end`
- `on_test_epoch_start/end`
- `on_before_backward/on_after_backward`
- `on_before_optimizer_step`
---
## Loggers
Track experiments with various logging frameworks.
### TensorBoardLogger
```python
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir='logs/', name='my_experiment')
trainer = Trainer(logger=logger)
```
### WandbLogger
```python
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project='my_project', name='experiment_1')
trainer = Trainer(logger=logger)
```
### MLFlowLogger
```python
from lightning.pytorch.loggers import MLFlowLogger
logger = MLFlowLogger(experiment_name='my_exp', tracking_uri='file:./ml-runs')
trainer = Trainer(logger=logger)
```
### CSVLogger
```python
from lightning.pytorch.loggers import CSVLogger
logger = CSVLogger(save_dir='logs/', name='my_experiment')
trainer = Trainer(logger=logger)
```
### Multiple Loggers
```python
loggers = [
TensorBoardLogger('logs/'),
CSVLogger('logs/'),
]
trainer = Trainer(logger=loggers)
```
---
## Common Patterns
### Reproducibility
```python
from lightning.pytorch import seed_everything
seed_everything(42, workers=True)
trainer = Trainer(deterministic=True)
```
### Mixed Precision Training
```python
trainer = Trainer(precision='16-mixed') # or 'bf16-mixed'
```
### Multi-GPU Training
```python
# Data parallel (DDP)
trainer = Trainer(accelerator='gpu', devices=4, strategy='ddp')
# Model parallel (FSDP)
trainer = Trainer(accelerator='gpu', devices=4, strategy='fsdp')
```
### Gradient Accumulation
```python
trainer = Trainer(accumulate_grad_batches=4) # Effective batch size = 4x
```
### Learning Rate Finding
```python
from lightning.pytorch.tuner import Tuner
trainer = Trainer()
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, train_dataloader)
model.hparams.learning_rate = lr_finder.suggestion()
```
### Loading from Checkpoint
```python
# Load model
model = MyLightningModule.load_from_checkpoint('checkpoint.ckpt')
# Resume training
trainer.fit(model, ckpt_path='checkpoint.ckpt')
```

View File

@@ -0,0 +1,508 @@
# Distributed and Model Parallel Training
Comprehensive guide for distributed training strategies in PyTorch Lightning.
## Overview
PyTorch Lightning provides seamless distributed training across multiple GPUs, machines, and TPUs with minimal code changes. The framework automatically handles the complexity of distributed training while keeping code device-agnostic and readable.
## Training Strategies
### Data Parallel (DDP - DistributedDataParallel)
**Best for:** Most models (< 500M parameters) where the full model fits in GPU memory.
**How it works:** Each GPU holds a complete copy of the model and trains on a different batch subset. Gradients are synchronized across GPUs during backward pass.
```python
# Single-node, multi-GPU
trainer = Trainer(
accelerator='gpu',
devices=4, # Use 4 GPUs
strategy='ddp',
)
# Multi-node, multi-GPU
trainer = Trainer(
accelerator='gpu',
devices=4, # GPUs per node
num_nodes=2, # Number of nodes
strategy='ddp',
)
```
**Advantages:**
- Most widely used and tested
- Works with most PyTorch code
- Good scaling efficiency
- No code changes required in LightningModule
**When to use:** Default choice for most distributed training scenarios.
### FSDP (Fully Sharded Data Parallel)
**Best for:** Large models (500M+ parameters) that don't fit in single GPU memory.
**How it works:** Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores a subset of the model.
```python
trainer = Trainer(
accelerator='gpu',
devices=4,
strategy='fsdp',
)
# With configuration
from lightning.pytorch.strategies import FSDPStrategy
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD", # Full sharding
cpu_offload=False, # Offload to CPU
mixed_precision=torch.float16,
)
trainer = Trainer(
accelerator='gpu',
devices=4,
strategy=strategy,
)
```
**Sharding Strategies:**
- `FULL_SHARD` - Shard parameters, gradients, and optimizer states
- `SHARD_GRAD_OP` - Shard only gradients and optimizer states
- `NO_SHARD` - DDP-like (no sharding)
- `HYBRID_SHARD` - Shard within node, DDP across nodes
**Advanced FSDP Configuration:**
```python
from lightning.pytorch.strategies import FSDPStrategy
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD",
activation_checkpointing=True, # Save memory
cpu_offload=True, # Offload parameters to CPU
backward_prefetch="BACKWARD_PRE", # Prefetch strategy
forward_prefetch=True,
limit_all_gathers=True,
)
```
**When to use:**
- Models > 500M parameters
- Limited GPU memory
- Native PyTorch solution preferred
- Migrating from standalone PyTorch FSDP
### DeepSpeed
**Best for:** Cutting-edge features, massive models, or existing DeepSpeed users.
**How it works:** Comprehensive optimization library with multiple stages of memory and compute optimization.
```python
# Basic DeepSpeed
trainer = Trainer(
accelerator='gpu',
devices=4,
strategy='deepspeed',
precision='16-mixed',
)
# With configuration
from lightning.pytorch.strategies import DeepSpeedStrategy
strategy = DeepSpeedStrategy(
stage=2, # ZeRO Stage (1, 2, or 3)
offload_optimizer=True,
offload_parameters=True,
)
trainer = Trainer(
accelerator='gpu',
devices=4,
strategy=strategy,
)
```
**ZeRO Stages:**
- **Stage 1:** Shard optimizer states
- **Stage 2:** Shard optimizer states + gradients
- **Stage 3:** Shard optimizer states + gradients + parameters (like FSDP)
**With DeepSpeed Config File:**
```python
strategy = DeepSpeedStrategy(config="deepspeed_config.json")
```
Example `deepspeed_config.json`:
```json
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_bucket_size": 2e8,
"reduce_bucket_size": 2e8
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true
},
"fp16": {
"enabled": true
},
"gradient_clipping": 1.0
}
```
**When to use:**
- Need specific DeepSpeed features
- Maximum memory efficiency required
- Already familiar with DeepSpeed
- Training extremely large models
### DDP Spawn
**Note:** Generally avoid using `ddp_spawn`. Use `ddp` instead.
```python
trainer = Trainer(strategy='ddp_spawn') # Not recommended
```
**Issues with ddp_spawn:**
- Cannot return values from `.fit()`
- Pickling issues with unpicklable objects
- Slower than `ddp`
- More memory overhead
**When to use:** Only for debugging or if `ddp` doesn't work on your system.
## Multi-Node Training
### Basic Multi-Node Setup
```python
# On each node, run the same command
trainer = Trainer(
accelerator='gpu',
devices=4, # GPUs per node
num_nodes=8, # Total number of nodes
strategy='ddp',
)
```
### SLURM Cluster
Lightning automatically detects SLURM environment:
```python
trainer = Trainer(
accelerator='gpu',
devices=4,
num_nodes=8,
strategy='ddp',
)
```
**SLURM Submit Script:**
```bash
#!/bin/bash
#SBATCH --nodes=8
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --job-name=lightning_training
python train.py
```
### Manual Cluster Setup
```python
from lightning.pytorch.strategies import DDPStrategy
strategy = DDPStrategy(
cluster_environment='TorchElastic', # or 'SLURM', 'LSF', 'Kubeflow'
)
trainer = Trainer(
accelerator='gpu',
devices=4,
num_nodes=8,
strategy=strategy,
)
```
## Memory Optimization Techniques
### Gradient Accumulation
Simulate larger batch sizes without increasing memory:
```python
trainer = Trainer(
accumulate_grad_batches=4, # Accumulate 4 batches before optimizer step
)
# Variable accumulation by epoch
trainer = Trainer(
accumulate_grad_batches={
0: 8, # Epochs 0-4: accumulate 8 batches
5: 4, # Epochs 5+: accumulate 4 batches
}
)
```
### Activation Checkpointing
Trade computation for memory by recomputing activations during backward pass:
```python
# FSDP
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
class MyModule(L.LightningModule):
def configure_model(self):
# Wrap specific layers for activation checkpointing
self.model = MyTransformer()
apply_activation_checkpointing(
self.model,
checkpoint_wrapper_fn=lambda m: checkpoint_wrapper(m, CheckpointImpl.NO_REENTRANT),
check_fn=lambda m: isinstance(m, TransformerBlock),
)
```
### Mixed Precision Training
Reduce memory usage and increase speed with mixed precision:
```python
# 16-bit mixed precision
trainer = Trainer(precision='16-mixed')
# BFloat16 mixed precision (more stable, requires newer GPUs)
trainer = Trainer(precision='bf16-mixed')
```
### CPU Offloading
Offload parameters or optimizer states to CPU:
```python
# FSDP with CPU offload
from lightning.pytorch.strategies import FSDPStrategy
strategy = FSDPStrategy(
cpu_offload=True, # Offload parameters to CPU
)
# DeepSpeed with CPU offload
from lightning.pytorch.strategies import DeepSpeedStrategy
strategy = DeepSpeedStrategy(
stage=3,
offload_optimizer=True,
offload_parameters=True,
)
```
## Performance Optimization
### Synchronize Batch Normalization
Synchronize batch norm statistics across GPUs:
```python
trainer = Trainer(
accelerator='gpu',
devices=4,
strategy='ddp',
sync_batchnorm=True, # Sync batch norm across GPUs
)
```
### Find Optimal Batch Size
```python
from lightning.pytorch.tuner import Tuner
trainer = Trainer()
tuner = Tuner(trainer)
# Auto-scale batch size
tuner.scale_batch_size(model, mode="power") # or "binsearch"
```
### Gradient Clipping
Prevent gradient explosion in distributed training:
```python
trainer = Trainer(
gradient_clip_val=1.0,
gradient_clip_algorithm='norm', # or 'value'
)
```
### Benchmark Mode
Enable cudnn.benchmark for consistent input sizes:
```python
trainer = Trainer(
benchmark=True, # Optimize for consistent input sizes
)
```
## Distributed Data Loading
### Automatic Distributed Sampling
Lightning automatically handles distributed sampling:
```python
# No changes needed - Lightning handles this automatically
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=32,
shuffle=True, # Lightning converts to DistributedSampler
)
```
### Manual Control
```python
# Disable automatic distributed sampler
trainer = Trainer(
use_distributed_sampler=False,
)
# Manual distributed sampler
from torch.utils.data.distributed import DistributedSampler
def train_dataloader(self):
sampler = DistributedSampler(self.train_dataset)
return DataLoader(
self.train_dataset,
batch_size=32,
sampler=sampler,
)
```
### Data Loading Best Practices
```python
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=32,
num_workers=4, # Use multiple workers
pin_memory=True, # Faster CPU-GPU transfer
persistent_workers=True, # Keep workers alive between epochs
)
```
## Common Patterns
### Logging in Distributed Training
```python
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Automatically syncs across processes
self.log('train_loss', loss, sync_dist=True)
return loss
```
### Rank-Specific Operations
```python
def training_step(self, batch, batch_idx):
# Run only on rank 0 (main process)
if self.trainer.is_global_zero:
print("This only prints once across all processes")
# Get current rank
rank = self.trainer.global_rank
world_size = self.trainer.world_size
return loss
```
### Barrier Synchronization
```python
def on_train_epoch_end(self):
# Wait for all processes
self.trainer.strategy.barrier()
# Now all processes are synchronized
if self.trainer.is_global_zero:
# Save something only once
self.save_artifacts()
```
## Troubleshooting
### Common Issues
**1. Out of Memory:**
- Reduce batch size
- Enable gradient accumulation
- Use FSDP or DeepSpeed
- Enable activation checkpointing
- Use mixed precision
**2. Slow Training:**
- Check data loading (use `num_workers > 0`)
- Enable `pin_memory=True` and `persistent_workers=True`
- Use `benchmark=True` for consistent input sizes
- Profile with `profiler='simple'`
**3. Hanging:**
- Ensure all processes execute same collectives
- Check for `if` statements that differ across ranks
- Use barrier synchronization when needed
**4. Inconsistent Results:**
- Set `deterministic=True`
- Use `seed_everything()`
- Ensure proper gradient synchronization
### Debugging Distributed Training
```python
# Test with single GPU first
trainer = Trainer(accelerator='gpu', devices=1)
# Then test with 2 GPUs
trainer = Trainer(accelerator='gpu', devices=2, strategy='ddp')
# Use fast_dev_run for quick testing
trainer = Trainer(
accelerator='gpu',
devices=2,
strategy='ddp',
fast_dev_run=10, # Run 10 batches only
)
```
## Strategy Selection Guide
| Model Size | Available Memory | Recommended Strategy |
|-----------|------------------|---------------------|
| < 500M params | Fits in 1 GPU | Single GPU |
| < 500M params | Fits across GPUs | DDP |
| 500M - 3B params | Limited memory | FSDP or DeepSpeed Stage 2 |
| 3B+ params | Very limited memory | FSDP or DeepSpeed Stage 3 |
| Any size | Maximum efficiency | DeepSpeed with offloading |
| Multiple nodes | Any | DDP (< 500M) or FSDP/DeepSpeed (> 500M) |

View File

@@ -0,0 +1,262 @@
"""
Helper script to quickly set up a PyTorch Lightning Trainer with common configurations.
This script provides preset configurations for different training scenarios
and makes it easy to create a Trainer with best practices.
"""
import lightning as L
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
ModelSummary,
)
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
def create_trainer(
preset: str = "default",
max_epochs: int = 100,
accelerator: str = "auto",
devices: int = 1,
log_dir: str = "./logs",
experiment_name: str = "lightning_experiment",
enable_checkpointing: bool = True,
enable_early_stopping: bool = True,
**kwargs
):
"""
Create a Lightning Trainer with preset configurations.
Args:
preset: Configuration preset - "default", "fast_dev", "production", "distributed"
max_epochs: Maximum number of training epochs
accelerator: Device to use ("auto", "gpu", "cpu", "tpu")
devices: Number of devices to use
log_dir: Directory for logs and checkpoints
experiment_name: Name for the experiment
enable_checkpointing: Whether to enable model checkpointing
enable_early_stopping: Whether to enable early stopping
**kwargs: Additional arguments to pass to Trainer
Returns:
Configured Lightning Trainer instance
"""
callbacks = []
logger_list = []
# Configure based on preset
if preset == "fast_dev":
# Fast development run - minimal epochs, quick debugging
config = {
"fast_dev_run": False,
"max_epochs": 3,
"limit_train_batches": 100,
"limit_val_batches": 50,
"log_every_n_steps": 10,
"enable_progress_bar": True,
"enable_model_summary": True,
}
elif preset == "production":
# Production-ready configuration with all bells and whistles
config = {
"max_epochs": max_epochs,
"precision": "16-mixed",
"gradient_clip_val": 1.0,
"log_every_n_steps": 50,
"enable_progress_bar": True,
"enable_model_summary": True,
"deterministic": True,
"benchmark": True,
}
# Add model checkpointing
if enable_checkpointing:
callbacks.append(
ModelCheckpoint(
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True,
verbose=True,
)
)
# Add early stopping
if enable_early_stopping:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=10,
mode="min",
verbose=True,
)
)
# Add learning rate monitor
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
# Add TensorBoard logger
logger_list.append(
TensorBoardLogger(
save_dir=log_dir,
name=experiment_name,
version=None,
)
)
elif preset == "distributed":
# Distributed training configuration
config = {
"max_epochs": max_epochs,
"strategy": "ddp",
"precision": "16-mixed",
"sync_batchnorm": True,
"use_distributed_sampler": True,
"log_every_n_steps": 50,
"enable_progress_bar": True,
}
# Add model checkpointing
if enable_checkpointing:
callbacks.append(
ModelCheckpoint(
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True,
)
)
else: # default
# Default configuration - balanced for most use cases
config = {
"max_epochs": max_epochs,
"log_every_n_steps": 50,
"enable_progress_bar": True,
"enable_model_summary": True,
}
# Add basic checkpointing
if enable_checkpointing:
callbacks.append(
ModelCheckpoint(
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
filename="{epoch}-{val_loss:.2f}",
monitor="val_loss",
save_last=True,
)
)
# Add CSV logger
logger_list.append(
CSVLogger(
save_dir=log_dir,
name=experiment_name,
)
)
# Add progress bar
if config.get("enable_progress_bar", True):
callbacks.append(RichProgressBar())
# Merge with provided kwargs
final_config = {
**config,
"accelerator": accelerator,
"devices": devices,
"callbacks": callbacks,
"logger": logger_list if logger_list else True,
**kwargs,
}
# Create and return trainer
return L.Trainer(**final_config)
def create_debugging_trainer():
"""Create a trainer optimized for debugging."""
return create_trainer(
preset="fast_dev",
max_epochs=1,
limit_train_batches=10,
limit_val_batches=5,
num_sanity_val_steps=2,
)
def create_gpu_trainer(num_gpus: int = 1, precision: str = "16-mixed"):
"""Create a trainer optimized for GPU training."""
return create_trainer(
preset="production",
accelerator="gpu",
devices=num_gpus,
precision=precision,
)
def create_distributed_trainer(num_gpus: int = 2, num_nodes: int = 1):
"""Create a trainer for distributed training across multiple GPUs."""
return create_trainer(
preset="distributed",
accelerator="gpu",
devices=num_gpus,
num_nodes=num_nodes,
strategy="ddp",
)
# Example usage
if __name__ == "__main__":
print("Creating different trainer configurations...\n")
# 1. Default trainer
print("1. Default trainer:")
trainer_default = create_trainer(preset="default", max_epochs=50)
print(f" Max epochs: {trainer_default.max_epochs}")
print(f" Accelerator: {trainer_default.accelerator}")
print(f" Callbacks: {len(trainer_default.callbacks)}")
print()
# 2. Fast development trainer
print("2. Fast development trainer:")
trainer_dev = create_trainer(preset="fast_dev")
print(f" Max epochs: {trainer_dev.max_epochs}")
print(f" Train batches limit: {trainer_dev.limit_train_batches}")
print()
# 3. Production trainer
print("3. Production trainer:")
trainer_prod = create_trainer(
preset="production",
max_epochs=100,
experiment_name="my_experiment"
)
print(f" Max epochs: {trainer_prod.max_epochs}")
print(f" Precision: {trainer_prod.precision}")
print(f" Callbacks: {len(trainer_prod.callbacks)}")
print()
# 4. Debugging trainer
print("4. Debugging trainer:")
trainer_debug = create_debugging_trainer()
print(f" Max epochs: {trainer_debug.max_epochs}")
print(f" Train batches: {trainer_debug.limit_train_batches}")
print()
# 5. GPU trainer
print("5. GPU trainer:")
trainer_gpu = create_gpu_trainer(num_gpus=1)
print(f" Accelerator: {trainer_gpu.accelerator}")
print(f" Precision: {trainer_gpu.precision}")
print()
print("All trainer configurations created successfully!")

View File

@@ -0,0 +1,221 @@
"""
Template for creating a PyTorch Lightning DataModule.
This template includes all common hooks and patterns for organizing
data processing workflows with best practices.
"""
import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split
import torch
class TemplateDataset(Dataset):
"""Example dataset - replace with your actual dataset."""
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = targets
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]
y = self.targets[idx]
if self.transform:
x = self.transform(x)
return x, y
class TemplateDataModule(L.LightningDataModule):
"""Template DataModule with all common hooks and patterns."""
def __init__(
self,
data_dir: str = "./data",
batch_size: int = 32,
num_workers: int = 4,
train_val_split: tuple = (0.8, 0.2),
seed: int = 42,
pin_memory: bool = True,
persistent_workers: bool = True,
):
super().__init__()
# Save hyperparameters
self.save_hyperparameters()
# Initialize attributes
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.train_val_split = train_val_split
self.seed = seed
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
# Placeholders for datasets
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.predict_dataset = None
# Placeholder for transforms
self.train_transform = None
self.val_transform = None
self.test_transform = None
def prepare_data(self):
"""
Download and prepare data (called only on 1 GPU/TPU in distributed settings).
Use this for downloading, tokenizing, etc. Do NOT set state here (no self.x = y).
"""
# Example: Download datasets
# datasets.MNIST(self.data_dir, train=True, download=True)
# datasets.MNIST(self.data_dir, train=False, download=True)
pass
def setup(self, stage: str = None):
"""
Load data and create train/val/test splits (called on every GPU/TPU in distributed).
Use this for splitting, creating datasets, etc. Setting state is OK here (self.x = y).
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'
"""
# Fit stage: setup training and validation datasets
if stage == "fit" or stage is None:
# Load full dataset
# Example: full_dataset = datasets.MNIST(self.data_dir, train=True, transform=self.train_transform)
# Create dummy data for template
full_data = torch.randn(1000, 784)
full_targets = torch.randint(0, 10, (1000,))
full_dataset = TemplateDataset(full_data, full_targets, transform=self.train_transform)
# Split into train and validation
train_size = int(len(full_dataset) * self.train_val_split[0])
val_size = len(full_dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(self.seed)
)
# Apply validation transform if different from train
if self.val_transform:
self.val_dataset.dataset.transform = self.val_transform
# Test stage: setup test dataset
if stage == "test" or stage is None:
# Example: self.test_dataset = datasets.MNIST(
# self.data_dir, train=False, transform=self.test_transform
# )
# Create dummy test data for template
test_data = torch.randn(200, 784)
test_targets = torch.randint(0, 10, (200,))
self.test_dataset = TemplateDataset(test_data, test_targets, transform=self.test_transform)
# Predict stage: setup prediction dataset
if stage == "predict" or stage is None:
# Example: self.predict_dataset = YourCustomDataset(...)
# Create dummy predict data for template
predict_data = torch.randn(100, 784)
predict_targets = torch.zeros(100, dtype=torch.long)
self.predict_dataset = TemplateDataset(predict_data, predict_targets)
def train_dataloader(self):
"""Return training dataloader."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
)
def val_dataloader(self):
"""Return validation dataloader."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
)
def test_dataloader(self):
"""Return test dataloader."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
)
def predict_dataloader(self):
"""Return prediction dataloader."""
return DataLoader(
self.predict_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
)
def teardown(self, stage: str = None):
"""Clean up after fit, validate, test, or predict."""
# Example: close database connections, clear caches, etc.
pass
def state_dict(self):
"""Save state for checkpointing."""
# Return anything you want to save in the checkpoint
return {}
def load_state_dict(self, state_dict):
"""Load state from checkpoint."""
# Restore state from checkpoint
pass
# Example usage
if __name__ == "__main__":
# Create datamodule
datamodule = TemplateDataModule(
data_dir="./data",
batch_size=32,
num_workers=4,
train_val_split=(0.8, 0.2),
)
# Prepare and setup data
datamodule.prepare_data()
datamodule.setup("fit")
# Get dataloaders
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
print("Template DataModule created successfully!")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Batch size: {datamodule.batch_size}")
# Test a batch
batch = next(iter(train_loader))
x, y = batch
print(f"Batch shape: {x.shape}, {y.shape}")

View File

@@ -0,0 +1,215 @@
"""
Template for creating a PyTorch Lightning LightningModule.
This template includes all common hooks and patterns for building
a Lightning model with best practices.
"""
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
class TemplateLightningModule(L.LightningModule):
"""Template LightningModule with all common hooks and patterns."""
def __init__(
self,
# Model architecture parameters
input_dim: int = 784,
hidden_dim: int = 128,
output_dim: int = 10,
# Optimization parameters
learning_rate: float = 1e-3,
optimizer_type: str = "adam",
scheduler_type: str = None,
# Other hyperparameters
dropout: float = 0.1,
):
super().__init__()
# Save hyperparameters for checkpointing and logging
self.save_hyperparameters()
# Define model architecture
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim)
)
# Define loss function
self.criterion = nn.CrossEntropyLoss()
# For tracking validation outputs (optional)
self.validation_step_outputs = []
def forward(self, x):
"""Forward pass for inference."""
return self.model(x)
def training_step(self, batch, batch_idx):
"""Training step - called for each training batch."""
x, y = batch
# Forward pass
logits = self(x)
loss = self.criterion(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
"""Validation step - called for each validation batch."""
x, y = batch
# Forward pass (model automatically in eval mode)
logits = self(x)
loss = self.criterion(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
# Optional: store outputs for epoch-level processing
self.validation_step_outputs.append({"loss": loss, "acc": acc})
return loss
def on_validation_epoch_end(self):
"""Called at the end of validation epoch."""
# Optional: process all validation outputs
if self.validation_step_outputs:
avg_loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()
avg_acc = torch.stack([x["acc"] for x in self.validation_step_outputs]).mean()
# Log epoch-level metrics if needed
# self.log("val_epoch_loss", avg_loss)
# self.log("val_epoch_acc", avg_acc)
# Clear outputs
self.validation_step_outputs.clear()
def test_step(self, batch, batch_idx):
"""Test step - called for each test batch."""
x, y = batch
# Forward pass
logits = self(x)
loss = self.criterion(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.log("test_acc", acc, on_step=False, on_epoch=True)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=0):
"""Prediction step - called for each prediction batch."""
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
return preds
def configure_optimizers(self):
"""Configure optimizer and learning rate scheduler."""
# Create optimizer
if self.hparams.optimizer_type.lower() == "adam":
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
elif self.hparams.optimizer_type.lower() == "sgd":
optimizer = SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9)
else:
raise ValueError(f"Unknown optimizer: {self.hparams.optimizer_type}")
# Configure with scheduler if specified
if self.hparams.scheduler_type:
if self.hparams.scheduler_type.lower() == "reduce_on_plateau":
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
}
}
elif self.hparams.scheduler_type.lower() == "step":
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
}
}
return optimizer
# Optional: Additional hooks for custom behavior
def on_train_start(self):
"""Called at the beginning of training."""
pass
def on_train_epoch_start(self):
"""Called at the beginning of each training epoch."""
pass
def on_train_epoch_end(self):
"""Called at the end of each training epoch."""
pass
def on_train_end(self):
"""Called at the end of training."""
pass
# Example usage
if __name__ == "__main__":
# Create model
model = TemplateLightningModule(
input_dim=784,
hidden_dim=128,
output_dim=10,
learning_rate=1e-3,
optimizer_type="adam",
scheduler_type="reduce_on_plateau"
)
# Create trainer
trainer = L.Trainer(
max_epochs=10,
accelerator="auto",
devices=1,
log_every_n_steps=50,
)
# Note: You would need to provide dataloaders
# trainer.fit(model, train_dataloader, val_dataloader)
print("Template LightningModule created successfully!")
print(f"Model hyperparameters: {model.hparams}")