17 KiB
name, description
| name | description |
|---|---|
| pytorch-lightning | PyTorch training framework. LightningModule, Trainer, distributed training (DDP/FSDP), callbacks, loggers (TensorBoard/WandB), mixed precision, for organized deep learning workflows. |
PyTorch Lightning
Overview
PyTorch Lightning is a deep learning framework that organizes PyTorch code to decouple research from engineering. It automates training loop complexity (multi-GPU, mixed precision, checkpointing, logging) while maintaining full flexibility over model architecture and training logic.
Core Philosophy: Separate concerns
- LightningModule - Research code (model architecture, training logic)
- Trainer - Engineering automation (hardware, optimization, logging)
- DataModule - Data processing (downloading, loading, transforms)
- Callbacks - Non-essential functionality (checkpointing, early stopping)
When to Use This Skill
This skill should be used when:
- Building or training deep learning models with PyTorch
- Converting existing PyTorch code to Lightning structure
- Setting up distributed training across multiple GPUs or nodes
- Implementing custom training loops with validation and testing
- Organizing data processing pipelines
- Configuring experiment logging and model checkpointing
- Optimizing training performance and memory usage
- Working with large models requiring model parallelism
Quick Start
Basic Lightning Workflow
- Define a LightningModule (organize your model)
- Create a DataModule or DataLoaders (organize your data)
- Configure a Trainer (automate training)
- Train with
trainer.fit()
Minimal Example
import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# 1. Define LightningModule
class SimpleModel(L.LightningModule):
def __init__(self, input_dim, output_dim):
super().__init__()
self.save_hyperparameters()
self.model = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 2. Prepare data
train_data = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
train_loader = DataLoader(train_data, batch_size=32)
# 3. Create Trainer
trainer = L.Trainer(max_epochs=10, accelerator='auto')
# 4. Train
model = SimpleModel(input_dim=10, output_dim=1)
trainer.fit(model, train_loader)
Core Workflows
1. Creating a LightningModule
Structure model code by implementing essential hooks:
Template: Use scripts/template_lightning_module.py as a starting point.
class MyLightningModule(L.LightningModule):
def __init__(self, hyperparameters):
super().__init__()
self.save_hyperparameters() # Save for checkpointing
self.model = YourModel()
def forward(self, x):
"""Inference forward pass."""
return self.model(x)
def training_step(self, batch, batch_idx):
"""Define training loop logic."""
x, y = batch
y_hat = self(x)
loss = self.compute_loss(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
"""Define validation logic."""
x, y = batch
y_hat = self(x)
loss = self.compute_loss(y_hat, y)
self.log('val_loss', loss, on_epoch=True)
return loss
def configure_optimizers(self):
"""Return optimizer and optional scheduler."""
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
}
}
Key Points:
- Use
self.save_hyperparameters()to automatically save init args - Use
self.log()to track metrics across loggers - Return loss from training_step for automatic optimization
- Keep model architecture separate from training logic
2. Creating a DataModule
Organize all data processing in a reusable module:
Template: Use scripts/template_datamodule.py as a starting point.
class MyDataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size=32):
super().__init__()
self.save_hyperparameters()
def prepare_data(self):
"""Download data (called once, single process)."""
# Download datasets, tokenize, etc.
pass
def setup(self, stage=None):
"""Create datasets (called on every process)."""
if stage == 'fit' or stage is None:
# Create train/val datasets
self.train_dataset = ...
self.val_dataset = ...
if stage == 'test' or stage is None:
# Create test dataset
self.test_dataset = ...
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size)
Key Points:
prepare_data()for downloading (single process)setup()for creating datasets (every process)- Use
stageparameter to separate fit/test logic - Makes data code reusable across projects
3. Configuring the Trainer
The Trainer automates training complexity:
Helper: Use scripts/quick_trainer_setup.py for preset configurations.
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
trainer = L.Trainer(
# Training duration
max_epochs=100,
# Hardware
accelerator='auto', # 'cpu', 'gpu', 'tpu'
devices=1, # Number of devices or specific IDs
# Optimization
precision='16-mixed', # Mixed precision training
gradient_clip_val=1.0,
accumulate_grad_batches=4, # Gradient accumulation
# Validation
check_val_every_n_epoch=1,
val_check_interval=1.0, # Validate every epoch
# Logging
log_every_n_steps=50,
logger=TensorBoardLogger('logs/'),
# Callbacks
callbacks=[
ModelCheckpoint(monitor='val_loss', mode='min'),
EarlyStopping(monitor='val_loss', patience=10),
],
# Debugging
fast_dev_run=False, # Quick test with few batches
enable_progress_bar=True,
)
Common Presets:
from scripts.quick_trainer_setup import create_trainer
# Development preset (fast debugging)
trainer = create_trainer(preset='fast_dev', max_epochs=3)
# Production preset (full features)
trainer = create_trainer(preset='production', max_epochs=100)
# Distributed preset (multi-GPU)
trainer = create_trainer(preset='distributed', devices=4)
4. Training and Evaluation
# Training
trainer.fit(model, datamodule=dm)
# Or with dataloaders
trainer.fit(model, train_loader, val_loader)
# Resume from checkpoint
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
# Testing
trainer.test(model, datamodule=dm)
# Or load best checkpoint
trainer.test(ckpt_path='best', datamodule=dm)
# Prediction
predictions = trainer.predict(model, predict_loader)
# Validation only
trainer.validate(model, datamodule=dm)
5. Distributed Training
Lightning handles distributed training automatically:
# Single machine, multiple GPUs (Data Parallel)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='ddp', # DistributedDataParallel
)
# Multiple machines, multiple GPUs
trainer = L.Trainer(
accelerator='gpu',
devices=4, # GPUs per node
num_nodes=8, # Number of machines
strategy='ddp',
)
# Large models (Model Parallel with FSDP)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='fsdp', # Fully Sharded Data Parallel
)
# Large models (Model Parallel with DeepSpeed)
trainer = L.Trainer(
accelerator='gpu',
devices=4,
strategy='deepspeed_stage_2',
precision='16-mixed',
)
For detailed distributed training guide, see: references/distributed_training.md
Strategy Selection:
- Models < 500M params → Use
ddp - Models > 500M params → Use
fsdpordeepspeed - Maximum memory efficiency → Use DeepSpeed Stage 3 with offloading
- Native PyTorch → Use
fsdp - Cutting-edge features → Use
deepspeed
6. Callbacks
Extend training with modular functionality:
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
)
callbacks = [
# Save best models
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='{epoch}-{val_loss:.2f}',
),
# Stop when no improvement
EarlyStopping(
monitor='val_loss',
patience=10,
mode='min',
),
# Log learning rate
LearningRateMonitor(logging_interval='epoch'),
# Rich progress bar
RichProgressBar(),
]
trainer = L.Trainer(callbacks=callbacks)
Custom Callbacks:
from lightning.pytorch.callbacks import Callback
class MyCustomCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Custom logic at end of each epoch
print(f"Epoch {trainer.current_epoch} completed")
def on_validation_end(self, trainer, pl_module):
val_loss = trainer.callback_metrics.get('val_loss')
# Custom validation logic
pass
7. Logging
Track experiments with various loggers:
from lightning.pytorch.loggers import (
TensorBoardLogger,
WandbLogger,
CSVLogger,
MLFlowLogger,
)
# Single logger
logger = TensorBoardLogger('logs/', name='my_experiment')
# Multiple loggers
loggers = [
TensorBoardLogger('logs/'),
WandbLogger(project='my_project'),
CSVLogger('logs/'),
]
trainer = L.Trainer(logger=loggers)
Logging in LightningModule:
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Log single metric
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# Log multiple metrics
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
self.log_dict(metrics, on_step=True, on_epoch=True)
return loss
Converting Existing PyTorch Code
Standard PyTorch → Lightning
Before (PyTorch):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(num_epochs):
for batch in train_loader:
optimizer.zero_grad()
x, y = batch
y_hat = model(x)
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
After (Lightning):
class MyLightningModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
trainer = L.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader)
Key Changes:
- Wrap model in LightningModule
- Move training loop logic to
training_step() - Move optimizer setup to
configure_optimizers() - Replace manual loop with
trainer.fit() - Lightning handles:
.zero_grad(),.backward(),.step(), device placement
Common Patterns
Reproducibility
from lightning.pytorch import seed_everything
# Set seed for reproducibility
seed_everything(42, workers=True)
trainer = L.Trainer(deterministic=True)
Mixed Precision Training
# 16-bit mixed precision
trainer = L.Trainer(precision='16-mixed')
# BFloat16 mixed precision (more stable)
trainer = L.Trainer(precision='bf16-mixed')
Gradient Accumulation
# Effective batch size = 4x actual batch size
trainer = L.Trainer(accumulate_grad_batches=4)
Learning Rate Finding
from lightning.pytorch.tuner import Tuner
trainer = L.Trainer()
tuner = Tuner(trainer)
# Find optimal learning rate
lr_finder = tuner.lr_find(model, train_dataloader)
model.hparams.learning_rate = lr_finder.suggestion()
# Find optimal batch size
tuner.scale_batch_size(model, mode="power")
Checkpointing and Loading
# Save checkpoint
trainer.fit(model, datamodule=dm)
# Checkpoint automatically saved to checkpoints/
# Load from checkpoint
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
# Resume training
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
# Test from checkpoint
trainer.test(ckpt_path='best', datamodule=dm)
Debugging
# Quick test with few batches
trainer = L.Trainer(fast_dev_run=10)
# Overfit on small data (debug model)
trainer = L.Trainer(overfit_batches=100)
# Limit batches for quick iteration
trainer = L.Trainer(
limit_train_batches=100,
limit_val_batches=50,
)
# Profile training
trainer = L.Trainer(profiler='simple') # or 'advanced'
Best Practices
Code Organization
-
Separate concerns:
- Model architecture in
__init__() - Training logic in
training_step() - Validation logic in
validation_step() - Data processing in DataModule
- Model architecture in
-
Use
save_hyperparameters():def __init__(self, lr, hidden_dim, dropout): super().__init__() self.save_hyperparameters() # Automatically saves all args -
Device-agnostic code:
# Avoid manual device placement # BAD: tensor.cuda() # GOOD: Lightning handles this automatically # Create tensors on model's device new_tensor = torch.zeros(10, device=self.device) -
Log comprehensively:
self.log('metric', value, on_step=True, on_epoch=True, prog_bar=True)
Performance Optimization
-
Use DataLoader best practices:
DataLoader( dataset, batch_size=32, num_workers=4, # Multiple workers pin_memory=True, # Faster GPU transfer persistent_workers=True, # Keep workers alive ) -
Enable benchmark mode for fixed input sizes:
trainer = L.Trainer(benchmark=True) -
Use gradient clipping:
trainer = L.Trainer(gradient_clip_val=1.0) -
Enable mixed precision:
trainer = L.Trainer(precision='16-mixed')
Distributed Training
-
Sync metrics across devices:
self.log('metric', value, sync_dist=True) -
Rank-specific operations:
if self.trainer.is_global_zero: # Only run on main process self.save_artifacts() -
Use appropriate strategy:
- Small models →
ddp - Large models →
fsdpordeepspeed
- Small models →
Resources
Scripts
Executable templates for quick implementation:
template_lightning_module.py- Complete LightningModule template with all hooks, logging, and optimization patternstemplate_datamodule.py- Complete DataModule template with data loading, splitting, and transformation patternsquick_trainer_setup.py- Helper functions to create Trainers with preset configurations (development, production, distributed)
References
Comprehensive documentation for deep-dive learning:
api_reference.md- Complete API reference covering LightningModule hooks, Trainer parameters, Callbacks, DataModules, Loggers, and common patternsdistributed_training.md- In-depth guide for distributed training strategies (DDP, FSDP, DeepSpeed), multi-node setup, memory optimization, and troubleshooting
Load references when needing detailed information:
# Example: Load distributed training reference
# See references/distributed_training.md for comprehensive distributed training guide
Troubleshooting
Common Issues
Out of Memory:
- Reduce batch size
- Use gradient accumulation
- Enable mixed precision (
precision='16-mixed') - Use FSDP or DeepSpeed for large models
- Enable activation checkpointing
Slow Training:
- Use multiple DataLoader workers (
num_workers > 0) - Enable
pin_memory=Trueandpersistent_workers=True - Enable
benchmark=Truefor fixed input sizes - Profile with
profiler='simple'
Validation Not Running:
- Check
check_val_every_n_epochsetting - Ensure validation data provided
- Verify
validation_step()implemented
Checkpoints Not Saving:
- Ensure
enable_checkpointing=True - Check
ModelCheckpointcallback configuration - Verify
monitormetric exists in logs
Additional Resources
- Official Documentation: https://lightning.ai/docs/pytorch/stable/
- GitHub: https://github.com/Lightning-AI/lightning
- Community: https://lightning.ai/community
When unclear about specific functionality, refer to references/api_reference.md for detailed API documentation or references/distributed_training.md for distributed training specifics.