Improve Pytorch Lightning skill

This commit is contained in:
Timothy Kassis
2025-10-21 10:19:15 -07:00
parent aacc29a778
commit 1a9149b089
15 changed files with 5049 additions and 1968 deletions

View File

@@ -1,190 +1,197 @@
"""
Template for creating a PyTorch Lightning LightningModule.
Template for creating a PyTorch Lightning Module.
This template includes all common hooks and patterns for building
a Lightning model with best practices.
This template provides a complete boilerplate for building a LightningModule
with all essential methods and best practices.
"""
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
class TemplateLightningModule(L.LightningModule):
"""Template LightningModule with all common hooks and patterns."""
"""
Template LightningModule for building deep learning models.
Args:
learning_rate: Learning rate for optimizer
hidden_dim: Hidden dimension size
dropout: Dropout probability
"""
def __init__(
self,
# Model architecture parameters
input_dim: int = 784,
hidden_dim: int = 128,
output_dim: int = 10,
# Optimization parameters
learning_rate: float = 1e-3,
optimizer_type: str = "adam",
scheduler_type: str = None,
# Other hyperparameters
learning_rate: float = 0.001,
hidden_dim: int = 256,
dropout: float = 0.1,
):
super().__init__()
# Save hyperparameters for checkpointing and logging
# Save hyperparameters (accessible via self.hparams)
self.save_hyperparameters()
# Define model architecture
# Define your model architecture
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Linear(784, self.hparams.hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim)
nn.Dropout(self.hparams.dropout),
nn.Linear(self.hparams.hidden_dim, 10),
)
# Define loss function
self.criterion = nn.CrossEntropyLoss()
# For tracking validation outputs (optional)
self.validation_step_outputs = []
# Optional: Define metrics
# from torchmetrics import Accuracy
# self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
# self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
"""Forward pass for inference."""
"""
Forward pass of the model.
Args:
x: Input tensor
Returns:
Model output
"""
return self.model(x)
def training_step(self, batch, batch_idx):
"""Training step - called for each training batch."""
"""
Training step (called for each training batch).
Args:
batch: Current batch of data
batch_idx: Index of the current batch
Returns:
Loss tensor
"""
x, y = batch
# Forward pass
logits = self(x)
loss = self.criterion(logits, y)
loss = F.cross_entropy(logits, y)
# Calculate accuracy
# Calculate accuracy (optional)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/acc", acc, on_step=True, on_epoch=True)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
"""Validation step - called for each validation batch."""
x, y = batch
"""
Validation step (called for each validation batch).
# Forward pass (model automatically in eval mode)
logits = self(x)
loss = self.criterion(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
# Optional: store outputs for epoch-level processing
self.validation_step_outputs.append({"loss": loss, "acc": acc})
return loss
def on_validation_epoch_end(self):
"""Called at the end of validation epoch."""
# Optional: process all validation outputs
if self.validation_step_outputs:
avg_loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()
avg_acc = torch.stack([x["acc"] for x in self.validation_step_outputs]).mean()
# Log epoch-level metrics if needed
# self.log("val_epoch_loss", avg_loss)
# self.log("val_epoch_acc", avg_acc)
# Clear outputs
self.validation_step_outputs.clear()
def test_step(self, batch, batch_idx):
"""Test step - called for each test batch."""
Args:
batch: Current batch of data
batch_idx: Index of the current batch
"""
x, y = batch
# Forward pass
logits = self(x)
loss = self.criterion(logits, y)
loss = F.cross_entropy(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics (automatically aggregated across batches)
self.log("val/loss", loss, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("val/acc", acc, on_epoch=True, prog_bar=True, sync_dist=True)
def test_step(self, batch, batch_idx):
"""
Test step (called for each test batch).
Args:
batch: Current batch of data
batch_idx: Index of the current batch
"""
x, y = batch
# Forward pass
logits = self(x)
loss = F.cross_entropy(logits, y)
# Calculate accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# Log metrics
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.log("test_acc", acc, on_step=False, on_epoch=True)
return loss
self.log("test/loss", loss, on_epoch=True)
self.log("test/acc", acc, on_epoch=True)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
"""Prediction step - called for each prediction batch."""
"""
Prediction step (called for each prediction batch).
Args:
batch: Current batch of data
batch_idx: Index of the current batch
dataloader_idx: Index of the dataloader (if multiple)
Returns:
Predictions
"""
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
return preds
def configure_optimizers(self):
"""Configure optimizer and learning rate scheduler."""
# Create optimizer
if self.hparams.optimizer_type.lower() == "adam":
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
elif self.hparams.optimizer_type.lower() == "sgd":
optimizer = SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9)
else:
raise ValueError(f"Unknown optimizer: {self.hparams.optimizer_type}")
"""
Configure optimizers and learning rate schedulers.
# Configure with scheduler if specified
if self.hparams.scheduler_type:
if self.hparams.scheduler_type.lower() == "reduce_on_plateau":
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
}
}
elif self.hparams.scheduler_type.lower() == "step":
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
}
}
Returns:
Optimizer and scheduler configuration
"""
# Define optimizer
optimizer = Adam(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=1e-5,
)
return optimizer
# Define scheduler
scheduler = ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=5,
verbose=True,
)
# Optional: Additional hooks for custom behavior
# Return configuration
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val/loss",
"interval": "epoch",
"frequency": 1,
},
}
def on_train_start(self):
"""Called at the beginning of training."""
pass
def on_train_epoch_start(self):
"""Called at the beginning of each training epoch."""
pass
# Optional: Add custom methods for model-specific logic
def on_train_epoch_end(self):
"""Called at the end of each training epoch."""
# Example: Log custom metrics
pass
def on_train_end(self):
"""Called at the end of training."""
def on_validation_epoch_end(self):
"""Called at the end of each validation epoch."""
# Example: Compute epoch-level metrics
pass
@@ -192,24 +199,21 @@ class TemplateLightningModule(L.LightningModule):
if __name__ == "__main__":
# Create model
model = TemplateLightningModule(
input_dim=784,
hidden_dim=128,
output_dim=10,
learning_rate=1e-3,
optimizer_type="adam",
scheduler_type="reduce_on_plateau"
learning_rate=0.001,
hidden_dim=256,
dropout=0.1,
)
# Create trainer
trainer = L.Trainer(
max_epochs=10,
accelerator="auto",
devices=1,
log_every_n_steps=50,
devices="auto",
logger=True,
)
# Note: You would need to provide dataloaders
# Train (you need to provide train_dataloader and val_dataloader)
# trainer.fit(model, train_dataloader, val_dataloader)
print("Template LightningModule created successfully!")
print(f"Model hyperparameters: {model.hparams}")
print(f"Model created with {model.num_parameters:,} parameters")
print(f"Hyperparameters: {model.hparams}")