mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-28 07:33:45 +08:00
Improve Pytorch Lightning skill
This commit is contained in:
@@ -1,660 +1,158 @@
|
||||
---
|
||||
name: pytorch-lightning
|
||||
description: "PyTorch training framework. LightningModule, Trainer, distributed training (DDP/FSDP), callbacks, loggers (TensorBoard/WandB), mixed precision, for organized deep learning workflows."
|
||||
description: Work with PyTorch Lightning for deep learning model training and research. This skill should be used when building, training, or deploying neural networks using PyTorch Lightning, organizing PyTorch code into LightningModules, configuring Trainers for multi-GPU/TPU training, implementing data pipelines with LightningDataModules, or working with callbacks, logging, and distributed training strategies (DDP, FSDP, DeepSpeed).
|
||||
---
|
||||
|
||||
# PyTorch Lightning
|
||||
|
||||
## Overview
|
||||
|
||||
PyTorch Lightning is a deep learning framework that organizes PyTorch code to decouple research from engineering. It automates training loop complexity (multi-GPU, mixed precision, checkpointing, logging) while maintaining full flexibility over model architecture and training logic.
|
||||
|
||||
**Core Philosophy:** Separate concerns
|
||||
- **LightningModule** - Research code (model architecture, training logic)
|
||||
- **Trainer** - Engineering automation (hardware, optimization, logging)
|
||||
- **DataModule** - Data processing (downloading, loading, transforms)
|
||||
- **Callbacks** - Non-essential functionality (checkpointing, early stopping)
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Building or training deep learning models with PyTorch
|
||||
- Converting existing PyTorch code to Lightning structure
|
||||
- Setting up distributed training across multiple GPUs or nodes
|
||||
- Implementing custom training loops with validation and testing
|
||||
- Organizing data processing pipelines
|
||||
- Configuring experiment logging and model checkpointing
|
||||
- Optimizing training performance and memory usage
|
||||
- Working with large models requiring model parallelism
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Lightning Workflow
|
||||
|
||||
1. **Define a LightningModule** (organize your model)
|
||||
2. **Create a DataModule or DataLoaders** (organize your data)
|
||||
3. **Configure a Trainer** (automate training)
|
||||
4. **Train** with `trainer.fit()`
|
||||
|
||||
### Minimal Example
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
# 1. Define LightningModule
|
||||
class SimpleModel(L.LightningModule):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = nn.functional.mse_loss(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# 2. Prepare data
|
||||
train_data = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
|
||||
train_loader = DataLoader(train_data, batch_size=32)
|
||||
|
||||
# 3. Create Trainer
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='auto')
|
||||
|
||||
# 4. Train
|
||||
model = SimpleModel(input_dim=10, output_dim=1)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### 1. Creating a LightningModule
|
||||
|
||||
Structure model code by implementing essential hooks:
|
||||
|
||||
**Template:** Use `scripts/template_lightning_module.py` as a starting point.
|
||||
|
||||
```python
|
||||
class MyLightningModule(L.LightningModule):
|
||||
def __init__(self, hyperparameters):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Save for checkpointing
|
||||
self.model = YourModel()
|
||||
|
||||
def forward(self, x):
|
||||
"""Inference forward pass."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Define training loop logic."""
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = self.compute_loss(y_hat, y)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""Define validation logic."""
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = self.compute_loss(y_hat, y)
|
||||
self.log('val_loss', loss, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Return optimizer and optional scheduler."""
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Use `self.save_hyperparameters()` to automatically save init args
|
||||
- Use `self.log()` to track metrics across loggers
|
||||
- Return loss from training_step for automatic optimization
|
||||
- Keep model architecture separate from training logic
|
||||
|
||||
### 2. Creating a DataModule
|
||||
|
||||
Organize all data processing in a reusable module:
|
||||
|
||||
**Template:** Use `scripts/template_datamodule.py` as a starting point.
|
||||
|
||||
```python
|
||||
class MyDataModule(L.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size=32):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
def prepare_data(self):
|
||||
"""Download data (called once, single process)."""
|
||||
# Download datasets, tokenize, etc.
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
"""Create datasets (called on every process)."""
|
||||
if stage == 'fit' or stage is None:
|
||||
# Create train/val datasets
|
||||
self.train_dataset = ...
|
||||
self.val_dataset = ...
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
# Create test dataset
|
||||
self.test_dataset = ...
|
||||
PyTorch Lightning is a deep learning framework that organizes PyTorch code to eliminate boilerplate while maintaining full flexibility. It automates training workflows, multi-device orchestration, and best practices from research labs. Use this skill when working with neural network training, scaling models across multiple GPUs/TPUs, or structuring deep learning projects professionally.
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size)
|
||||
## Core Capabilities
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- `prepare_data()` for downloading (single process)
|
||||
- `setup()` for creating datasets (every process)
|
||||
- Use `stage` parameter to separate fit/test logic
|
||||
- Makes data code reusable across projects
|
||||
|
||||
### 3. Configuring the Trainer
|
||||
|
||||
The Trainer automates training complexity:
|
||||
|
||||
**Helper:** Use `scripts/quick_trainer_setup.py` for preset configurations.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
trainer = L.Trainer(
|
||||
# Training duration
|
||||
max_epochs=100,
|
||||
|
||||
# Hardware
|
||||
accelerator='auto', # 'cpu', 'gpu', 'tpu'
|
||||
devices=1, # Number of devices or specific IDs
|
||||
|
||||
# Optimization
|
||||
precision='16-mixed', # Mixed precision training
|
||||
gradient_clip_val=1.0,
|
||||
accumulate_grad_batches=4, # Gradient accumulation
|
||||
|
||||
# Validation
|
||||
check_val_every_n_epoch=1,
|
||||
val_check_interval=1.0, # Validate every epoch
|
||||
|
||||
# Logging
|
||||
log_every_n_steps=50,
|
||||
logger=TensorBoardLogger('logs/'),
|
||||
|
||||
# Callbacks
|
||||
callbacks=[
|
||||
ModelCheckpoint(monitor='val_loss', mode='min'),
|
||||
EarlyStopping(monitor='val_loss', patience=10),
|
||||
],
|
||||
|
||||
# Debugging
|
||||
fast_dev_run=False, # Quick test with few batches
|
||||
enable_progress_bar=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Common Presets:**
|
||||
|
||||
```python
|
||||
from scripts.quick_trainer_setup import create_trainer
|
||||
|
||||
# Development preset (fast debugging)
|
||||
trainer = create_trainer(preset='fast_dev', max_epochs=3)
|
||||
|
||||
# Production preset (full features)
|
||||
trainer = create_trainer(preset='production', max_epochs=100)
|
||||
|
||||
# Distributed preset (multi-GPU)
|
||||
trainer = create_trainer(preset='distributed', devices=4)
|
||||
```
|
||||
|
||||
### 4. Training and Evaluation
|
||||
|
||||
```python
|
||||
# Training
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Or with dataloaders
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Resume from checkpoint
|
||||
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
|
||||
|
||||
# Testing
|
||||
trainer.test(model, datamodule=dm)
|
||||
# Or load best checkpoint
|
||||
trainer.test(ckpt_path='best', datamodule=dm)
|
||||
|
||||
# Prediction
|
||||
predictions = trainer.predict(model, predict_loader)
|
||||
|
||||
# Validation only
|
||||
trainer.validate(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### 5. Distributed Training
|
||||
|
||||
Lightning handles distributed training automatically:
|
||||
|
||||
```python
|
||||
# Single machine, multiple GPUs (Data Parallel)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='ddp', # DistributedDataParallel
|
||||
)
|
||||
|
||||
# Multiple machines, multiple GPUs
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=8, # Number of machines
|
||||
strategy='ddp',
|
||||
)
|
||||
|
||||
# Large models (Model Parallel with FSDP)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='fsdp', # Fully Sharded Data Parallel
|
||||
)
|
||||
|
||||
# Large models (Model Parallel with DeepSpeed)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='deepspeed_stage_2',
|
||||
precision='16-mixed',
|
||||
)
|
||||
```
|
||||
|
||||
**For detailed distributed training guide, see:** `references/distributed_training.md`
|
||||
|
||||
**Strategy Selection:**
|
||||
- Models < 500M params → Use `ddp`
|
||||
- Models > 500M params → Use `fsdp` or `deepspeed`
|
||||
- Maximum memory efficiency → Use DeepSpeed Stage 3 with offloading
|
||||
- Native PyTorch → Use `fsdp`
|
||||
- Cutting-edge features → Use `deepspeed`
|
||||
|
||||
### 6. Callbacks
|
||||
|
||||
Extend training with modular functionality:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
RichProgressBar,
|
||||
)
|
||||
|
||||
callbacks = [
|
||||
# Save best models
|
||||
ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
filename='{epoch}-{val_loss:.2f}',
|
||||
),
|
||||
|
||||
# Stop when no improvement
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
mode='min',
|
||||
),
|
||||
|
||||
# Log learning rate
|
||||
LearningRateMonitor(logging_interval='epoch'),
|
||||
|
||||
# Rich progress bar
|
||||
RichProgressBar(),
|
||||
]
|
||||
|
||||
trainer = L.Trainer(callbacks=callbacks)
|
||||
```
|
||||
|
||||
**Custom Callbacks:**
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class MyCustomCallback(Callback):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
# Custom logic at end of each epoch
|
||||
print(f"Epoch {trainer.current_epoch} completed")
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
val_loss = trainer.callback_metrics.get('val_loss')
|
||||
# Custom validation logic
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. Logging
|
||||
|
||||
Track experiments with various loggers:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.loggers import (
|
||||
TensorBoardLogger,
|
||||
WandbLogger,
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
)
|
||||
|
||||
# Single logger
|
||||
logger = TensorBoardLogger('logs/', name='my_experiment')
|
||||
|
||||
# Multiple loggers
|
||||
loggers = [
|
||||
TensorBoardLogger('logs/'),
|
||||
WandbLogger(project='my_project'),
|
||||
CSVLogger('logs/'),
|
||||
]
|
||||
|
||||
trainer = L.Trainer(logger=loggers)
|
||||
```
|
||||
|
||||
**Logging in LightningModule:**
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Log single metric
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
### 1. LightningModule - Model Definition
|
||||
|
||||
# Log multiple metrics
|
||||
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
|
||||
self.log_dict(metrics, on_step=True, on_epoch=True)
|
||||
Organize PyTorch models into six logical sections:
|
||||
|
||||
return loss
|
||||
```
|
||||
1. **Initialization** - `__init__()` and `setup()`
|
||||
2. **Training Loop** - `training_step(batch, batch_idx)`
|
||||
3. **Validation Loop** - `validation_step(batch, batch_idx)`
|
||||
4. **Test Loop** - `test_step(batch, batch_idx)`
|
||||
5. **Prediction** - `predict_step(batch, batch_idx)`
|
||||
6. **Optimizer Configuration** - `configure_optimizers()`
|
||||
|
||||
## Converting Existing PyTorch Code
|
||||
**Quick template reference:** See `scripts/template_lightning_module.py` for a complete boilerplate.
|
||||
|
||||
### Standard PyTorch → Lightning
|
||||
**Detailed documentation:** Read `references/lightning_module.md` for comprehensive method documentation, hooks, properties, and best practices.
|
||||
|
||||
**Before (PyTorch):**
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
### 2. Trainer - Training Automation
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
x, y = batch
|
||||
y_hat = model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
The Trainer automates the training loop, device management, gradient operations, and callbacks. Key features:
|
||||
|
||||
**After (Lightning):**
|
||||
```python
|
||||
class MyLightningModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
- Multi-GPU/TPU support with strategy selection (DDP, FSDP, DeepSpeed)
|
||||
- Automatic mixed precision training
|
||||
- Gradient accumulation and clipping
|
||||
- Checkpointing and early stopping
|
||||
- Progress bars and logging
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
**Quick setup reference:** See `scripts/quick_trainer_setup.py` for common Trainer configurations.
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
**Detailed documentation:** Read `references/trainer.md` for all parameters, methods, and configuration options.
|
||||
|
||||
trainer = L.Trainer(max_epochs=num_epochs)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
### 3. LightningDataModule - Data Pipeline Organization
|
||||
|
||||
**Key Changes:**
|
||||
1. Wrap model in LightningModule
|
||||
2. Move training loop logic to `training_step()`
|
||||
3. Move optimizer setup to `configure_optimizers()`
|
||||
4. Replace manual loop with `trainer.fit()`
|
||||
5. Lightning handles: `.zero_grad()`, `.backward()`, `.step()`, device placement
|
||||
Encapsulate all data processing steps in a reusable class:
|
||||
|
||||
## Common Patterns
|
||||
1. `prepare_data()` - Download and process data (single-process)
|
||||
2. `setup()` - Create datasets and apply transforms (per-GPU)
|
||||
3. `train_dataloader()` - Return training DataLoader
|
||||
4. `val_dataloader()` - Return validation DataLoader
|
||||
5. `test_dataloader()` - Return test DataLoader
|
||||
|
||||
### Reproducibility
|
||||
**Quick template reference:** See `scripts/template_datamodule.py` for a complete boilerplate.
|
||||
|
||||
```python
|
||||
from lightning.pytorch import seed_everything
|
||||
**Detailed documentation:** Read `references/data_module.md` for method details and usage patterns.
|
||||
|
||||
# Set seed for reproducibility
|
||||
seed_everything(42, workers=True)
|
||||
### 4. Callbacks - Extensible Training Logic
|
||||
|
||||
trainer = L.Trainer(deterministic=True)
|
||||
```
|
||||
Add custom functionality at specific training hooks without modifying your LightningModule. Built-in callbacks include:
|
||||
|
||||
### Mixed Precision Training
|
||||
- **ModelCheckpoint** - Save best/latest models
|
||||
- **EarlyStopping** - Stop when metrics plateau
|
||||
- **LearningRateMonitor** - Track LR scheduler changes
|
||||
- **BatchSizeFinder** - Auto-determine optimal batch size
|
||||
|
||||
```python
|
||||
# 16-bit mixed precision
|
||||
trainer = L.Trainer(precision='16-mixed')
|
||||
**Detailed documentation:** Read `references/callbacks.md` for built-in callbacks and custom callback creation.
|
||||
|
||||
# BFloat16 mixed precision (more stable)
|
||||
trainer = L.Trainer(precision='bf16-mixed')
|
||||
```
|
||||
### 5. Logging - Experiment Tracking
|
||||
|
||||
### Gradient Accumulation
|
||||
Integrate with multiple logging platforms:
|
||||
|
||||
```python
|
||||
# Effective batch size = 4x actual batch size
|
||||
trainer = L.Trainer(accumulate_grad_batches=4)
|
||||
```
|
||||
- TensorBoard (default)
|
||||
- Weights & Biases (WandbLogger)
|
||||
- MLflow (MLFlowLogger)
|
||||
- Neptune (NeptuneLogger)
|
||||
- Comet (CometLogger)
|
||||
- CSV (CSVLogger)
|
||||
|
||||
### Learning Rate Finding
|
||||
Log metrics using `self.log("metric_name", value)` in any LightningModule method.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
**Detailed documentation:** Read `references/logging.md` for logger setup and configuration.
|
||||
|
||||
trainer = L.Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
### 6. Distributed Training - Scale to Multiple Devices
|
||||
|
||||
# Find optimal learning rate
|
||||
lr_finder = tuner.lr_find(model, train_dataloader)
|
||||
model.hparams.learning_rate = lr_finder.suggestion()
|
||||
Choose the right strategy based on model size:
|
||||
|
||||
# Find optimal batch size
|
||||
tuner.scale_batch_size(model, mode="power")
|
||||
```
|
||||
- **DDP** - For models <500M parameters (ResNet, smaller transformers)
|
||||
- **FSDP** - For models 500M+ parameters (large transformers, recommended for Lightning users)
|
||||
- **DeepSpeed** - For cutting-edge features and fine-grained control
|
||||
|
||||
### Checkpointing and Loading
|
||||
Configure with: `Trainer(strategy="ddp", accelerator="gpu", devices=4)`
|
||||
|
||||
```python
|
||||
# Save checkpoint
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Checkpoint automatically saved to checkpoints/
|
||||
**Detailed documentation:** Read `references/distributed_training.md` for strategy comparison and configuration.
|
||||
|
||||
# Load from checkpoint
|
||||
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
||||
### 7. Best Practices
|
||||
|
||||
# Resume training
|
||||
trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt')
|
||||
- Device agnostic code - Use `self.device` instead of `.cuda()`
|
||||
- Hyperparameter saving - Use `self.save_hyperparameters()` in `__init__()`
|
||||
- Metric logging - Use `self.log()` for automatic aggregation across devices
|
||||
- Reproducibility - Use `seed_everything()` and `Trainer(deterministic=True)`
|
||||
- Debugging - Use `Trainer(fast_dev_run=True)` to test with 1 batch
|
||||
|
||||
# Test from checkpoint
|
||||
trainer.test(ckpt_path='best', datamodule=dm)
|
||||
```
|
||||
**Detailed documentation:** Read `references/best_practices.md` for common patterns and pitfalls.
|
||||
|
||||
### Debugging
|
||||
## Quick Workflow
|
||||
|
||||
```python
|
||||
# Quick test with few batches
|
||||
trainer = L.Trainer(fast_dev_run=10)
|
||||
|
||||
# Overfit on small data (debug model)
|
||||
trainer = L.Trainer(overfit_batches=100)
|
||||
|
||||
# Limit batches for quick iteration
|
||||
trainer = L.Trainer(
|
||||
limit_train_batches=100,
|
||||
limit_val_batches=50,
|
||||
)
|
||||
|
||||
# Profile training
|
||||
trainer = L.Trainer(profiler='simple') # or 'advanced'
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Code Organization
|
||||
|
||||
1. **Separate concerns:**
|
||||
- Model architecture in `__init__()`
|
||||
- Training logic in `training_step()`
|
||||
- Validation logic in `validation_step()`
|
||||
- Data processing in DataModule
|
||||
|
||||
2. **Use `save_hyperparameters()`:**
|
||||
1. **Define model:**
|
||||
```python
|
||||
def __init__(self, lr, hidden_dim, dropout):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Automatically saves all args
|
||||
class MyModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = YourNetwork()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
loss = F.cross_entropy(self.model(x), y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
```
|
||||
|
||||
3. **Device-agnostic code:**
|
||||
2. **Prepare data:**
|
||||
```python
|
||||
# Avoid manual device placement
|
||||
# BAD: tensor.cuda()
|
||||
# GOOD: Lightning handles this automatically
|
||||
# Option 1: Direct DataLoaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=32)
|
||||
|
||||
# Create tensors on model's device
|
||||
new_tensor = torch.zeros(10, device=self.device)
|
||||
# Option 2: LightningDataModule (recommended for reusability)
|
||||
dm = MyDataModule(batch_size=32)
|
||||
```
|
||||
|
||||
4. **Log comprehensively:**
|
||||
3. **Train:**
|
||||
```python
|
||||
self.log('metric', value, on_step=True, on_epoch=True, prog_bar=True)
|
||||
trainer = L.Trainer(max_epochs=10, accelerator="gpu", devices=2)
|
||||
trainer.fit(model, train_loader) # or trainer.fit(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **Use DataLoader best practices:**
|
||||
```python
|
||||
DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Multiple workers
|
||||
pin_memory=True, # Faster GPU transfer
|
||||
persistent_workers=True, # Keep workers alive
|
||||
)
|
||||
```
|
||||
|
||||
2. **Enable benchmark mode for fixed input sizes:**
|
||||
```python
|
||||
trainer = L.Trainer(benchmark=True)
|
||||
```
|
||||
|
||||
3. **Use gradient clipping:**
|
||||
```python
|
||||
trainer = L.Trainer(gradient_clip_val=1.0)
|
||||
```
|
||||
|
||||
4. **Enable mixed precision:**
|
||||
```python
|
||||
trainer = L.Trainer(precision='16-mixed')
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
1. **Sync metrics across devices:**
|
||||
```python
|
||||
self.log('metric', value, sync_dist=True)
|
||||
```
|
||||
|
||||
2. **Rank-specific operations:**
|
||||
```python
|
||||
if self.trainer.is_global_zero:
|
||||
# Only run on main process
|
||||
self.save_artifacts()
|
||||
```
|
||||
|
||||
3. **Use appropriate strategy:**
|
||||
- Small models → `ddp`
|
||||
- Large models → `fsdp` or `deepspeed`
|
||||
|
||||
## Resources
|
||||
|
||||
### Scripts
|
||||
### scripts/
|
||||
Executable Python templates for common PyTorch Lightning patterns:
|
||||
|
||||
Executable templates for quick implementation:
|
||||
- `template_lightning_module.py` - Complete LightningModule boilerplate
|
||||
- `template_datamodule.py` - Complete LightningDataModule boilerplate
|
||||
- `quick_trainer_setup.py` - Common Trainer configuration examples
|
||||
|
||||
- **`template_lightning_module.py`** - Complete LightningModule template with all hooks, logging, and optimization patterns
|
||||
- **`template_datamodule.py`** - Complete DataModule template with data loading, splitting, and transformation patterns
|
||||
- **`quick_trainer_setup.py`** - Helper functions to create Trainers with preset configurations (development, production, distributed)
|
||||
### references/
|
||||
Detailed documentation for each PyTorch Lightning component:
|
||||
|
||||
### References
|
||||
|
||||
Comprehensive documentation for deep-dive learning:
|
||||
|
||||
- **`api_reference.md`** - Complete API reference covering LightningModule hooks, Trainer parameters, Callbacks, DataModules, Loggers, and common patterns
|
||||
- **`distributed_training.md`** - In-depth guide for distributed training strategies (DDP, FSDP, DeepSpeed), multi-node setup, memory optimization, and troubleshooting
|
||||
|
||||
Load references when needing detailed information:
|
||||
```python
|
||||
# Example: Load distributed training reference
|
||||
# See references/distributed_training.md for comprehensive distributed training guide
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Out of Memory:**
|
||||
- Reduce batch size
|
||||
- Use gradient accumulation
|
||||
- Enable mixed precision (`precision='16-mixed'`)
|
||||
- Use FSDP or DeepSpeed for large models
|
||||
- Enable activation checkpointing
|
||||
|
||||
**Slow Training:**
|
||||
- Use multiple DataLoader workers (`num_workers > 0`)
|
||||
- Enable `pin_memory=True` and `persistent_workers=True`
|
||||
- Enable `benchmark=True` for fixed input sizes
|
||||
- Profile with `profiler='simple'`
|
||||
|
||||
**Validation Not Running:**
|
||||
- Check `check_val_every_n_epoch` setting
|
||||
- Ensure validation data provided
|
||||
- Verify `validation_step()` implemented
|
||||
|
||||
**Checkpoints Not Saving:**
|
||||
- Ensure `enable_checkpointing=True`
|
||||
- Check `ModelCheckpoint` callback configuration
|
||||
- Verify `monitor` metric exists in logs
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official Documentation: https://lightning.ai/docs/pytorch/stable/
|
||||
- GitHub: https://github.com/Lightning-AI/lightning
|
||||
- Community: https://lightning.ai/community
|
||||
|
||||
When unclear about specific functionality, refer to `references/api_reference.md` for detailed API documentation or `references/distributed_training.md` for distributed training specifics.
|
||||
- `lightning_module.md` - Comprehensive LightningModule guide (methods, hooks, properties)
|
||||
- `trainer.md` - Trainer configuration and parameters
|
||||
- `data_module.md` - LightningDataModule patterns and methods
|
||||
- `callbacks.md` - Built-in and custom callbacks
|
||||
- `logging.md` - Logger integrations and usage
|
||||
- `distributed_training.md` - DDP, FSDP, DeepSpeed comparison and setup
|
||||
- `best_practices.md` - Common patterns, tips, and pitfalls
|
||||
|
||||
Reference in New Issue
Block a user