mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-28 07:20:27 +08:00
Consolidate skills
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Template for creating a PyTorch Lightning Module.
|
||||
|
||||
This template provides a complete boilerplate for building a LightningModule
|
||||
with all essential methods and best practices.
|
||||
"""
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
|
||||
class TemplateLightningModule(L.LightningModule):
|
||||
"""
|
||||
Template LightningModule for building deep learning models.
|
||||
|
||||
Args:
|
||||
learning_rate: Learning rate for optimizer
|
||||
hidden_dim: Hidden dimension size
|
||||
dropout: Dropout probability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float = 0.001,
|
||||
hidden_dim: int = 256,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Save hyperparameters (accessible via self.hparams)
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Define your model architecture
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(784, self.hparams.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(self.hparams.dropout),
|
||||
nn.Linear(self.hparams.hidden_dim, 10),
|
||||
)
|
||||
|
||||
# Optional: Define metrics
|
||||
# from torchmetrics import Accuracy
|
||||
# self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
|
||||
# self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
Model output
|
||||
"""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
Training step (called for each training batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy (optional)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
self.log("train/acc", acc, on_step=True, on_epoch=True)
|
||||
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""
|
||||
Validation step (called for each validation batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics (automatically aggregated across batches)
|
||||
self.log("val/loss", loss, on_epoch=True, prog_bar=True, sync_dist=True)
|
||||
self.log("val/acc", acc, on_epoch=True, prog_bar=True, sync_dist=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""
|
||||
Test step (called for each test batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
"""
|
||||
x, y = batch
|
||||
|
||||
# Forward pass
|
||||
logits = self(x)
|
||||
loss = F.cross_entropy(logits, y)
|
||||
|
||||
# Calculate accuracy
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = (preds == y).float().mean()
|
||||
|
||||
# Log metrics
|
||||
self.log("test/loss", loss, on_epoch=True)
|
||||
self.log("test/acc", acc, on_epoch=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
"""
|
||||
Prediction step (called for each prediction batch).
|
||||
|
||||
Args:
|
||||
batch: Current batch of data
|
||||
batch_idx: Index of the current batch
|
||||
dataloader_idx: Index of the dataloader (if multiple)
|
||||
|
||||
Returns:
|
||||
Predictions
|
||||
"""
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
return preds
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Configure optimizers and learning rate schedulers.
|
||||
|
||||
Returns:
|
||||
Optimizer and scheduler configuration
|
||||
"""
|
||||
# Define optimizer
|
||||
optimizer = Adam(
|
||||
self.parameters(),
|
||||
lr=self.hparams.learning_rate,
|
||||
weight_decay=1e-5,
|
||||
)
|
||||
|
||||
# Define scheduler
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Return configuration
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "val/loss",
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
},
|
||||
}
|
||||
|
||||
# Optional: Add custom methods for model-specific logic
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
"""Called at the end of each training epoch."""
|
||||
# Example: Log custom metrics
|
||||
pass
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
"""Called at the end of each validation epoch."""
|
||||
# Example: Compute epoch-level metrics
|
||||
pass
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create model
|
||||
model = TemplateLightningModule(
|
||||
learning_rate=0.001,
|
||||
hidden_dim=256,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
accelerator="auto",
|
||||
devices="auto",
|
||||
logger=True,
|
||||
)
|
||||
|
||||
# Train (you need to provide train_dataloader and val_dataloader)
|
||||
# trainer.fit(model, train_dataloader, val_dataloader)
|
||||
|
||||
print(f"Model created with {model.num_parameters:,} parameters")
|
||||
print(f"Hyperparameters: {model.hparams}")
|
||||
Reference in New Issue
Block a user