Files
claude-scientific-skills/scientific-packages/pytorch-lightning/references/distributed_training.md
2025-10-19 14:12:02 -07:00

11 KiB

Distributed and Model Parallel Training

Comprehensive guide for distributed training strategies in PyTorch Lightning.

Overview

PyTorch Lightning provides seamless distributed training across multiple GPUs, machines, and TPUs with minimal code changes. The framework automatically handles the complexity of distributed training while keeping code device-agnostic and readable.

Training Strategies

Data Parallel (DDP - DistributedDataParallel)

Best for: Most models (< 500M parameters) where the full model fits in GPU memory.

How it works: Each GPU holds a complete copy of the model and trains on a different batch subset. Gradients are synchronized across GPUs during backward pass.

# Single-node, multi-GPU
trainer = Trainer(
    accelerator='gpu',
    devices=4,  # Use 4 GPUs
    strategy='ddp',
)

# Multi-node, multi-GPU
trainer = Trainer(
    accelerator='gpu',
    devices=4,  # GPUs per node
    num_nodes=2,  # Number of nodes
    strategy='ddp',
)

Advantages:

  • Most widely used and tested
  • Works with most PyTorch code
  • Good scaling efficiency
  • No code changes required in LightningModule

When to use: Default choice for most distributed training scenarios.

FSDP (Fully Sharded Data Parallel)

Best for: Large models (500M+ parameters) that don't fit in single GPU memory.

How it works: Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores a subset of the model.

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    strategy='fsdp',
)

# With configuration
from lightning.pytorch.strategies import FSDPStrategy

strategy = FSDPStrategy(
    sharding_strategy="FULL_SHARD",  # Full sharding
    cpu_offload=False,  # Offload to CPU
    mixed_precision=torch.float16,
)

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    strategy=strategy,
)

Sharding Strategies:

  • FULL_SHARD - Shard parameters, gradients, and optimizer states
  • SHARD_GRAD_OP - Shard only gradients and optimizer states
  • NO_SHARD - DDP-like (no sharding)
  • HYBRID_SHARD - Shard within node, DDP across nodes

Advanced FSDP Configuration:

from lightning.pytorch.strategies import FSDPStrategy

strategy = FSDPStrategy(
    sharding_strategy="FULL_SHARD",
    activation_checkpointing=True,  # Save memory
    cpu_offload=True,  # Offload parameters to CPU
    backward_prefetch="BACKWARD_PRE",  # Prefetch strategy
    forward_prefetch=True,
    limit_all_gathers=True,
)

When to use:

  • Models > 500M parameters
  • Limited GPU memory
  • Native PyTorch solution preferred
  • Migrating from standalone PyTorch FSDP

DeepSpeed

Best for: Cutting-edge features, massive models, or existing DeepSpeed users.

How it works: Comprehensive optimization library with multiple stages of memory and compute optimization.

# Basic DeepSpeed
trainer = Trainer(
    accelerator='gpu',
    devices=4,
    strategy='deepspeed',
    precision='16-mixed',
)

# With configuration
from lightning.pytorch.strategies import DeepSpeedStrategy

strategy = DeepSpeedStrategy(
    stage=2,  # ZeRO Stage (1, 2, or 3)
    offload_optimizer=True,
    offload_parameters=True,
)

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    strategy=strategy,
)

ZeRO Stages:

  • Stage 1: Shard optimizer states
  • Stage 2: Shard optimizer states + gradients
  • Stage 3: Shard optimizer states + gradients + parameters (like FSDP)

With DeepSpeed Config File:

strategy = DeepSpeedStrategy(config="deepspeed_config.json")

Example deepspeed_config.json:

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "allgather_bucket_size": 2e8,
    "reduce_bucket_size": 2e8
  },
  "activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": true
  },
  "fp16": {
    "enabled": true
  },
  "gradient_clipping": 1.0
}

When to use:

  • Need specific DeepSpeed features
  • Maximum memory efficiency required
  • Already familiar with DeepSpeed
  • Training extremely large models

DDP Spawn

Note: Generally avoid using ddp_spawn. Use ddp instead.

trainer = Trainer(strategy='ddp_spawn')  # Not recommended

Issues with ddp_spawn:

  • Cannot return values from .fit()
  • Pickling issues with unpicklable objects
  • Slower than ddp
  • More memory overhead

When to use: Only for debugging or if ddp doesn't work on your system.

Multi-Node Training

Basic Multi-Node Setup

# On each node, run the same command
trainer = Trainer(
    accelerator='gpu',
    devices=4,  # GPUs per node
    num_nodes=8,  # Total number of nodes
    strategy='ddp',
)

SLURM Cluster

Lightning automatically detects SLURM environment:

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    num_nodes=8,
    strategy='ddp',
)

SLURM Submit Script:

#!/bin/bash
#SBATCH --nodes=8
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --job-name=lightning_training

python train.py

Manual Cluster Setup

from lightning.pytorch.strategies import DDPStrategy

strategy = DDPStrategy(
    cluster_environment='TorchElastic',  # or 'SLURM', 'LSF', 'Kubeflow'
)

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    num_nodes=8,
    strategy=strategy,
)

Memory Optimization Techniques

Gradient Accumulation

Simulate larger batch sizes without increasing memory:

trainer = Trainer(
    accumulate_grad_batches=4,  # Accumulate 4 batches before optimizer step
)

# Variable accumulation by epoch
trainer = Trainer(
    accumulate_grad_batches={
        0: 8,  # Epochs 0-4: accumulate 8 batches
        5: 4,  # Epochs 5+: accumulate 4 batches
    }
)

Activation Checkpointing

Trade computation for memory by recomputing activations during backward pass:

# FSDP
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

class MyModule(L.LightningModule):
    def configure_model(self):
        # Wrap specific layers for activation checkpointing
        self.model = MyTransformer()
        apply_activation_checkpointing(
            self.model,
            checkpoint_wrapper_fn=lambda m: checkpoint_wrapper(m, CheckpointImpl.NO_REENTRANT),
            check_fn=lambda m: isinstance(m, TransformerBlock),
        )

Mixed Precision Training

Reduce memory usage and increase speed with mixed precision:

# 16-bit mixed precision
trainer = Trainer(precision='16-mixed')

# BFloat16 mixed precision (more stable, requires newer GPUs)
trainer = Trainer(precision='bf16-mixed')

CPU Offloading

Offload parameters or optimizer states to CPU:

# FSDP with CPU offload
from lightning.pytorch.strategies import FSDPStrategy

strategy = FSDPStrategy(
    cpu_offload=True,  # Offload parameters to CPU
)

# DeepSpeed with CPU offload
from lightning.pytorch.strategies import DeepSpeedStrategy

strategy = DeepSpeedStrategy(
    stage=3,
    offload_optimizer=True,
    offload_parameters=True,
)

Performance Optimization

Synchronize Batch Normalization

Synchronize batch norm statistics across GPUs:

trainer = Trainer(
    accelerator='gpu',
    devices=4,
    strategy='ddp',
    sync_batchnorm=True,  # Sync batch norm across GPUs
)

Find Optimal Batch Size

from lightning.pytorch.tuner import Tuner

trainer = Trainer()
tuner = Tuner(trainer)

# Auto-scale batch size
tuner.scale_batch_size(model, mode="power")  # or "binsearch"

Gradient Clipping

Prevent gradient explosion in distributed training:

trainer = Trainer(
    gradient_clip_val=1.0,
    gradient_clip_algorithm='norm',  # or 'value'
)

Benchmark Mode

Enable cudnn.benchmark for consistent input sizes:

trainer = Trainer(
    benchmark=True,  # Optimize for consistent input sizes
)

Distributed Data Loading

Automatic Distributed Sampling

Lightning automatically handles distributed sampling:

# No changes needed - Lightning handles this automatically
def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=32,
        shuffle=True,  # Lightning converts to DistributedSampler
    )

Manual Control

# Disable automatic distributed sampler
trainer = Trainer(
    use_distributed_sampler=False,
)

# Manual distributed sampler
from torch.utils.data.distributed import DistributedSampler

def train_dataloader(self):
    sampler = DistributedSampler(self.train_dataset)
    return DataLoader(
        self.train_dataset,
        batch_size=32,
        sampler=sampler,
    )

Data Loading Best Practices

def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=32,
        num_workers=4,  # Use multiple workers
        pin_memory=True,  # Faster CPU-GPU transfer
        persistent_workers=True,  # Keep workers alive between epochs
    )

Common Patterns

Logging in Distributed Training

def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)

    # Automatically syncs across processes
    self.log('train_loss', loss, sync_dist=True)

    return loss

Rank-Specific Operations

def training_step(self, batch, batch_idx):
    # Run only on rank 0 (main process)
    if self.trainer.is_global_zero:
        print("This only prints once across all processes")

    # Get current rank
    rank = self.trainer.global_rank
    world_size = self.trainer.world_size

    return loss

Barrier Synchronization

def on_train_epoch_end(self):
    # Wait for all processes
    self.trainer.strategy.barrier()

    # Now all processes are synchronized
    if self.trainer.is_global_zero:
        # Save something only once
        self.save_artifacts()

Troubleshooting

Common Issues

1. Out of Memory:

  • Reduce batch size
  • Enable gradient accumulation
  • Use FSDP or DeepSpeed
  • Enable activation checkpointing
  • Use mixed precision

2. Slow Training:

  • Check data loading (use num_workers > 0)
  • Enable pin_memory=True and persistent_workers=True
  • Use benchmark=True for consistent input sizes
  • Profile with profiler='simple'

3. Hanging:

  • Ensure all processes execute same collectives
  • Check for if statements that differ across ranks
  • Use barrier synchronization when needed

4. Inconsistent Results:

  • Set deterministic=True
  • Use seed_everything()
  • Ensure proper gradient synchronization

Debugging Distributed Training

# Test with single GPU first
trainer = Trainer(accelerator='gpu', devices=1)

# Then test with 2 GPUs
trainer = Trainer(accelerator='gpu', devices=2, strategy='ddp')

# Use fast_dev_run for quick testing
trainer = Trainer(
    accelerator='gpu',
    devices=2,
    strategy='ddp',
    fast_dev_run=10,  # Run 10 batches only
)

Strategy Selection Guide

Model Size Available Memory Recommended Strategy
< 500M params Fits in 1 GPU Single GPU
< 500M params Fits across GPUs DDP
500M - 3B params Limited memory FSDP or DeepSpeed Stage 2
3B+ params Very limited memory FSDP or DeepSpeed Stage 3
Any size Maximum efficiency DeepSpeed with offloading
Multiple nodes Any DDP (< 500M) or FSDP/DeepSpeed (> 500M)