mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-28 07:33:45 +08:00
Add more scientific skills
This commit is contained in:
660
scientific-packages/pytorch-lightning/SKILL.md
Normal file
660
scientific-packages/pytorch-lightning/SKILL.md
Normal file
@@ -0,0 +1,660 @@
|
||||
---
|
||||
name: pytorch-lightning
|
||||
description: Comprehensive toolkit for PyTorch Lightning, a deep learning framework for organizing PyTorch code. Use this skill when working with PyTorch Lightning for training deep learning models, implementing LightningModules, configuring Trainers, setting up distributed training, creating DataModules, or converting existing PyTorch code to Lightning format. The skill provides templates, reference documentation, and best practices for efficient deep learning workflows.
|
||||
---
|
||||
|
||||
# PyTorch Lightning
|
||||
|
||||
## Overview
|
||||
|
||||
PyTorch Lightning is a deep learning framework that organizes PyTorch code to decouple research from engineering. It automates training loop complexity (multi-GPU, mixed precision, checkpointing, logging) while maintaining full flexibility over model architecture and training logic.
|
||||
|
||||
**Core Philosophy:** Separate concerns
|
||||
- **LightningModule** - Research code (model architecture, training logic)
|
||||
- **Trainer** - Engineering automation (hardware, optimization, logging)
|
||||
- **DataModule** - Data processing (downloading, loading, transforms)
|
||||
- **Callbacks** - Non-essential functionality (checkpointing, early stopping)
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Building or training deep learning models with PyTorch
|
||||
- Converting existing PyTorch code to Lightning structure
|
||||
- Setting up distributed training across multiple GPUs or nodes
|
||||
- Implementing custom training loops with validation and testing
|
||||
- Organizing data processing pipelines
|
||||
- Configuring experiment logging and model checkpointing
|
||||
- Optimizing training performance and memory usage
|
||||
- Working with large models requiring model parallelism
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Lightning Workflow
|
||||
|
||||
1. **Define a LightningModule** (organize your model)
|
||||
2. **Create a DataModule or DataLoaders** (organize your data)
|
||||
3. **Configure a Trainer** (automate training)
|
||||
4. **Train** with `trainer.fit()`
|
||||
|
||||
### Minimal Example
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
# 1. Define LightningModule
|
||||
class SimpleModel(L.LightningModule):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# 2. Prepare data
|
||||
train_data = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
|
||||
train_loader = DataLoader(train_data, batch_size=32)
|
||||
|
||||
# 3. Create Trainer
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='auto')
|
||||
|
||||
# 4. Train
|
||||
model = SimpleModel(input_dim=10, output_dim=1)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### 1. Creating a LightningModule
|
||||
|
||||
Structure model code by implementing essential hooks:
|
||||
|
||||
**Template:** Use `scripts/template_lightning_module.py` as a starting point.
|
||||
|
||||
```python
|
||||
class MyLightningModule(L.LightningModule):
|
||||
def __init__(self, hyperparameters):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Save for checkpointing
|
||||
self.model = YourModel()
|
||||
|
||||
def forward(self, x):
|
||||
"""Inference forward pass."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Define training loop logic."""
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = self.compute_loss(y_hat, y)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""Define validation logic."""
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = self.compute_loss(y_hat, y)
|
||||
self.log('val_loss', loss, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Return optimizer and optional scheduler."""
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Use `self.save_hyperparameters()` to automatically save init args
|
||||
- Use `self.log()` to track metrics across loggers
|
||||
- Return loss from training_step for automatic optimization
|
||||
- Keep model architecture separate from training logic
|
||||
|
||||
### 2. Creating a DataModule
|
||||
|
||||
Organize all data processing in a reusable module:
|
||||
|
||||
**Template:** Use `scripts/template_datamodule.py` as a starting point.
|
||||
|
||||
```python
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size=32):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
def prepare_data(self):
|
||||
"""Download data (called once, single process)."""
|
||||
# Download datasets, tokenize, etc.
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
"""Create datasets (called on every process)."""
|
||||
if stage == 'fit' or stage is None:
|
||||
# Create train/val datasets
|
||||
self.train_dataset = ...
|
||||
self.val_dataset = ...
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
# Create test dataset
|
||||
self.test_dataset = ...
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `prepare_data()` for downloading (single process)
|
||||
- `setup()` for creating datasets (every process)
|
||||
- Use `stage` parameter to separate fit/test logic
|
||||
- Makes data code reusable across projects
|
||||
|
||||
### 3. Configuring the Trainer
|
||||
|
||||
The Trainer automates training complexity:
|
||||
|
||||
**Helper:** Use `scripts/quick_trainer_setup.py` for preset configurations.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
trainer = L.Trainer(
|
||||
# Training duration
|
||||
max_epochs=100,
|
||||
|
||||
# Hardware
|
||||
accelerator='auto', # 'cpu', 'gpu', 'tpu'
|
||||
devices=1, # Number of devices or specific IDs
|
||||
|
||||
# Optimization
|
||||
precision='16-mixed', # Mixed precision training
|
||||
gradient_clip_val=1.0,
|
||||
accumulate_grad_batches=4, # Gradient accumulation
|
||||
|
||||
# Validation
|
||||
check_val_every_n_epoch=1,
|
||||
val_check_interval=1.0, # Validate every epoch
|
||||
|
||||
# Logging
|
||||
log_every_n_steps=50,
|
||||
logger=TensorBoardLogger('logs/'),
|
||||
|
||||
# Callbacks
|
||||
callbacks=[
|
||||
ModelCheckpoint(monitor='val_loss', mode='min'),
|
||||
EarlyStopping(monitor='val_loss', patience=10),
|
||||
],
|
||||
|
||||
# Debugging
|
||||
fast_dev_run=False, # Quick test with few batches
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Common Presets:**
|
||||
|
||||
```python
|
||||
from scripts.quick_trainer_setup import create_trainer
|
||||
|
||||
# Development preset (fast debugging)
|
||||
trainer = create_trainer(preset='fast_dev', max_epochs=3)
|
||||
|
||||
# Production preset (full features)
|
||||
trainer = create_trainer(preset='production', max_epochs=100)
|
||||
|
||||
# Distributed preset (multi-GPU)
|
||||
trainer = create_trainer(preset='distributed', devices=4)
|
||||
```
|
||||
|
||||
### 4. Training and Evaluation
|
||||
|
||||
```python
|
||||
# Training
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Or with dataloaders
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Resume from checkpoint
|
||||
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
|
||||
|
||||
# Testing
|
||||
trainer.test(model, datamodule=dm)
|
||||
# Or load best checkpoint
|
||||
trainer.test(ckpt_path='best', datamodule=dm)
|
||||
|
||||
# Prediction
|
||||
predictions = trainer.predict(model, predict_loader)
|
||||
|
||||
# Validation only
|
||||
trainer.validate(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### 5. Distributed Training
|
||||
|
||||
Lightning handles distributed training automatically:
|
||||
|
||||
```python
|
||||
# Single machine, multiple GPUs (Data Parallel)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='ddp', # DistributedDataParallel
|
||||
)
|
||||
|
||||
# Multiple machines, multiple GPUs
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=8, # Number of machines
|
||||
strategy='ddp',
|
||||
)
|
||||
|
||||
# Large models (Model Parallel with FSDP)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='fsdp', # Fully Sharded Data Parallel
|
||||
)
|
||||
|
||||
# Large models (Model Parallel with DeepSpeed)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='deepspeed_stage_2',
|
||||
precision='16-mixed',
|
||||
)
|
||||
```
|
||||
|
||||
**For detailed distributed training guide, see:** `references/distributed_training.md`
|
||||
|
||||
**Strategy Selection:**
|
||||
- Models < 500M params → Use `ddp`
|
||||
- Models > 500M params → Use `fsdp` or `deepspeed`
|
||||
- Maximum memory efficiency → Use DeepSpeed Stage 3 with offloading
|
||||
- Native PyTorch → Use `fsdp`
|
||||
- Cutting-edge features → Use `deepspeed`
|
||||
|
||||
### 6. Callbacks
|
||||
|
||||
Extend training with modular functionality:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
RichProgressBar,
|
||||
)
|
||||
|
||||
callbacks = [
|
||||
# Save best models
|
||||
ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
filename='{epoch}-{val_loss:.2f}',
|
||||
),
|
||||
|
||||
# Stop when no improvement
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
mode='min',
|
||||
),
|
||||
|
||||
# Log learning rate
|
||||
LearningRateMonitor(logging_interval='epoch'),
|
||||
|
||||
# Rich progress bar
|
||||
RichProgressBar(),
|
||||
]
|
||||
|
||||
trainer = L.Trainer(callbacks=callbacks)
|
||||
```
|
||||
|
||||
**Custom Callbacks:**
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class MyCustomCallback(Callback):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
# Custom logic at end of each epoch
|
||||
print(f"Epoch {trainer.current_epoch} completed")
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
val_loss = trainer.callback_metrics.get('val_loss')
|
||||
# Custom validation logic
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. Logging
|
||||
|
||||
Track experiments with various loggers:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.loggers import (
|
||||
TensorBoardLogger,
|
||||
WandbLogger,
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
)
|
||||
|
||||
# Single logger
|
||||
logger = TensorBoardLogger('logs/', name='my_experiment')
|
||||
|
||||
# Multiple loggers
|
||||
loggers = [
|
||||
TensorBoardLogger('logs/'),
|
||||
WandbLogger(project='my_project'),
|
||||
CSVLogger('logs/'),
|
||||
]
|
||||
|
||||
trainer = L.Trainer(logger=loggers)
|
||||
```
|
||||
|
||||
**Logging in LightningModule:**
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Log single metric
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
|
||||
# Log multiple metrics
|
||||
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
|
||||
self.log_dict(metrics, on_step=True, on_epoch=True)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
## Converting Existing PyTorch Code
|
||||
|
||||
### Standard PyTorch → Lightning
|
||||
|
||||
**Before (PyTorch):**
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
x, y = batch
|
||||
y_hat = model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**After (Lightning):**
|
||||
```python
|
||||
class MyLightningModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
|
||||
trainer = L.Trainer(max_epochs=num_epochs)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Key Changes:**
|
||||
1. Wrap model in LightningModule
|
||||
2. Move training loop logic to `training_step()`
|
||||
3. Move optimizer setup to `configure_optimizers()`
|
||||
4. Replace manual loop with `trainer.fit()`
|
||||
5. Lightning handles: `.zero_grad()`, `.backward()`, `.step()`, device placement
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Reproducibility
|
||||
|
||||
```python
|
||||
from lightning.pytorch import seed_everything
|
||||
|
||||
# Set seed for reproducibility
|
||||
seed_everything(42, workers=True)
|
||||
|
||||
trainer = L.Trainer(deterministic=True)
|
||||
```
|
||||
|
||||
### Mixed Precision Training
|
||||
|
||||
```python
|
||||
# 16-bit mixed precision
|
||||
trainer = L.Trainer(precision='16-mixed')
|
||||
|
||||
# BFloat16 mixed precision (more stable)
|
||||
trainer = L.Trainer(precision='bf16-mixed')
|
||||
```
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
```python
|
||||
# Effective batch size = 4x actual batch size
|
||||
trainer = L.Trainer(accumulate_grad_batches=4)
|
||||
```
|
||||
|
||||
### Learning Rate Finding
|
||||
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
|
||||
trainer = L.Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
|
||||
# Find optimal learning rate
|
||||
lr_finder = tuner.lr_find(model, train_dataloader)
|
||||
model.hparams.learning_rate = lr_finder.suggestion()
|
||||
|
||||
# Find optimal batch size
|
||||
tuner.scale_batch_size(model, mode="power")
|
||||
```
|
||||
|
||||
### Checkpointing and Loading
|
||||
|
||||
```python
|
||||
# Save checkpoint
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Checkpoint automatically saved to checkpoints/
|
||||
|
||||
# Load from checkpoint
|
||||
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
||||
|
||||
# Resume training
|
||||
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
|
||||
|
||||
# Test from checkpoint
|
||||
trainer.test(ckpt_path='best', datamodule=dm)
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```python
|
||||
# Quick test with few batches
|
||||
trainer = L.Trainer(fast_dev_run=10)
|
||||
|
||||
# Overfit on small data (debug model)
|
||||
trainer = L.Trainer(overfit_batches=100)
|
||||
|
||||
# Limit batches for quick iteration
|
||||
trainer = L.Trainer(
|
||||
limit_train_batches=100,
|
||||
limit_val_batches=50,
|
||||
)
|
||||
|
||||
# Profile training
|
||||
trainer = L.Trainer(profiler='simple') # or 'advanced'
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Code Organization
|
||||
|
||||
1. **Separate concerns:**
|
||||
- Model architecture in `__init__()`
|
||||
- Training logic in `training_step()`
|
||||
- Validation logic in `validation_step()`
|
||||
- Data processing in DataModule
|
||||
|
||||
2. **Use `save_hyperparameters()`:**
|
||||
```python
|
||||
def __init__(self, lr, hidden_dim, dropout):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Automatically saves all args
|
||||
```
|
||||
|
||||
3. **Device-agnostic code:**
|
||||
```python
|
||||
# Avoid manual device placement
|
||||
# BAD: tensor.cuda()
|
||||
# GOOD: Lightning handles this automatically
|
||||
|
||||
# Create tensors on model's device
|
||||
new_tensor = torch.zeros(10, device=self.device)
|
||||
```
|
||||
|
||||
4. **Log comprehensively:**
|
||||
```python
|
||||
self.log('metric', value, on_step=True, on_epoch=True, prog_bar=True)
|
||||
```
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **Use DataLoader best practices:**
|
||||
```python
|
||||
DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Multiple workers
|
||||
pin_memory=True, # Faster GPU transfer
|
||||
persistent_workers=True, # Keep workers alive
|
||||
)
|
||||
```
|
||||
|
||||
2. **Enable benchmark mode for fixed input sizes:**
|
||||
```python
|
||||
trainer = L.Trainer(benchmark=True)
|
||||
```
|
||||
|
||||
3. **Use gradient clipping:**
|
||||
```python
|
||||
trainer = L.Trainer(gradient_clip_val=1.0)
|
||||
```
|
||||
|
||||
4. **Enable mixed precision:**
|
||||
```python
|
||||
trainer = L.Trainer(precision='16-mixed')
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
1. **Sync metrics across devices:**
|
||||
```python
|
||||
self.log('metric', value, sync_dist=True)
|
||||
```
|
||||
|
||||
2. **Rank-specific operations:**
|
||||
```python
|
||||
if self.trainer.is_global_zero:
|
||||
# Only run on main process
|
||||
self.save_artifacts()
|
||||
```
|
||||
|
||||
3. **Use appropriate strategy:**
|
||||
- Small models → `ddp`
|
||||
- Large models → `fsdp` or `deepspeed`
|
||||
|
||||
## Resources
|
||||
|
||||
### Scripts
|
||||
|
||||
Executable templates for quick implementation:
|
||||
|
||||
- **`template_lightning_module.py`** - Complete LightningModule template with all hooks, logging, and optimization patterns
|
||||
- **`template_datamodule.py`** - Complete DataModule template with data loading, splitting, and transformation patterns
|
||||
- **`quick_trainer_setup.py`** - Helper functions to create Trainers with preset configurations (development, production, distributed)
|
||||
|
||||
### References
|
||||
|
||||
Comprehensive documentation for deep-dive learning:
|
||||
|
||||
- **`api_reference.md`** - Complete API reference covering LightningModule hooks, Trainer parameters, Callbacks, DataModules, Loggers, and common patterns
|
||||
- **`distributed_training.md`** - In-depth guide for distributed training strategies (DDP, FSDP, DeepSpeed), multi-node setup, memory optimization, and troubleshooting
|
||||
|
||||
Load references when needing detailed information:
|
||||
```python
|
||||
# Example: Load distributed training reference
|
||||
# See references/distributed_training.md for comprehensive distributed training guide
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Out of Memory:**
|
||||
- Reduce batch size
|
||||
- Use gradient accumulation
|
||||
- Enable mixed precision (`precision='16-mixed'`)
|
||||
- Use FSDP or DeepSpeed for large models
|
||||
- Enable activation checkpointing
|
||||
|
||||
**Slow Training:**
|
||||
- Use multiple DataLoader workers (`num_workers > 0`)
|
||||
- Enable `pin_memory=True` and `persistent_workers=True`
|
||||
- Enable `benchmark=True` for fixed input sizes
|
||||
- Profile with `profiler='simple'`
|
||||
|
||||
**Validation Not Running:**
|
||||
- Check `check_val_every_n_epoch` setting
|
||||
- Ensure validation data provided
|
||||
- Verify `validation_step()` implemented
|
||||
|
||||
**Checkpoints Not Saving:**
|
||||
- Ensure `enable_checkpointing=True`
|
||||
- Check `ModelCheckpoint` callback configuration
|
||||
- Verify `monitor` metric exists in logs
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official Documentation: https://lightning.ai/docs/pytorch/stable/
|
||||
- GitHub: https://github.com/Lightning-AI/lightning
|
||||
- Community: https://lightning.ai/community
|
||||
|
||||
When unclear about specific functionality, refer to `references/api_reference.md` for detailed API documentation or `references/distributed_training.md` for distributed training specifics.
|
||||
@@ -0,0 +1,490 @@
|
||||
# PyTorch Lightning API Reference
|
||||
|
||||
Comprehensive reference for PyTorch Lightning core APIs, hooks, and components.
|
||||
|
||||
## LightningModule
|
||||
|
||||
The LightningModule is the core abstraction for organizing PyTorch code in Lightning.
|
||||
|
||||
### Essential Hooks
|
||||
|
||||
#### `__init__(self, *args, **kwargs)`
|
||||
Initialize the model, define layers, and save hyperparameters.
|
||||
|
||||
```python
|
||||
def __init__(self, learning_rate=1e-3, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Saves all args to self.hparams
|
||||
self.model = nn.Sequential(...)
|
||||
```
|
||||
|
||||
#### `forward(self, x)`
|
||||
Define the forward pass for inference. Called by `predict_step` by default.
|
||||
|
||||
```python
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
```
|
||||
|
||||
#### `training_step(self, batch, batch_idx)`
|
||||
Define the training loop logic. Return loss for automatic optimization.
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `validation_step(self, batch, batch_idx)`
|
||||
Define the validation loop logic. Model automatically in eval mode with no gradients.
|
||||
|
||||
```python
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('val_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `test_step(self, batch, batch_idx)`
|
||||
Define the test loop logic. Only runs when `trainer.test()` is called.
|
||||
|
||||
```python
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `predict_step(self, batch, batch_idx, dataloader_idx=0)`
|
||||
Define prediction logic for inference. Defaults to calling `forward()`.
|
||||
|
||||
```python
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
x, y = batch
|
||||
return self(x)
|
||||
```
|
||||
|
||||
#### `configure_optimizers(self)`
|
||||
Return optimizer(s) and optional learning rate scheduler(s).
|
||||
|
||||
```python
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode='min')
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Lifecycle Hooks
|
||||
|
||||
#### Epoch-Level Hooks
|
||||
- `on_train_epoch_start()` - Called at the start of each training epoch
|
||||
- `on_train_epoch_end()` - Called at the end of each training epoch
|
||||
- `on_validation_epoch_start()` - Called at the start of validation epoch
|
||||
- `on_validation_epoch_end()` - Called at the end of validation epoch
|
||||
- `on_test_epoch_start()` - Called at the start of test epoch
|
||||
- `on_test_epoch_end()` - Called at the end of test epoch
|
||||
|
||||
#### Batch-Level Hooks
|
||||
- `on_train_batch_start(batch, batch_idx)` - Called before training batch
|
||||
- `on_train_batch_end(outputs, batch, batch_idx)` - Called after training batch
|
||||
- `on_validation_batch_start(batch, batch_idx)` - Called before validation batch
|
||||
- `on_validation_batch_end(outputs, batch, batch_idx)` - Called after validation batch
|
||||
|
||||
#### Training Lifecycle
|
||||
- `on_fit_start()` - Called at the start of fit
|
||||
- `on_fit_end()` - Called at the end of fit
|
||||
- `on_train_start()` - Called at the start of training
|
||||
- `on_train_end()` - Called at the end of training
|
||||
|
||||
### Logging
|
||||
|
||||
#### `self.log(name, value, **kwargs)`
|
||||
Log a metric to all configured loggers.
|
||||
|
||||
**Common Parameters:**
|
||||
- `on_step` (bool) - Log at each batch step
|
||||
- `on_epoch` (bool) - Log at the end of epoch (automatically aggregated)
|
||||
- `prog_bar` (bool) - Display in progress bar
|
||||
- `logger` (bool) - Send to logger
|
||||
- `sync_dist` (bool) - Synchronize across all distributed processes
|
||||
- `reduce_fx` (str) - Reduction function for distributed ("mean", "sum", etc.)
|
||||
|
||||
```python
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
```
|
||||
|
||||
#### `self.log_dict(dictionary, **kwargs)`
|
||||
Log multiple metrics at once.
|
||||
|
||||
```python
|
||||
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
|
||||
self.log_dict(metrics, on_step=True, on_epoch=True)
|
||||
```
|
||||
|
||||
### Device Management
|
||||
|
||||
- `self.device` - Current device (automatically managed)
|
||||
- `self.to(device)` - Move model to device (usually handled automatically)
|
||||
|
||||
**Best Practice:** Create tensors on model's device:
|
||||
```python
|
||||
new_tensor = torch.zeros(10, device=self.device)
|
||||
```
|
||||
|
||||
### Hyperparameter Management
|
||||
|
||||
#### `self.save_hyperparameters(*args, **kwargs)`
|
||||
Automatically save init arguments to `self.hparams` and checkpoints.
|
||||
|
||||
```python
|
||||
def __init__(self, learning_rate, hidden_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Saves all args
|
||||
# Access via self.hparams.learning_rate, self.hparams.hidden_dim
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Trainer
|
||||
|
||||
The Trainer automates the training loop and engineering complexity.
|
||||
|
||||
### Core Parameters
|
||||
|
||||
#### Training Duration
|
||||
- `max_epochs` (int) - Maximum number of epochs (default: 1000)
|
||||
- `min_epochs` (int) - Minimum number of epochs
|
||||
- `max_steps` (int) - Maximum number of optimizer steps
|
||||
- `min_steps` (int) - Minimum number of optimizer steps
|
||||
- `max_time` (str/dict) - Maximum training time ("DD:HH:MM:SS" or dict)
|
||||
|
||||
#### Hardware Configuration
|
||||
- `accelerator` (str) - Hardware to use: "cpu", "gpu", "tpu", "auto"
|
||||
- `devices` (int/list) - Number or specific device IDs: 1, 4, [0,2], "auto"
|
||||
- `num_nodes` (int) - Number of GPU nodes for distributed training
|
||||
- `strategy` (str) - Training strategy: "ddp", "fsdp", "deepspeed", etc.
|
||||
|
||||
#### Data Management
|
||||
- `limit_train_batches` (int/float) - Limit training batches (0.0-1.0 for %, int for count)
|
||||
- `limit_val_batches` (int/float) - Limit validation batches
|
||||
- `limit_test_batches` (int/float) - Limit test batches
|
||||
- `limit_predict_batches` (int/float) - Limit prediction batches
|
||||
|
||||
#### Validation
|
||||
- `check_val_every_n_epoch` (int) - Run validation every N epochs
|
||||
- `val_check_interval` (int/float) - Validate every N batches or fraction
|
||||
- `num_sanity_val_steps` (int) - Validation steps before training (default: 2)
|
||||
|
||||
#### Optimization
|
||||
- `gradient_clip_val` (float) - Clip gradients by value
|
||||
- `gradient_clip_algorithm` (str) - "value" or "norm"
|
||||
- `accumulate_grad_batches` (int) - Accumulate gradients over K batches
|
||||
- `precision` (str) - Training precision: "32-true", "16-mixed", "bf16-mixed", "64-true"
|
||||
|
||||
#### Logging and Checkpointing
|
||||
- `logger` (Logger/list) - Logger instance(s) or True/False
|
||||
- `log_every_n_steps` (int) - Logging frequency
|
||||
- `enable_checkpointing` (bool) - Enable automatic checkpointing
|
||||
- `callbacks` (list) - List of callback instances
|
||||
- `default_root_dir` (str) - Default path for logs and checkpoints
|
||||
|
||||
#### Debugging
|
||||
- `fast_dev_run` (bool/int) - Run N batches for quick testing
|
||||
- `overfit_batches` (int/float) - Overfit on limited data for debugging
|
||||
- `detect_anomaly` (bool) - Enable PyTorch anomaly detection
|
||||
- `profiler` (str/Profiler) - Profile training: "simple", "advanced", or custom
|
||||
|
||||
#### Performance
|
||||
- `benchmark` (bool) - Enable cudnn.benchmark for performance
|
||||
- `deterministic` (bool) - Enable deterministic training for reproducibility
|
||||
- `sync_batchnorm` (bool) - Synchronize batch norm across GPUs
|
||||
|
||||
### Training Methods
|
||||
|
||||
#### `trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run the full training routine.
|
||||
|
||||
```python
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
# Or with DataModule
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Resume from checkpoint
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path="path/to/checkpoint.ckpt")
|
||||
```
|
||||
|
||||
#### `trainer.validate(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run validation independently.
|
||||
|
||||
```python
|
||||
trainer.validate(model, val_loader)
|
||||
```
|
||||
|
||||
#### `trainer.test(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run test evaluation.
|
||||
|
||||
```python
|
||||
trainer.test(model, test_loader)
|
||||
# Or load from checkpoint
|
||||
trainer.test(ckpt_path="best_model.ckpt", datamodule=dm)
|
||||
```
|
||||
|
||||
#### `trainer.predict(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run inference predictions.
|
||||
|
||||
```python
|
||||
predictions = trainer.predict(model, predict_loader)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LightningDataModule
|
||||
|
||||
Encapsulates all data processing logic in a reusable class.
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### `prepare_data(self)`
|
||||
Download and prepare data (called once on single process).
|
||||
Do NOT set state here (no self.x = y).
|
||||
|
||||
```python
|
||||
def prepare_data(self):
|
||||
# Download datasets
|
||||
datasets.MNIST(self.data_dir, train=True, download=True)
|
||||
datasets.MNIST(self.data_dir, train=False, download=True)
|
||||
```
|
||||
|
||||
#### `setup(self, stage=None)`
|
||||
Load data and create splits (called on every process/GPU).
|
||||
Setting state is OK here.
|
||||
|
||||
**stage parameter:** "fit", "validate", "test", or "predict"
|
||||
|
||||
```python
|
||||
def setup(self, stage=None):
|
||||
if stage == "fit" or stage is None:
|
||||
full_dataset = datasets.MNIST(self.data_dir, train=True)
|
||||
self.train_dataset, self.val_dataset = random_split(full_dataset, [55000, 5000])
|
||||
|
||||
if stage == "test" or stage is None:
|
||||
self.test_dataset = datasets.MNIST(self.data_dir, train=False)
|
||||
```
|
||||
|
||||
#### DataLoader Methods
|
||||
- `train_dataloader(self)` - Return training DataLoader
|
||||
- `val_dataloader(self)` - Return validation DataLoader
|
||||
- `test_dataloader(self)` - Return test DataLoader
|
||||
- `predict_dataloader(self)` - Return prediction DataLoader
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, batch_size=32, shuffle=True)
|
||||
```
|
||||
|
||||
### Optional Methods
|
||||
- `teardown(stage=None)` - Cleanup after training/testing
|
||||
- `state_dict()` - Save state for checkpointing
|
||||
- `load_state_dict(state_dict)` - Load state from checkpoint
|
||||
|
||||
---
|
||||
|
||||
## Callbacks
|
||||
|
||||
Extend training with modular, reusable functionality.
|
||||
|
||||
### Built-in Callbacks
|
||||
|
||||
#### ModelCheckpoint
|
||||
Save model checkpoints based on monitored metrics.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='{epoch}-{val_loss:.2f}',
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `monitor` - Metric to monitor
|
||||
- `mode` - "min" or "max"
|
||||
- `save_top_k` - Save top K models
|
||||
- `save_last` - Always save last checkpoint
|
||||
- `every_n_epochs` - Save every N epochs
|
||||
|
||||
#### EarlyStopping
|
||||
Stop training when metric stops improving.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
mode='min',
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
#### LearningRateMonitor
|
||||
Log learning rate values.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
```
|
||||
|
||||
#### RichProgressBar
|
||||
Display rich progress bar with metrics.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import RichProgressBar
|
||||
|
||||
progress_bar = RichProgressBar()
|
||||
```
|
||||
|
||||
### Custom Callbacks
|
||||
|
||||
Create custom callbacks by inheriting from `Callback`.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class MyCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training starting!")
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
print(f"Epoch {trainer.current_epoch} ended")
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
val_loss = trainer.callback_metrics.get('val_loss')
|
||||
print(f"Validation loss: {val_loss}")
|
||||
```
|
||||
|
||||
**Common Hooks:**
|
||||
- `on_train_start/end`
|
||||
- `on_train_epoch_start/end`
|
||||
- `on_validation_epoch_start/end`
|
||||
- `on_test_epoch_start/end`
|
||||
- `on_before_backward/on_after_backward`
|
||||
- `on_before_optimizer_step`
|
||||
|
||||
---
|
||||
|
||||
## Loggers
|
||||
|
||||
Track experiments with various logging frameworks.
|
||||
|
||||
### TensorBoardLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(save_dir='logs/', name='my_experiment')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### WandbLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(project='my_project', name='experiment_1')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### MLFlowLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
|
||||
logger = MLFlowLogger(experiment_name='my_exp', tracking_uri='file:./ml-runs')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### CSVLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
logger = CSVLogger(save_dir='logs/', name='my_experiment')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### Multiple Loggers
|
||||
```python
|
||||
loggers = [
|
||||
TensorBoardLogger('logs/'),
|
||||
CSVLogger('logs/'),
|
||||
]
|
||||
trainer = Trainer(logger=loggers)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Reproducibility
|
||||
```python
|
||||
from lightning.pytorch import seed_everything
|
||||
|
||||
seed_everything(42, workers=True)
|
||||
trainer = Trainer(deterministic=True)
|
||||
```
|
||||
|
||||
### Mixed Precision Training
|
||||
```python
|
||||
trainer = Trainer(precision='16-mixed') # or 'bf16-mixed'
|
||||
```
|
||||
|
||||
### Multi-GPU Training
|
||||
```python
|
||||
# Data parallel (DDP)
|
||||
trainer = Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
||||
|
||||
# Model parallel (FSDP)
|
||||
trainer = Trainer(accelerator='gpu', devices=4, strategy='fsdp')
|
||||
```
|
||||
|
||||
### Gradient Accumulation
|
||||
```python
|
||||
trainer = Trainer(accumulate_grad_batches=4) # Effective batch size = 4x
|
||||
```
|
||||
|
||||
### Learning Rate Finding
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
|
||||
trainer = Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
lr_finder = tuner.lr_find(model, train_dataloader)
|
||||
model.hparams.learning_rate = lr_finder.suggestion()
|
||||
```
|
||||
|
||||
### Loading from Checkpoint
|
||||
```python
|
||||
# Load model
|
||||
model = MyLightningModule.load_from_checkpoint('checkpoint.ckpt')
|
||||
|
||||
# Resume training
|
||||
trainer.fit(model, ckpt_path='checkpoint.ckpt')
|
||||
```
|
||||
@@ -0,0 +1,508 @@
|
||||
# Distributed and Model Parallel Training
|
||||
|
||||
Comprehensive guide for distributed training strategies in PyTorch Lightning.
|
||||
|
||||
## Overview
|
||||
|
||||
PyTorch Lightning provides seamless distributed training across multiple GPUs, machines, and TPUs with minimal code changes. The framework automatically handles the complexity of distributed training while keeping code device-agnostic and readable.
|
||||
|
||||
## Training Strategies
|
||||
|
||||
### Data Parallel (DDP - DistributedDataParallel)
|
||||
|
||||
**Best for:** Most models (< 500M parameters) where the full model fits in GPU memory.
|
||||
|
||||
**How it works:** Each GPU holds a complete copy of the model and trains on a different batch subset. Gradients are synchronized across GPUs during backward pass.
|
||||
|
||||
```python
|
||||
# Single-node, multi-GPU
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # Use 4 GPUs
|
||||
strategy='ddp',
|
||||
)
|
||||
|
||||
# Multi-node, multi-GPU
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=2, # Number of nodes
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
**Advantages:**
|
||||
- Most widely used and tested
|
||||
- Works with most PyTorch code
|
||||
- Good scaling efficiency
|
||||
- No code changes required in LightningModule
|
||||
|
||||
**When to use:** Default choice for most distributed training scenarios.
|
||||
|
||||
### FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Best for:** Large models (500M+ parameters) that don't fit in single GPU memory.
|
||||
|
||||
**How it works:** Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores a subset of the model.
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='fsdp',
|
||||
)
|
||||
|
||||
# With configuration
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD", # Full sharding
|
||||
cpu_offload=False, # Offload to CPU
|
||||
mixed_precision=torch.float16,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
**Sharding Strategies:**
|
||||
- `FULL_SHARD` - Shard parameters, gradients, and optimizer states
|
||||
- `SHARD_GRAD_OP` - Shard only gradients and optimizer states
|
||||
- `NO_SHARD` - DDP-like (no sharding)
|
||||
- `HYBRID_SHARD` - Shard within node, DDP across nodes
|
||||
|
||||
**Advanced FSDP Configuration:**
|
||||
```python
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
activation_checkpointing=True, # Save memory
|
||||
cpu_offload=True, # Offload parameters to CPU
|
||||
backward_prefetch="BACKWARD_PRE", # Prefetch strategy
|
||||
forward_prefetch=True,
|
||||
limit_all_gathers=True,
|
||||
)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Models > 500M parameters
|
||||
- Limited GPU memory
|
||||
- Native PyTorch solution preferred
|
||||
- Migrating from standalone PyTorch FSDP
|
||||
|
||||
### DeepSpeed
|
||||
|
||||
**Best for:** Cutting-edge features, massive models, or existing DeepSpeed users.
|
||||
|
||||
**How it works:** Comprehensive optimization library with multiple stages of memory and compute optimization.
|
||||
|
||||
```python
|
||||
# Basic DeepSpeed
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='deepspeed',
|
||||
precision='16-mixed',
|
||||
)
|
||||
|
||||
# With configuration
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=2, # ZeRO Stage (1, 2, or 3)
|
||||
offload_optimizer=True,
|
||||
offload_parameters=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
**ZeRO Stages:**
|
||||
- **Stage 1:** Shard optimizer states
|
||||
- **Stage 2:** Shard optimizer states + gradients
|
||||
- **Stage 3:** Shard optimizer states + gradients + parameters (like FSDP)
|
||||
|
||||
**With DeepSpeed Config File:**
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(config="deepspeed_config.json")
|
||||
```
|
||||
|
||||
Example `deepspeed_config.json`:
|
||||
```json
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_bucket_size": 2e8,
|
||||
"reduce_bucket_size": 2e8
|
||||
},
|
||||
"activation_checkpointing": {
|
||||
"partition_activations": true,
|
||||
"cpu_checkpointing": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true
|
||||
},
|
||||
"gradient_clipping": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Need specific DeepSpeed features
|
||||
- Maximum memory efficiency required
|
||||
- Already familiar with DeepSpeed
|
||||
- Training extremely large models
|
||||
|
||||
### DDP Spawn
|
||||
|
||||
**Note:** Generally avoid using `ddp_spawn`. Use `ddp` instead.
|
||||
|
||||
```python
|
||||
trainer = Trainer(strategy='ddp_spawn') # Not recommended
|
||||
```
|
||||
|
||||
**Issues with ddp_spawn:**
|
||||
- Cannot return values from `.fit()`
|
||||
- Pickling issues with unpicklable objects
|
||||
- Slower than `ddp`
|
||||
- More memory overhead
|
||||
|
||||
**When to use:** Only for debugging or if `ddp` doesn't work on your system.
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
### Basic Multi-Node Setup
|
||||
|
||||
```python
|
||||
# On each node, run the same command
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=8, # Total number of nodes
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
### SLURM Cluster
|
||||
|
||||
Lightning automatically detects SLURM environment:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
num_nodes=8,
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
**SLURM Submit Script:**
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=8
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH --job-name=lightning_training
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
### Manual Cluster Setup
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
cluster_environment='TorchElastic', # or 'SLURM', 'LSF', 'Kubeflow'
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
num_nodes=8,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Optimization Techniques
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
Simulate larger batch sizes without increasing memory:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accumulate_grad_batches=4, # Accumulate 4 batches before optimizer step
|
||||
)
|
||||
|
||||
# Variable accumulation by epoch
|
||||
trainer = Trainer(
|
||||
accumulate_grad_batches={
|
||||
0: 8, # Epochs 0-4: accumulate 8 batches
|
||||
5: 4, # Epochs 5+: accumulate 4 batches
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Activation Checkpointing
|
||||
|
||||
Trade computation for memory by recomputing activations during backward pass:
|
||||
|
||||
```python
|
||||
# FSDP
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
|
||||
class MyModule(L.LightningModule):
|
||||
def configure_model(self):
|
||||
# Wrap specific layers for activation checkpointing
|
||||
self.model = MyTransformer()
|
||||
apply_activation_checkpointing(
|
||||
self.model,
|
||||
checkpoint_wrapper_fn=lambda m: checkpoint_wrapper(m, CheckpointImpl.NO_REENTRANT),
|
||||
check_fn=lambda m: isinstance(m, TransformerBlock),
|
||||
)
|
||||
```
|
||||
|
||||
### Mixed Precision Training
|
||||
|
||||
Reduce memory usage and increase speed with mixed precision:
|
||||
|
||||
```python
|
||||
# 16-bit mixed precision
|
||||
trainer = Trainer(precision='16-mixed')
|
||||
|
||||
# BFloat16 mixed precision (more stable, requires newer GPUs)
|
||||
trainer = Trainer(precision='bf16-mixed')
|
||||
```
|
||||
|
||||
### CPU Offloading
|
||||
|
||||
Offload parameters or optimizer states to CPU:
|
||||
|
||||
```python
|
||||
# FSDP with CPU offload
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
cpu_offload=True, # Offload parameters to CPU
|
||||
)
|
||||
|
||||
# DeepSpeed with CPU offload
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3,
|
||||
offload_optimizer=True,
|
||||
offload_parameters=True,
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Synchronize Batch Normalization
|
||||
|
||||
Synchronize batch norm statistics across GPUs:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='ddp',
|
||||
sync_batchnorm=True, # Sync batch norm across GPUs
|
||||
)
|
||||
```
|
||||
|
||||
### Find Optimal Batch Size
|
||||
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
|
||||
trainer = Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
|
||||
# Auto-scale batch size
|
||||
tuner.scale_batch_size(model, mode="power") # or "binsearch"
|
||||
```
|
||||
|
||||
### Gradient Clipping
|
||||
|
||||
Prevent gradient explosion in distributed training:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=1.0,
|
||||
gradient_clip_algorithm='norm', # or 'value'
|
||||
)
|
||||
```
|
||||
|
||||
### Benchmark Mode
|
||||
|
||||
Enable cudnn.benchmark for consistent input sizes:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
benchmark=True, # Optimize for consistent input sizes
|
||||
)
|
||||
```
|
||||
|
||||
## Distributed Data Loading
|
||||
|
||||
### Automatic Distributed Sampling
|
||||
|
||||
Lightning automatically handles distributed sampling:
|
||||
|
||||
```python
|
||||
# No changes needed - Lightning handles this automatically
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True, # Lightning converts to DistributedSampler
|
||||
)
|
||||
```
|
||||
|
||||
### Manual Control
|
||||
|
||||
```python
|
||||
# Disable automatic distributed sampler
|
||||
trainer = Trainer(
|
||||
use_distributed_sampler=False,
|
||||
)
|
||||
|
||||
# Manual distributed sampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
def train_dataloader(self):
|
||||
sampler = DistributedSampler(self.train_dataset)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
sampler=sampler,
|
||||
)
|
||||
```
|
||||
|
||||
### Data Loading Best Practices
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Use multiple workers
|
||||
pin_memory=True, # Faster CPU-GPU transfer
|
||||
persistent_workers=True, # Keep workers alive between epochs
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Logging in Distributed Training
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Automatically syncs across processes
|
||||
self.log('train_loss', loss, sync_dist=True)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Rank-Specific Operations
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Run only on rank 0 (main process)
|
||||
if self.trainer.is_global_zero:
|
||||
print("This only prints once across all processes")
|
||||
|
||||
# Get current rank
|
||||
rank = self.trainer.global_rank
|
||||
world_size = self.trainer.world_size
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Barrier Synchronization
|
||||
|
||||
```python
|
||||
def on_train_epoch_end(self):
|
||||
# Wait for all processes
|
||||
self.trainer.strategy.barrier()
|
||||
|
||||
# Now all processes are synchronized
|
||||
if self.trainer.is_global_zero:
|
||||
# Save something only once
|
||||
self.save_artifacts()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**1. Out of Memory:**
|
||||
- Reduce batch size
|
||||
- Enable gradient accumulation
|
||||
- Use FSDP or DeepSpeed
|
||||
- Enable activation checkpointing
|
||||
- Use mixed precision
|
||||
|
||||
**2. Slow Training:**
|
||||
- Check data loading (use `num_workers > 0`)
|
||||
- Enable `pin_memory=True` and `persistent_workers=True`
|
||||
- Use `benchmark=True` for consistent input sizes
|
||||
- Profile with `profiler='simple'`
|
||||
|
||||
**3. Hanging:**
|
||||
- Ensure all processes execute same collectives
|
||||
- Check for `if` statements that differ across ranks
|
||||
- Use barrier synchronization when needed
|
||||
|
||||
**4. Inconsistent Results:**
|
||||
- Set `deterministic=True`
|
||||
- Use `seed_everything()`
|
||||
- Ensure proper gradient synchronization
|
||||
|
||||
### Debugging Distributed Training
|
||||
|
||||
```python
|
||||
# Test with single GPU first
|
||||
trainer = Trainer(accelerator='gpu', devices=1)
|
||||
|
||||
# Then test with 2 GPUs
|
||||
trainer = Trainer(accelerator='gpu', devices=2, strategy='ddp')
|
||||
|
||||
# Use fast_dev_run for quick testing
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
strategy='ddp',
|
||||
fast_dev_run=10, # Run 10 batches only
|
||||
)
|
||||
```
|
||||
|
||||
## Strategy Selection Guide
|
||||
|
||||
| Model Size | Available Memory | Recommended Strategy |
|
||||
|-----------|------------------|---------------------|
|
||||
| < 500M params | Fits in 1 GPU | Single GPU |
|
||||
| < 500M params | Fits across GPUs | DDP |
|
||||
| 500M - 3B params | Limited memory | FSDP or DeepSpeed Stage 2 |
|
||||
| 3B+ params | Very limited memory | FSDP or DeepSpeed Stage 3 |
|
||||
| Any size | Maximum efficiency | DeepSpeed with offloading |
|
||||
| Multiple nodes | Any | DDP (< 500M) or FSDP/DeepSpeed (> 500M) |
|
||||
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Helper script to quickly set up a PyTorch Lightning Trainer with common configurations.
|
||||
|
||||
This script provides preset configurations for different training scenarios
|
||||
and makes it easy to create a Trainer with best practices.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
RichProgressBar,
|
||||
ModelSummary,
|
||||
)
|
||||
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
|
||||
|
||||
|
||||
def create_trainer(
|
||||
preset: str = "default",
|
||||
max_epochs: int = 100,
|
||||
accelerator: str = "auto",
|
||||
devices: int = 1,
|
||||
log_dir: str = "./logs",
|
||||
experiment_name: str = "lightning_experiment",
|
||||
enable_checkpointing: bool = True,
|
||||
enable_early_stopping: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Create a Lightning Trainer with preset configurations.
|
||||
|
||||
Args:
|
||||
preset: Configuration preset - "default", "fast_dev", "production", "distributed"
|
||||
max_epochs: Maximum number of training epochs
|
||||
accelerator: Device to use ("auto", "gpu", "cpu", "tpu")
|
||||
devices: Number of devices to use
|
||||
log_dir: Directory for logs and checkpoints
|
||||
experiment_name: Name for the experiment
|
||||
enable_checkpointing: Whether to enable model checkpointing
|
||||
enable_early_stopping: Whether to enable early stopping
|
||||
**kwargs: Additional arguments to pass to Trainer
|
||||
|
||||
Returns:
|
||||
Configured Lightning Trainer instance
|
||||
"""
|
||||
|
||||
callbacks = []
|
||||
logger_list = []
|
||||
|
||||
# Configure based on preset
|
||||
if preset == "fast_dev":
|
||||
# Fast development run - minimal epochs, quick debugging
|
||||
config = {
|
||||
"fast_dev_run": False,
|
||||
"max_epochs": 3,
|
||||
"limit_train_batches": 100,
|
||||
"limit_val_batches": 50,
|
||||
"log_every_n_steps": 10,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
}
|
||||
|
||||
elif preset == "production":
|
||||
# Production-ready configuration with all bells and whistles
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"precision": "16-mixed",
|
||||
"gradient_clip_val": 1.0,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
"deterministic": True,
|
||||
"benchmark": True,
|
||||
}
|
||||
|
||||
# Add model checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
verbose=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add early stopping
|
||||
if enable_early_stopping:
|
||||
callbacks.append(
|
||||
EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=10,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add learning rate monitor
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
# Add TensorBoard logger
|
||||
logger_list.append(
|
||||
TensorBoardLogger(
|
||||
save_dir=log_dir,
|
||||
name=experiment_name,
|
||||
version=None,
|
||||
)
|
||||
)
|
||||
|
||||
elif preset == "distributed":
|
||||
# Distributed training configuration
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"strategy": "ddp",
|
||||
"precision": "16-mixed",
|
||||
"sync_batchnorm": True,
|
||||
"use_distributed_sampler": True,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
}
|
||||
|
||||
# Add model checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
)
|
||||
|
||||
else: # default
|
||||
# Default configuration - balanced for most use cases
|
||||
config = {
|
||||
"max_epochs": max_epochs,
|
||||
"log_every_n_steps": 50,
|
||||
"enable_progress_bar": True,
|
||||
"enable_model_summary": True,
|
||||
}
|
||||
|
||||
# Add basic checkpointing
|
||||
if enable_checkpointing:
|
||||
callbacks.append(
|
||||
ModelCheckpoint(
|
||||
dirpath=f"{log_dir}/{experiment_name}/checkpoints",
|
||||
filename="{epoch}-{val_loss:.2f}",
|
||||
monitor="val_loss",
|
||||
save_last=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add CSV logger
|
||||
logger_list.append(
|
||||
CSVLogger(
|
||||
save_dir=log_dir,
|
||||
name=experiment_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add progress bar
|
||||
if config.get("enable_progress_bar", True):
|
||||
callbacks.append(RichProgressBar())
|
||||
|
||||
# Merge with provided kwargs
|
||||
final_config = {
|
||||
**config,
|
||||
"accelerator": accelerator,
|
||||
"devices": devices,
|
||||
"callbacks": callbacks,
|
||||
"logger": logger_list if logger_list else True,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Create and return trainer
|
||||
return L.Trainer(**final_config)
|
||||
|
||||
|
||||
def create_debugging_trainer():
|
||||
"""Create a trainer optimized for debugging."""
|
||||
return create_trainer(
|
||||
preset="fast_dev",
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=5,
|
||||
num_sanity_val_steps=2,
|
||||
)
|
||||
|
||||
|
||||
def create_gpu_trainer(num_gpus: int = 1, precision: str = "16-mixed"):
|
||||
"""Create a trainer optimized for GPU training."""
|
||||
return create_trainer(
|
||||
preset="production",
|
||||
accelerator="gpu",
|
||||
devices=num_gpus,
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
|
||||
def create_distributed_trainer(num_gpus: int = 2, num_nodes: int = 1):
|
||||
"""Create a trainer for distributed training across multiple GPUs."""
|
||||
return create_trainer(
|
||||
preset="distributed",
|
||||
accelerator="gpu",
|
||||
devices=num_gpus,
|
||||
num_nodes=num_nodes,
|
||||
strategy="ddp",
|
||||
)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
print("Creating different trainer configurations...\n")
|
||||
|
||||
# 1. Default trainer
|
||||
print("1. Default trainer:")
|
||||
trainer_default = create_trainer(preset="default", max_epochs=50)
|
||||
print(f" Max epochs: {trainer_default.max_epochs}")
|
||||
print(f" Accelerator: {trainer_default.accelerator}")
|
||||
print(f" Callbacks: {len(trainer_default.callbacks)}")
|
||||
print()
|
||||
|
||||
# 2. Fast development trainer
|
||||
print("2. Fast development trainer:")
|
||||
trainer_dev = create_trainer(preset="fast_dev")
|
||||
print(f" Max epochs: {trainer_dev.max_epochs}")
|
||||
print(f" Train batches limit: {trainer_dev.limit_train_batches}")
|
||||
print()
|
||||
|
||||
# 3. Production trainer
|
||||
print("3. Production trainer:")
|
||||
trainer_prod = create_trainer(
|
||||
preset="production",
|
||||
max_epochs=100,
|
||||
experiment_name="my_experiment"
|
||||
)
|
||||
print(f" Max epochs: {trainer_prod.max_epochs}")
|
||||
print(f" Precision: {trainer_prod.precision}")
|
||||
print(f" Callbacks: {len(trainer_prod.callbacks)}")
|
||||
print()
|
||||
|
||||
# 4. Debugging trainer
|
||||
print("4. Debugging trainer:")
|
||||
trainer_debug = create_debugging_trainer()
|
||||
print(f" Max epochs: {trainer_debug.max_epochs}")
|
||||
print(f" Train batches: {trainer_debug.limit_train_batches}")
|
||||
print()
|
||||
|
||||
# 5. GPU trainer
|
||||
print("5. GPU trainer:")
|
||||
trainer_gpu = create_gpu_trainer(num_gpus=1)
|
||||
print(f" Accelerator: {trainer_gpu.accelerator}")
|
||||
print(f" Precision: {trainer_gpu.precision}")
|
||||
print()
|
||||
|
||||
print("All trainer configurations created successfully!")
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Template for creating a PyTorch Lightning DataModule.
|
||||
|
||||
This template includes all common hooks and patterns for organizing
|
||||
data processing workflows with best practices.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
import torch
|
||||
|
||||
|
||||
class TemplateDataset(Dataset):
|
||||
"""Example dataset - replace with your actual dataset."""
|
||||
|
||||
def __init__(self, data, targets, transform=None):
|
||||
self.data = data
|
||||
self.targets = targets
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.data[idx]
|
||||
y = self.targets[idx]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
class TemplateDataModule(L.LightningDataModule):
|
||||
"""Template DataModule with all common hooks and patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str = "./data",
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
train_val_split: tuple = (0.8, 0.2),
|
||||
seed: int = 42,
|
||||
pin_memory: bool = True,
|
||||
persistent_workers: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Save hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Initialize attributes
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.train_val_split = train_val_split
|
||||
self.seed = seed
|
||||
self.pin_memory = pin_memory
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
# Placeholders for datasets
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
self.test_dataset = None
|
||||
self.predict_dataset = None
|
||||
|
||||
# Placeholder for transforms
|
||||
self.train_transform = None
|
||||
self.val_transform = None
|
||||
self.test_transform = None
|
||||
|
||||
def prepare_data(self):
|
||||
"""
|
||||
Download and prepare data (called only on 1 GPU/TPU in distributed settings).
|
||||
Use this for downloading, tokenizing, etc. Do NOT set state here (no self.x = y).
|
||||
"""
|
||||
# Example: Download datasets
|
||||
# datasets.MNIST(self.data_dir, train=True, download=True)
|
||||
# datasets.MNIST(self.data_dir, train=False, download=True)
|
||||
pass
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
"""
|
||||
Load data and create train/val/test splits (called on every GPU/TPU in distributed).
|
||||
Use this for splitting, creating datasets, etc. Setting state is OK here (self.x = y).
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'
|
||||
"""
|
||||
|
||||
# Fit stage: setup training and validation datasets
|
||||
if stage == "fit" or stage is None:
|
||||
# Load full dataset
|
||||
# Example: full_dataset = datasets.MNIST(self.data_dir, train=True, transform=self.train_transform)
|
||||
|
||||
# Create dummy data for template
|
||||
full_data = torch.randn(1000, 784)
|
||||
full_targets = torch.randint(0, 10, (1000,))
|
||||
full_dataset = TemplateDataset(full_data, full_targets, transform=self.train_transform)
|
||||
|
||||
# Split into train and validation
|
||||
train_size = int(len(full_dataset) * self.train_val_split[0])
|
||||
val_size = len(full_dataset) - train_size
|
||||
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
full_dataset,
|
||||
[train_size, val_size],
|
||||
generator=torch.Generator().manual_seed(self.seed)
|
||||
)
|
||||
|
||||
# Apply validation transform if different from train
|
||||
if self.val_transform:
|
||||
self.val_dataset.dataset.transform = self.val_transform
|
||||
|
||||
# Test stage: setup test dataset
|
||||
if stage == "test" or stage is None:
|
||||
# Example: self.test_dataset = datasets.MNIST(
|
||||
# self.data_dir, train=False, transform=self.test_transform
|
||||
# )
|
||||
|
||||
# Create dummy test data for template
|
||||
test_data = torch.randn(200, 784)
|
||||
test_targets = torch.randint(0, 10, (200,))
|
||||
self.test_dataset = TemplateDataset(test_data, test_targets, transform=self.test_transform)
|
||||
|
||||
# Predict stage: setup prediction dataset
|
||||
if stage == "predict" or stage is None:
|
||||
# Example: self.predict_dataset = YourCustomDataset(...)
|
||||
|
||||
# Create dummy predict data for template
|
||||
predict_data = torch.randn(100, 784)
|
||||
predict_targets = torch.zeros(100, dtype=torch.long)
|
||||
self.predict_dataset = TemplateDataset(predict_data, predict_targets)
|
||||
|
||||
def train_dataloader(self):
|
||||
"""Return training dataloader."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
"""Return validation dataloader."""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""Return test dataloader."""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""Return prediction dataloader."""
|
||||
return DataLoader(
|
||||
self.predict_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
|
||||
)
|
||||
|
||||
def teardown(self, stage: str = None):
|
||||
"""Clean up after fit, validate, test, or predict."""
|
||||
# Example: close database connections, clear caches, etc.
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
"""Save state for checkpointing."""
|
||||
# Return anything you want to save in the checkpoint
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state from checkpoint."""
|
||||
# Restore state from checkpoint
|
||||
pass
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create datamodule
|
||||
datamodule = TemplateDataModule(
|
||||
data_dir="./data",
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
train_val_split=(0.8, 0.2),
|
||||
)
|
||||
|
||||
# Prepare and setup data
|
||||
datamodule.prepare_data()
|
||||
datamodule.setup("fit")
|
||||
|
||||
# Get dataloaders
|
||||
train_loader = datamodule.train_dataloader()
|
||||
val_loader = datamodule.val_dataloader()
|
||||
|
||||
print("Template DataModule created successfully!")
|
||||
print(f"Train batches: {len(train_loader)}")
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
print(f"Batch size: {datamodule.batch_size}")
|
||||
|
||||
# Test a batch
|
||||
batch = next(iter(train_loader))
|
||||
x, y = batch
|
||||
print(f"Batch shape: {x.shape}, {y.shape}")
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Template for creating a PyTorch Lightning LightningModule.
|
||||
|
||||
This template includes all common hooks and patterns for building
|
||||
a Lightning model with best practices.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam, SGD
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
||||
|
||||
|
||||
class TemplateLightningModule(L.LightningModule):
|
||||
"""Template LightningModule with all common hooks and patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Model architecture parameters
|
||||
input_dim: int = 784,
|
||||
hidden_dim: int = 128,
|
||||
output_dim: int = 10,
|
||||
# Optimization parameters
|
||||
learning_rate: float = 1e-3,
|
||||
optimizer_type: str = "adam",
|
||||
scheduler_type: str = None,
|
||||
# Other hyperparameters
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Save hyperparameters for checkpointing and logging
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Define model architecture
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
# Define loss function
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# For tracking validation outputs (optional)
|
||||
self.validation_step_outputs = []
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass for inference."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Training step - called for each training batch."""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""Validation step - called for each validation batch."""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass (model automatically in eval mode)
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
|
||||
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
# Optional: store outputs for epoch-level processing
|
||||
self.validation_step_outputs.append({"loss": loss, "acc": acc})
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
"""Called at the end of validation epoch."""
|
||||
# Optional: process all validation outputs
|
||||
if self.validation_step_outputs:
|
||||
avg_loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()
|
||||
avg_acc = torch.stack([x["acc"] for x in self.validation_step_outputs]).mean()
|
||||
|
||||
# Log epoch-level metrics if needed
|
||||
# self.log("val_epoch_loss", avg_loss)
|
||||
# self.log("val_epoch_acc", avg_acc)
|
||||
|
||||
# Clear outputs
|
||||
self.validation_step_outputs.clear()
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""Test step - called for each test batch."""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = self.criterion(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.log("test_acc", acc, on_step=False, on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
"""Prediction step - called for each prediction batch."""
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
return preds
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Configure optimizer and learning rate scheduler."""
|
||||
# Create optimizer
|
||||
if self.hparams.optimizer_type.lower() == "adam":
|
||||
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
elif self.hparams.optimizer_type.lower() == "sgd":
|
||||
optimizer = SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9)
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer: {self.hparams.optimizer_type}")
|
||||
|
||||
# Configure with scheduler if specified
|
||||
if self.hparams.scheduler_type:
|
||||
if self.hparams.scheduler_type.lower() == "reduce_on_plateau":
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
elif self.hparams.scheduler_type.lower() == "step":
|
||||
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
|
||||
return optimizer
|
||||
|
||||
# Optional: Additional hooks for custom behavior
|
||||
|
||||
def on_train_start(self):
|
||||
"""Called at the beginning of training."""
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
"""Called at the beginning of each training epoch."""
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
"""Called at the end of each training epoch."""
|
||||
pass
|
||||
|
||||
def on_train_end(self):
|
||||
"""Called at the end of training."""
|
||||
pass
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create model
|
||||
model = TemplateLightningModule(
|
||||
input_dim=784,
|
||||
hidden_dim=128,
|
||||
output_dim=10,
|
||||
learning_rate=1e-3,
|
||||
optimizer_type="adam",
|
||||
scheduler_type="reduce_on_plateau"
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
accelerator="auto",
|
||||
devices=1,
|
||||
log_every_n_steps=50,
|
||||
)
|
||||
|
||||
# Note: You would need to provide dataloaders
|
||||
# trainer.fit(model, train_dataloader, val_dataloader)
|
||||
|
||||
print("Template LightningModule created successfully!")
|
||||
print(f"Model hyperparameters: {model.hparams}")
|
||||
Reference in New Issue
Block a user