mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-28 07:33:45 +08:00
Add more scientific skills
This commit is contained in:
@@ -0,0 +1,490 @@
|
||||
# PyTorch Lightning API Reference
|
||||
|
||||
Comprehensive reference for PyTorch Lightning core APIs, hooks, and components.
|
||||
|
||||
## LightningModule
|
||||
|
||||
The LightningModule is the core abstraction for organizing PyTorch code in Lightning.
|
||||
|
||||
### Essential Hooks
|
||||
|
||||
#### `__init__(self, *args, **kwargs)`
|
||||
Initialize the model, define layers, and save hyperparameters.
|
||||
|
||||
```python
|
||||
def __init__(self, learning_rate=1e-3, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Saves all args to self.hparams
|
||||
self.model = nn.Sequential(...)
|
||||
```
|
||||
|
||||
#### `forward(self, x)`
|
||||
Define the forward pass for inference. Called by `predict_step` by default.
|
||||
|
||||
```python
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
```
|
||||
|
||||
#### `training_step(self, batch, batch_idx)`
|
||||
Define the training loop logic. Return loss for automatic optimization.
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `validation_step(self, batch, batch_idx)`
|
||||
Define the validation loop logic. Model automatically in eval mode with no gradients.
|
||||
|
||||
```python
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('val_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `test_step(self, batch, batch_idx)`
|
||||
Define the test loop logic. Only runs when `trainer.test()` is called.
|
||||
|
||||
```python
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', loss)
|
||||
return loss
|
||||
```
|
||||
|
||||
#### `predict_step(self, batch, batch_idx, dataloader_idx=0)`
|
||||
Define prediction logic for inference. Defaults to calling `forward()`.
|
||||
|
||||
```python
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
x, y = batch
|
||||
return self(x)
|
||||
```
|
||||
|
||||
#### `configure_optimizers(self)`
|
||||
Return optimizer(s) and optional learning rate scheduler(s).
|
||||
|
||||
```python
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode='min')
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val_loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Lifecycle Hooks
|
||||
|
||||
#### Epoch-Level Hooks
|
||||
- `on_train_epoch_start()` - Called at the start of each training epoch
|
||||
- `on_train_epoch_end()` - Called at the end of each training epoch
|
||||
- `on_validation_epoch_start()` - Called at the start of validation epoch
|
||||
- `on_validation_epoch_end()` - Called at the end of validation epoch
|
||||
- `on_test_epoch_start()` - Called at the start of test epoch
|
||||
- `on_test_epoch_end()` - Called at the end of test epoch
|
||||
|
||||
#### Batch-Level Hooks
|
||||
- `on_train_batch_start(batch, batch_idx)` - Called before training batch
|
||||
- `on_train_batch_end(outputs, batch, batch_idx)` - Called after training batch
|
||||
- `on_validation_batch_start(batch, batch_idx)` - Called before validation batch
|
||||
- `on_validation_batch_end(outputs, batch, batch_idx)` - Called after validation batch
|
||||
|
||||
#### Training Lifecycle
|
||||
- `on_fit_start()` - Called at the start of fit
|
||||
- `on_fit_end()` - Called at the end of fit
|
||||
- `on_train_start()` - Called at the start of training
|
||||
- `on_train_end()` - Called at the end of training
|
||||
|
||||
### Logging
|
||||
|
||||
#### `self.log(name, value, **kwargs)`
|
||||
Log a metric to all configured loggers.
|
||||
|
||||
**Common Parameters:**
|
||||
- `on_step` (bool) - Log at each batch step
|
||||
- `on_epoch` (bool) - Log at the end of epoch (automatically aggregated)
|
||||
- `prog_bar` (bool) - Display in progress bar
|
||||
- `logger` (bool) - Send to logger
|
||||
- `sync_dist` (bool) - Synchronize across all distributed processes
|
||||
- `reduce_fx` (str) - Reduction function for distributed ("mean", "sum", etc.)
|
||||
|
||||
```python
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
```
|
||||
|
||||
#### `self.log_dict(dictionary, **kwargs)`
|
||||
Log multiple metrics at once.
|
||||
|
||||
```python
|
||||
metrics = {'loss': loss, 'acc': acc, 'f1': f1}
|
||||
self.log_dict(metrics, on_step=True, on_epoch=True)
|
||||
```
|
||||
|
||||
### Device Management
|
||||
|
||||
- `self.device` - Current device (automatically managed)
|
||||
- `self.to(device)` - Move model to device (usually handled automatically)
|
||||
|
||||
**Best Practice:** Create tensors on model's device:
|
||||
```python
|
||||
new_tensor = torch.zeros(10, device=self.device)
|
||||
```
|
||||
|
||||
### Hyperparameter Management
|
||||
|
||||
#### `self.save_hyperparameters(*args, **kwargs)`
|
||||
Automatically save init arguments to `self.hparams` and checkpoints.
|
||||
|
||||
```python
|
||||
def __init__(self, learning_rate, hidden_dim):
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # Saves all args
|
||||
# Access via self.hparams.learning_rate, self.hparams.hidden_dim
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Trainer
|
||||
|
||||
The Trainer automates the training loop and engineering complexity.
|
||||
|
||||
### Core Parameters
|
||||
|
||||
#### Training Duration
|
||||
- `max_epochs` (int) - Maximum number of epochs (default: 1000)
|
||||
- `min_epochs` (int) - Minimum number of epochs
|
||||
- `max_steps` (int) - Maximum number of optimizer steps
|
||||
- `min_steps` (int) - Minimum number of optimizer steps
|
||||
- `max_time` (str/dict) - Maximum training time ("DD:HH:MM:SS" or dict)
|
||||
|
||||
#### Hardware Configuration
|
||||
- `accelerator` (str) - Hardware to use: "cpu", "gpu", "tpu", "auto"
|
||||
- `devices` (int/list) - Number or specific device IDs: 1, 4, [0,2], "auto"
|
||||
- `num_nodes` (int) - Number of GPU nodes for distributed training
|
||||
- `strategy` (str) - Training strategy: "ddp", "fsdp", "deepspeed", etc.
|
||||
|
||||
#### Data Management
|
||||
- `limit_train_batches` (int/float) - Limit training batches (0.0-1.0 for %, int for count)
|
||||
- `limit_val_batches` (int/float) - Limit validation batches
|
||||
- `limit_test_batches` (int/float) - Limit test batches
|
||||
- `limit_predict_batches` (int/float) - Limit prediction batches
|
||||
|
||||
#### Validation
|
||||
- `check_val_every_n_epoch` (int) - Run validation every N epochs
|
||||
- `val_check_interval` (int/float) - Validate every N batches or fraction
|
||||
- `num_sanity_val_steps` (int) - Validation steps before training (default: 2)
|
||||
|
||||
#### Optimization
|
||||
- `gradient_clip_val` (float) - Clip gradients by value
|
||||
- `gradient_clip_algorithm` (str) - "value" or "norm"
|
||||
- `accumulate_grad_batches` (int) - Accumulate gradients over K batches
|
||||
- `precision` (str) - Training precision: "32-true", "16-mixed", "bf16-mixed", "64-true"
|
||||
|
||||
#### Logging and Checkpointing
|
||||
- `logger` (Logger/list) - Logger instance(s) or True/False
|
||||
- `log_every_n_steps` (int) - Logging frequency
|
||||
- `enable_checkpointing` (bool) - Enable automatic checkpointing
|
||||
- `callbacks` (list) - List of callback instances
|
||||
- `default_root_dir` (str) - Default path for logs and checkpoints
|
||||
|
||||
#### Debugging
|
||||
- `fast_dev_run` (bool/int) - Run N batches for quick testing
|
||||
- `overfit_batches` (int/float) - Overfit on limited data for debugging
|
||||
- `detect_anomaly` (bool) - Enable PyTorch anomaly detection
|
||||
- `profiler` (str/Profiler) - Profile training: "simple", "advanced", or custom
|
||||
|
||||
#### Performance
|
||||
- `benchmark` (bool) - Enable cudnn.benchmark for performance
|
||||
- `deterministic` (bool) - Enable deterministic training for reproducibility
|
||||
- `sync_batchnorm` (bool) - Synchronize batch norm across GPUs
|
||||
|
||||
### Training Methods
|
||||
|
||||
#### `trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run the full training routine.
|
||||
|
||||
```python
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
# Or with DataModule
|
||||
trainer.fit(model, datamodule=dm)
|
||||
# Resume from checkpoint
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path="path/to/checkpoint.ckpt")
|
||||
```
|
||||
|
||||
#### `trainer.validate(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run validation independently.
|
||||
|
||||
```python
|
||||
trainer.validate(model, val_loader)
|
||||
```
|
||||
|
||||
#### `trainer.test(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run test evaluation.
|
||||
|
||||
```python
|
||||
trainer.test(model, test_loader)
|
||||
# Or load from checkpoint
|
||||
trainer.test(ckpt_path="best_model.ckpt", datamodule=dm)
|
||||
```
|
||||
|
||||
#### `trainer.predict(model, dataloaders=None, datamodule=None, ckpt_path=None)`
|
||||
Run inference predictions.
|
||||
|
||||
```python
|
||||
predictions = trainer.predict(model, predict_loader)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LightningDataModule
|
||||
|
||||
Encapsulates all data processing logic in a reusable class.
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### `prepare_data(self)`
|
||||
Download and prepare data (called once on single process).
|
||||
Do NOT set state here (no self.x = y).
|
||||
|
||||
```python
|
||||
def prepare_data(self):
|
||||
# Download datasets
|
||||
datasets.MNIST(self.data_dir, train=True, download=True)
|
||||
datasets.MNIST(self.data_dir, train=False, download=True)
|
||||
```
|
||||
|
||||
#### `setup(self, stage=None)`
|
||||
Load data and create splits (called on every process/GPU).
|
||||
Setting state is OK here.
|
||||
|
||||
**stage parameter:** "fit", "validate", "test", or "predict"
|
||||
|
||||
```python
|
||||
def setup(self, stage=None):
|
||||
if stage == "fit" or stage is None:
|
||||
full_dataset = datasets.MNIST(self.data_dir, train=True)
|
||||
self.train_dataset, self.val_dataset = random_split(full_dataset, [55000, 5000])
|
||||
|
||||
if stage == "test" or stage is None:
|
||||
self.test_dataset = datasets.MNIST(self.data_dir, train=False)
|
||||
```
|
||||
|
||||
#### DataLoader Methods
|
||||
- `train_dataloader(self)` - Return training DataLoader
|
||||
- `val_dataloader(self)` - Return validation DataLoader
|
||||
- `test_dataloader(self)` - Return test DataLoader
|
||||
- `predict_dataloader(self)` - Return prediction DataLoader
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, batch_size=32, shuffle=True)
|
||||
```
|
||||
|
||||
### Optional Methods
|
||||
- `teardown(stage=None)` - Cleanup after training/testing
|
||||
- `state_dict()` - Save state for checkpointing
|
||||
- `load_state_dict(state_dict)` - Load state from checkpoint
|
||||
|
||||
---
|
||||
|
||||
## Callbacks
|
||||
|
||||
Extend training with modular, reusable functionality.
|
||||
|
||||
### Built-in Callbacks
|
||||
|
||||
#### ModelCheckpoint
|
||||
Save model checkpoints based on monitored metrics.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='{epoch}-{val_loss:.2f}',
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `monitor` - Metric to monitor
|
||||
- `mode` - "min" or "max"
|
||||
- `save_top_k` - Save top K models
|
||||
- `save_last` - Always save last checkpoint
|
||||
- `every_n_epochs` - Save every N epochs
|
||||
|
||||
#### EarlyStopping
|
||||
Stop training when metric stops improving.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
mode='min',
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
#### LearningRateMonitor
|
||||
Log learning rate values.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
```
|
||||
|
||||
#### RichProgressBar
|
||||
Display rich progress bar with metrics.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import RichProgressBar
|
||||
|
||||
progress_bar = RichProgressBar()
|
||||
```
|
||||
|
||||
### Custom Callbacks
|
||||
|
||||
Create custom callbacks by inheriting from `Callback`.
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class MyCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training starting!")
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
print(f"Epoch {trainer.current_epoch} ended")
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
val_loss = trainer.callback_metrics.get('val_loss')
|
||||
print(f"Validation loss: {val_loss}")
|
||||
```
|
||||
|
||||
**Common Hooks:**
|
||||
- `on_train_start/end`
|
||||
- `on_train_epoch_start/end`
|
||||
- `on_validation_epoch_start/end`
|
||||
- `on_test_epoch_start/end`
|
||||
- `on_before_backward/on_after_backward`
|
||||
- `on_before_optimizer_step`
|
||||
|
||||
---
|
||||
|
||||
## Loggers
|
||||
|
||||
Track experiments with various logging frameworks.
|
||||
|
||||
### TensorBoardLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
logger = TensorBoardLogger(save_dir='logs/', name='my_experiment')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### WandbLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(project='my_project', name='experiment_1')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### MLFlowLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
|
||||
logger = MLFlowLogger(experiment_name='my_exp', tracking_uri='file:./ml-runs')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### CSVLogger
|
||||
```python
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
logger = CSVLogger(save_dir='logs/', name='my_experiment')
|
||||
trainer = Trainer(logger=logger)
|
||||
```
|
||||
|
||||
### Multiple Loggers
|
||||
```python
|
||||
loggers = [
|
||||
TensorBoardLogger('logs/'),
|
||||
CSVLogger('logs/'),
|
||||
]
|
||||
trainer = Trainer(logger=loggers)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Reproducibility
|
||||
```python
|
||||
from lightning.pytorch import seed_everything
|
||||
|
||||
seed_everything(42, workers=True)
|
||||
trainer = Trainer(deterministic=True)
|
||||
```
|
||||
|
||||
### Mixed Precision Training
|
||||
```python
|
||||
trainer = Trainer(precision='16-mixed') # or 'bf16-mixed'
|
||||
```
|
||||
|
||||
### Multi-GPU Training
|
||||
```python
|
||||
# Data parallel (DDP)
|
||||
trainer = Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
||||
|
||||
# Model parallel (FSDP)
|
||||
trainer = Trainer(accelerator='gpu', devices=4, strategy='fsdp')
|
||||
```
|
||||
|
||||
### Gradient Accumulation
|
||||
```python
|
||||
trainer = Trainer(accumulate_grad_batches=4) # Effective batch size = 4x
|
||||
```
|
||||
|
||||
### Learning Rate Finding
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
|
||||
trainer = Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
lr_finder = tuner.lr_find(model, train_dataloader)
|
||||
model.hparams.learning_rate = lr_finder.suggestion()
|
||||
```
|
||||
|
||||
### Loading from Checkpoint
|
||||
```python
|
||||
# Load model
|
||||
model = MyLightningModule.load_from_checkpoint('checkpoint.ckpt')
|
||||
|
||||
# Resume training
|
||||
trainer.fit(model, ckpt_path='checkpoint.ckpt')
|
||||
```
|
||||
@@ -0,0 +1,508 @@
|
||||
# Distributed and Model Parallel Training
|
||||
|
||||
Comprehensive guide for distributed training strategies in PyTorch Lightning.
|
||||
|
||||
## Overview
|
||||
|
||||
PyTorch Lightning provides seamless distributed training across multiple GPUs, machines, and TPUs with minimal code changes. The framework automatically handles the complexity of distributed training while keeping code device-agnostic and readable.
|
||||
|
||||
## Training Strategies
|
||||
|
||||
### Data Parallel (DDP - DistributedDataParallel)
|
||||
|
||||
**Best for:** Most models (< 500M parameters) where the full model fits in GPU memory.
|
||||
|
||||
**How it works:** Each GPU holds a complete copy of the model and trains on a different batch subset. Gradients are synchronized across GPUs during backward pass.
|
||||
|
||||
```python
|
||||
# Single-node, multi-GPU
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # Use 4 GPUs
|
||||
strategy='ddp',
|
||||
)
|
||||
|
||||
# Multi-node, multi-GPU
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=2, # Number of nodes
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
**Advantages:**
|
||||
- Most widely used and tested
|
||||
- Works with most PyTorch code
|
||||
- Good scaling efficiency
|
||||
- No code changes required in LightningModule
|
||||
|
||||
**When to use:** Default choice for most distributed training scenarios.
|
||||
|
||||
### FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Best for:** Large models (500M+ parameters) that don't fit in single GPU memory.
|
||||
|
||||
**How it works:** Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores a subset of the model.
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='fsdp',
|
||||
)
|
||||
|
||||
# With configuration
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD", # Full sharding
|
||||
cpu_offload=False, # Offload to CPU
|
||||
mixed_precision=torch.float16,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
**Sharding Strategies:**
|
||||
- `FULL_SHARD` - Shard parameters, gradients, and optimizer states
|
||||
- `SHARD_GRAD_OP` - Shard only gradients and optimizer states
|
||||
- `NO_SHARD` - DDP-like (no sharding)
|
||||
- `HYBRID_SHARD` - Shard within node, DDP across nodes
|
||||
|
||||
**Advanced FSDP Configuration:**
|
||||
```python
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
activation_checkpointing=True, # Save memory
|
||||
cpu_offload=True, # Offload parameters to CPU
|
||||
backward_prefetch="BACKWARD_PRE", # Prefetch strategy
|
||||
forward_prefetch=True,
|
||||
limit_all_gathers=True,
|
||||
)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Models > 500M parameters
|
||||
- Limited GPU memory
|
||||
- Native PyTorch solution preferred
|
||||
- Migrating from standalone PyTorch FSDP
|
||||
|
||||
### DeepSpeed
|
||||
|
||||
**Best for:** Cutting-edge features, massive models, or existing DeepSpeed users.
|
||||
|
||||
**How it works:** Comprehensive optimization library with multiple stages of memory and compute optimization.
|
||||
|
||||
```python
|
||||
# Basic DeepSpeed
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='deepspeed',
|
||||
precision='16-mixed',
|
||||
)
|
||||
|
||||
# With configuration
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=2, # ZeRO Stage (1, 2, or 3)
|
||||
offload_optimizer=True,
|
||||
offload_parameters=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
**ZeRO Stages:**
|
||||
- **Stage 1:** Shard optimizer states
|
||||
- **Stage 2:** Shard optimizer states + gradients
|
||||
- **Stage 3:** Shard optimizer states + gradients + parameters (like FSDP)
|
||||
|
||||
**With DeepSpeed Config File:**
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(config="deepspeed_config.json")
|
||||
```
|
||||
|
||||
Example `deepspeed_config.json`:
|
||||
```json
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_bucket_size": 2e8,
|
||||
"reduce_bucket_size": 2e8
|
||||
},
|
||||
"activation_checkpointing": {
|
||||
"partition_activations": true,
|
||||
"cpu_checkpointing": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true
|
||||
},
|
||||
"gradient_clipping": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Need specific DeepSpeed features
|
||||
- Maximum memory efficiency required
|
||||
- Already familiar with DeepSpeed
|
||||
- Training extremely large models
|
||||
|
||||
### DDP Spawn
|
||||
|
||||
**Note:** Generally avoid using `ddp_spawn`. Use `ddp` instead.
|
||||
|
||||
```python
|
||||
trainer = Trainer(strategy='ddp_spawn') # Not recommended
|
||||
```
|
||||
|
||||
**Issues with ddp_spawn:**
|
||||
- Cannot return values from `.fit()`
|
||||
- Pickling issues with unpicklable objects
|
||||
- Slower than `ddp`
|
||||
- More memory overhead
|
||||
|
||||
**When to use:** Only for debugging or if `ddp` doesn't work on your system.
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
### Basic Multi-Node Setup
|
||||
|
||||
```python
|
||||
# On each node, run the same command
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4, # GPUs per node
|
||||
num_nodes=8, # Total number of nodes
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
### SLURM Cluster
|
||||
|
||||
Lightning automatically detects SLURM environment:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
num_nodes=8,
|
||||
strategy='ddp',
|
||||
)
|
||||
```
|
||||
|
||||
**SLURM Submit Script:**
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=8
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH --ntasks-per-node=4
|
||||
#SBATCH --job-name=lightning_training
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
### Manual Cluster Setup
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
cluster_environment='TorchElastic', # or 'SLURM', 'LSF', 'Kubeflow'
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
num_nodes=8,
|
||||
strategy=strategy,
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Optimization Techniques
|
||||
|
||||
### Gradient Accumulation
|
||||
|
||||
Simulate larger batch sizes without increasing memory:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accumulate_grad_batches=4, # Accumulate 4 batches before optimizer step
|
||||
)
|
||||
|
||||
# Variable accumulation by epoch
|
||||
trainer = Trainer(
|
||||
accumulate_grad_batches={
|
||||
0: 8, # Epochs 0-4: accumulate 8 batches
|
||||
5: 4, # Epochs 5+: accumulate 4 batches
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Activation Checkpointing
|
||||
|
||||
Trade computation for memory by recomputing activations during backward pass:
|
||||
|
||||
```python
|
||||
# FSDP
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
|
||||
class MyModule(L.LightningModule):
|
||||
def configure_model(self):
|
||||
# Wrap specific layers for activation checkpointing
|
||||
self.model = MyTransformer()
|
||||
apply_activation_checkpointing(
|
||||
self.model,
|
||||
checkpoint_wrapper_fn=lambda m: checkpoint_wrapper(m, CheckpointImpl.NO_REENTRANT),
|
||||
check_fn=lambda m: isinstance(m, TransformerBlock),
|
||||
)
|
||||
```
|
||||
|
||||
### Mixed Precision Training
|
||||
|
||||
Reduce memory usage and increase speed with mixed precision:
|
||||
|
||||
```python
|
||||
# 16-bit mixed precision
|
||||
trainer = Trainer(precision='16-mixed')
|
||||
|
||||
# BFloat16 mixed precision (more stable, requires newer GPUs)
|
||||
trainer = Trainer(precision='bf16-mixed')
|
||||
```
|
||||
|
||||
### CPU Offloading
|
||||
|
||||
Offload parameters or optimizer states to CPU:
|
||||
|
||||
```python
|
||||
# FSDP with CPU offload
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
cpu_offload=True, # Offload parameters to CPU
|
||||
)
|
||||
|
||||
# DeepSpeed with CPU offload
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3,
|
||||
offload_optimizer=True,
|
||||
offload_parameters=True,
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Synchronize Batch Normalization
|
||||
|
||||
Synchronize batch norm statistics across GPUs:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=4,
|
||||
strategy='ddp',
|
||||
sync_batchnorm=True, # Sync batch norm across GPUs
|
||||
)
|
||||
```
|
||||
|
||||
### Find Optimal Batch Size
|
||||
|
||||
```python
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
|
||||
trainer = Trainer()
|
||||
tuner = Tuner(trainer)
|
||||
|
||||
# Auto-scale batch size
|
||||
tuner.scale_batch_size(model, mode="power") # or "binsearch"
|
||||
```
|
||||
|
||||
### Gradient Clipping
|
||||
|
||||
Prevent gradient explosion in distributed training:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=1.0,
|
||||
gradient_clip_algorithm='norm', # or 'value'
|
||||
)
|
||||
```
|
||||
|
||||
### Benchmark Mode
|
||||
|
||||
Enable cudnn.benchmark for consistent input sizes:
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
benchmark=True, # Optimize for consistent input sizes
|
||||
)
|
||||
```
|
||||
|
||||
## Distributed Data Loading
|
||||
|
||||
### Automatic Distributed Sampling
|
||||
|
||||
Lightning automatically handles distributed sampling:
|
||||
|
||||
```python
|
||||
# No changes needed - Lightning handles this automatically
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True, # Lightning converts to DistributedSampler
|
||||
)
|
||||
```
|
||||
|
||||
### Manual Control
|
||||
|
||||
```python
|
||||
# Disable automatic distributed sampler
|
||||
trainer = Trainer(
|
||||
use_distributed_sampler=False,
|
||||
)
|
||||
|
||||
# Manual distributed sampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
def train_dataloader(self):
|
||||
sampler = DistributedSampler(self.train_dataset)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
sampler=sampler,
|
||||
)
|
||||
```
|
||||
|
||||
### Data Loading Best Practices
|
||||
|
||||
```python
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Use multiple workers
|
||||
pin_memory=True, # Faster CPU-GPU transfer
|
||||
persistent_workers=True, # Keep workers alive between epochs
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Logging in Distributed Training
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Automatically syncs across processes
|
||||
self.log('train_loss', loss, sync_dist=True)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Rank-Specific Operations
|
||||
|
||||
```python
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Run only on rank 0 (main process)
|
||||
if self.trainer.is_global_zero:
|
||||
print("This only prints once across all processes")
|
||||
|
||||
# Get current rank
|
||||
rank = self.trainer.global_rank
|
||||
world_size = self.trainer.world_size
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Barrier Synchronization
|
||||
|
||||
```python
|
||||
def on_train_epoch_end(self):
|
||||
# Wait for all processes
|
||||
self.trainer.strategy.barrier()
|
||||
|
||||
# Now all processes are synchronized
|
||||
if self.trainer.is_global_zero:
|
||||
# Save something only once
|
||||
self.save_artifacts()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**1. Out of Memory:**
|
||||
- Reduce batch size
|
||||
- Enable gradient accumulation
|
||||
- Use FSDP or DeepSpeed
|
||||
- Enable activation checkpointing
|
||||
- Use mixed precision
|
||||
|
||||
**2. Slow Training:**
|
||||
- Check data loading (use `num_workers > 0`)
|
||||
- Enable `pin_memory=True` and `persistent_workers=True`
|
||||
- Use `benchmark=True` for consistent input sizes
|
||||
- Profile with `profiler='simple'`
|
||||
|
||||
**3. Hanging:**
|
||||
- Ensure all processes execute same collectives
|
||||
- Check for `if` statements that differ across ranks
|
||||
- Use barrier synchronization when needed
|
||||
|
||||
**4. Inconsistent Results:**
|
||||
- Set `deterministic=True`
|
||||
- Use `seed_everything()`
|
||||
- Ensure proper gradient synchronization
|
||||
|
||||
### Debugging Distributed Training
|
||||
|
||||
```python
|
||||
# Test with single GPU first
|
||||
trainer = Trainer(accelerator='gpu', devices=1)
|
||||
|
||||
# Then test with 2 GPUs
|
||||
trainer = Trainer(accelerator='gpu', devices=2, strategy='ddp')
|
||||
|
||||
# Use fast_dev_run for quick testing
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
strategy='ddp',
|
||||
fast_dev_run=10, # Run 10 batches only
|
||||
)
|
||||
```
|
||||
|
||||
## Strategy Selection Guide
|
||||
|
||||
| Model Size | Available Memory | Recommended Strategy |
|
||||
|-----------|------------------|---------------------|
|
||||
| < 500M params | Fits in 1 GPU | Single GPU |
|
||||
| < 500M params | Fits across GPUs | DDP |
|
||||
| 500M - 3B params | Limited memory | FSDP or DeepSpeed Stage 2 |
|
||||
| 3B+ params | Very limited memory | FSDP or DeepSpeed Stage 3 |
|
||||
| Any size | Maximum efficiency | DeepSpeed with offloading |
|
||||
| Multiple nodes | Any | DDP (< 500M) or FSDP/DeepSpeed (> 500M) |
|
||||
Reference in New Issue
Block a user