566 lines
14 KiB
Markdown
566 lines
14 KiB
Markdown
# Core Concepts and Technical Details
|
||
|
||
## Overview
|
||
|
||
This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.
|
||
|
||
## Architecture Philosophy
|
||
|
||
### Modular Design
|
||
|
||
TorchDrug separates concerns into distinct modules:
|
||
|
||
1. **Representation Models** (models.py): Encode graphs into embeddings
|
||
2. **Task Definitions** (tasks.py): Define learning objectives and evaluation
|
||
3. **Data Handling** (data.py, datasets.py): Graph structures and datasets
|
||
4. **Core Components** (core.py): Base classes and utilities
|
||
|
||
**Benefits:**
|
||
- Reuse representations across tasks
|
||
- Mix and match components
|
||
- Easy experimentation and prototyping
|
||
- Clear separation of concerns
|
||
|
||
### Configurable System
|
||
|
||
All components inherit from `core.Configurable`:
|
||
- Serialize to configuration dictionaries
|
||
- Reconstruct from configurations
|
||
- Save and load complete pipelines
|
||
- Reproducible experiments
|
||
|
||
## Core Components
|
||
|
||
### core.Configurable
|
||
|
||
Base class for all TorchDrug components.
|
||
|
||
**Key Methods:**
|
||
- `config_dict()`: Serialize to dictionary
|
||
- `load_config_dict(config)`: Load from dictionary
|
||
- `save(file)`: Save to file
|
||
- `load(file)`: Load from file
|
||
|
||
**Example:**
|
||
```python
|
||
from torchdrug import core, models
|
||
|
||
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
|
||
|
||
# Save configuration
|
||
config = model.config_dict()
|
||
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}
|
||
|
||
# Reconstruct model
|
||
model2 = core.Configurable.load_config_dict(config)
|
||
```
|
||
|
||
### core.Registry
|
||
|
||
Decorator for registering models, tasks, and datasets.
|
||
|
||
**Usage:**
|
||
```python
|
||
from torchdrug import core as core_td
|
||
|
||
@core_td.register("models.CustomModel")
|
||
class CustomModel(nn.Module, core_td.Configurable):
|
||
def __init__(self, input_dim, hidden_dim):
|
||
super().__init__()
|
||
self.linear = nn.Linear(input_dim, hidden_dim)
|
||
|
||
def forward(self, graph, input, all_loss, metric):
|
||
# Model implementation
|
||
pass
|
||
```
|
||
|
||
**Benefits:**
|
||
- Models automatically serializable
|
||
- String-based model specification
|
||
- Easy model lookup and instantiation
|
||
|
||
## Data Structures
|
||
|
||
### Graph
|
||
|
||
Core data structure representing molecular or protein graphs.
|
||
|
||
**Attributes:**
|
||
- `num_node`: Number of nodes
|
||
- `num_edge`: Number of edges
|
||
- `node_feature`: Node feature tensor [num_node, feature_dim]
|
||
- `edge_feature`: Edge feature tensor [num_edge, feature_dim]
|
||
- `edge_list`: Edge connectivity [num_edge, 2 or 3]
|
||
- `num_relation`: Number of edge types (for multi-relational)
|
||
|
||
**Methods:**
|
||
- `node_mask(mask)`: Select subset of nodes
|
||
- `edge_mask(mask)`: Select subset of edges
|
||
- `undirected()`: Make graph undirected
|
||
- `directed()`: Make graph directed
|
||
|
||
**Batching:**
|
||
- Graphs batched into single disconnected graph
|
||
- Automatic batching in DataLoader
|
||
- Preserves node/edge indices per graph
|
||
|
||
### Molecule (extends Graph)
|
||
|
||
Specialized graph for molecules.
|
||
|
||
**Additional Attributes:**
|
||
- `atom_type`: Atomic numbers
|
||
- `bond_type`: Bond types (single, double, triple, aromatic)
|
||
- `formal_charge`: Atomic formal charges
|
||
- `explicit_hs`: Explicit hydrogen counts
|
||
|
||
**Methods:**
|
||
- `from_smiles(smiles)`: Create from SMILES string
|
||
- `from_molecule(mol)`: Create from RDKit molecule
|
||
- `to_smiles()`: Convert to SMILES
|
||
- `to_molecule()`: Convert to RDKit molecule
|
||
- `ion_to_molecule()`: Neutralize charges
|
||
|
||
**Example:**
|
||
```python
|
||
from torchdrug import data
|
||
|
||
# From SMILES
|
||
mol = data.Molecule.from_smiles("CCO")
|
||
|
||
# Atom features
|
||
print(mol.atom_type) # [6, 6, 8] (C, C, O)
|
||
print(mol.bond_type) # [1, 1] (single bonds)
|
||
```
|
||
|
||
### Protein (extends Graph)
|
||
|
||
Specialized graph for proteins.
|
||
|
||
**Additional Attributes:**
|
||
- `residue_type`: Amino acid types
|
||
- `atom_name`: Atom names (CA, CB, etc.)
|
||
- `atom_type`: Atomic numbers
|
||
- `residue_number`: Residue numbering
|
||
- `chain_id`: Chain identifiers
|
||
|
||
**Methods:**
|
||
- `from_pdb(pdb_file)`: Load from PDB file
|
||
- `from_sequence(sequence)`: Create from sequence
|
||
- `to_pdb(pdb_file)`: Save to PDB file
|
||
|
||
**Graph Construction:**
|
||
- Nodes typically represent residues (not atoms)
|
||
- Edges can be sequential, spatial (KNN), or contact-based
|
||
- Configurable edge construction strategies
|
||
|
||
**Example:**
|
||
```python
|
||
from torchdrug import data
|
||
|
||
# Load protein
|
||
protein = data.Protein.from_pdb("1a3x.pdb")
|
||
|
||
# Build graph with multiple edge types
|
||
graph = protein.residue_graph(
|
||
node_position="ca", # Use Cα positions
|
||
edge_types=["sequential", "radius"] # Sequential + spatial edges
|
||
)
|
||
```
|
||
|
||
### PackedGraph
|
||
|
||
Efficient batching structure for heterogeneous graphs.
|
||
|
||
**Purpose:**
|
||
- Batch graphs of different sizes
|
||
- Single GPU memory allocation
|
||
- Efficient parallel processing
|
||
|
||
**Attributes:**
|
||
- `num_nodes`: List of node counts per graph
|
||
- `num_edges`: List of edge counts per graph
|
||
- `graph_ind`: Graph index for each node
|
||
|
||
**Use Cases:**
|
||
- Automatic in DataLoader
|
||
- Custom batching strategies
|
||
- Multi-graph operations
|
||
|
||
## Model Interface
|
||
|
||
### Forward Function Signature
|
||
|
||
All TorchDrug models follow a standardized interface:
|
||
|
||
```python
|
||
def forward(self, graph, input, all_loss=None, metric=None):
|
||
"""
|
||
Args:
|
||
graph (Graph): Batch of graphs
|
||
input (Tensor): Node input features
|
||
all_loss (Tensor, optional): Accumulator for losses
|
||
metric (dict, optional): Dictionary for metrics
|
||
|
||
Returns:
|
||
dict: Output dictionary with representation keys
|
||
"""
|
||
# Model computation
|
||
output = self.layers(graph, input)
|
||
|
||
return {
|
||
"node_feature": output,
|
||
"graph_feature": graph_pooling(output)
|
||
}
|
||
```
|
||
|
||
**Key Points:**
|
||
- `graph`: Batched graph structure
|
||
- `input`: Node features [num_node, input_dim]
|
||
- `all_loss`: Accumulated loss (for multi-task)
|
||
- `metric`: Shared metric dictionary
|
||
- Returns dict with representation types
|
||
|
||
### Essential Attributes
|
||
|
||
**All models must define:**
|
||
- `input_dim`: Expected input feature dimension
|
||
- `output_dim`: Output representation dimension
|
||
|
||
**Purpose:**
|
||
- Automatic dimension checking
|
||
- Compose models in pipelines
|
||
- Error checking and validation
|
||
|
||
**Example:**
|
||
```python
|
||
class CustomModel(nn.Module):
|
||
def __init__(self, input_dim, hidden_dim):
|
||
super().__init__()
|
||
self.input_dim = input_dim
|
||
self.output_dim = hidden_dim
|
||
# ... layers ...
|
||
```
|
||
|
||
## Task Interface
|
||
|
||
### Core Task Methods
|
||
|
||
All tasks implement these methods:
|
||
|
||
```python
|
||
class CustomTask(tasks.Task):
|
||
def preprocess(self, train_set, valid_set, test_set):
|
||
"""Dataset-specific preprocessing (optional)"""
|
||
pass
|
||
|
||
def predict(self, batch):
|
||
"""Generate predictions for a batch"""
|
||
graph, label = batch
|
||
output = self.model(graph, graph.node_feature)
|
||
pred = self.mlp(output["graph_feature"])
|
||
return pred
|
||
|
||
def target(self, batch):
|
||
"""Extract ground truth labels"""
|
||
graph, label = batch
|
||
return label
|
||
|
||
def forward(self, batch):
|
||
"""Compute training loss"""
|
||
pred = self.predict(batch)
|
||
target = self.target(batch)
|
||
loss = self.criterion(pred, target)
|
||
return loss
|
||
|
||
def evaluate(self, pred, target):
|
||
"""Compute evaluation metrics"""
|
||
metrics = {}
|
||
metrics["auroc"] = compute_auroc(pred, target)
|
||
metrics["auprc"] = compute_auprc(pred, target)
|
||
return metrics
|
||
```
|
||
|
||
### Task Components
|
||
|
||
**Typical Task Structure:**
|
||
1. **Representation Model**: Encodes graph to embeddings
|
||
2. **Readout/Prediction Head**: Maps embeddings to predictions
|
||
3. **Loss Function**: Training objective
|
||
4. **Metrics**: Evaluation measures
|
||
|
||
**Example:**
|
||
```python
|
||
from torchdrug import tasks, models
|
||
|
||
# Representation model
|
||
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
|
||
|
||
# Task wraps model with prediction head
|
||
task = tasks.PropertyPrediction(
|
||
model=model,
|
||
task=["task1", "task2"], # Multi-task
|
||
criterion="bce",
|
||
metric=["auroc", "auprc"],
|
||
num_mlp_layer=2
|
||
)
|
||
```
|
||
|
||
## Training Workflow
|
||
|
||
### Standard Training Loop
|
||
|
||
```python
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from torchdrug import core, models, tasks, datasets
|
||
|
||
# 1. Load dataset
|
||
dataset = datasets.BBBP("~/datasets/")
|
||
train_set, valid_set, test_set = dataset.split()
|
||
|
||
# 2. Create data loaders
|
||
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
|
||
valid_loader = DataLoader(valid_set, batch_size=32)
|
||
|
||
# 3. Define model and task
|
||
model = models.GIN(input_dim=dataset.node_feature_dim,
|
||
hidden_dims=[256, 256, 256])
|
||
task = tasks.PropertyPrediction(model, task=dataset.tasks,
|
||
criterion="bce", metric=["auroc", "auprc"])
|
||
|
||
# 4. Setup optimizer
|
||
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
|
||
|
||
# 5. Training loop
|
||
for epoch in range(100):
|
||
# Train
|
||
task.train()
|
||
for batch in train_loader:
|
||
loss = task(batch)
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# Validate
|
||
task.eval()
|
||
preds, targets = [], []
|
||
for batch in valid_loader:
|
||
pred = task.predict(batch)
|
||
target = task.target(batch)
|
||
preds.append(pred)
|
||
targets.append(target)
|
||
|
||
preds = torch.cat(preds)
|
||
targets = torch.cat(targets)
|
||
metrics = task.evaluate(preds, targets)
|
||
print(f"Epoch {epoch}: {metrics}")
|
||
```
|
||
|
||
### PyTorch Lightning Integration
|
||
|
||
TorchDrug tasks are compatible with PyTorch Lightning:
|
||
|
||
```python
|
||
import pytorch_lightning as pl
|
||
|
||
class LightningWrapper(pl.LightningModule):
|
||
def __init__(self, task):
|
||
super().__init__()
|
||
self.task = task
|
||
|
||
def training_step(self, batch, batch_idx):
|
||
loss = self.task(batch)
|
||
return loss
|
||
|
||
def validation_step(self, batch, batch_idx):
|
||
pred = self.task.predict(batch)
|
||
target = self.task.target(batch)
|
||
return {"pred": pred, "target": target}
|
||
|
||
def validation_epoch_end(self, outputs):
|
||
preds = torch.cat([o["pred"] for o in outputs])
|
||
targets = torch.cat([o["target"] for o in outputs])
|
||
metrics = self.task.evaluate(preds, targets)
|
||
self.log_dict(metrics)
|
||
|
||
def configure_optimizers(self):
|
||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||
```
|
||
|
||
## Loss Functions
|
||
|
||
### Built-in Criteria
|
||
|
||
**Classification:**
|
||
- `"bce"`: Binary cross-entropy
|
||
- `"ce"`: Cross-entropy (multi-class)
|
||
|
||
**Regression:**
|
||
- `"mse"`: Mean squared error
|
||
- `"mae"`: Mean absolute error
|
||
|
||
**Knowledge Graph:**
|
||
- `"bce"`: Binary classification of triples
|
||
- `"ce"`: Cross-entropy ranking loss
|
||
- `"margin"`: Margin-based ranking
|
||
|
||
### Custom Loss
|
||
|
||
```python
|
||
class CustomTask(tasks.Task):
|
||
def forward(self, batch):
|
||
pred = self.predict(batch)
|
||
target = self.target(batch)
|
||
|
||
# Custom loss computation
|
||
loss = custom_loss_function(pred, target)
|
||
|
||
return loss
|
||
```
|
||
|
||
## Metrics
|
||
|
||
### Common Metrics
|
||
|
||
**Classification:**
|
||
- **AUROC**: Area under ROC curve
|
||
- **AUPRC**: Area under precision-recall curve
|
||
- **Accuracy**: Overall accuracy
|
||
- **F1**: Harmonic mean of precision and recall
|
||
|
||
**Regression:**
|
||
- **MAE**: Mean absolute error
|
||
- **RMSE**: Root mean squared error
|
||
- **R²**: Coefficient of determination
|
||
- **Pearson**: Pearson correlation
|
||
|
||
**Ranking (Knowledge Graph):**
|
||
- **MR**: Mean rank
|
||
- **MRR**: Mean reciprocal rank
|
||
- **Hits@K**: Percentage in top K
|
||
|
||
### Multi-Task Metrics
|
||
|
||
For multi-label or multi-task:
|
||
- Metrics computed per task
|
||
- Macro-average across tasks
|
||
- Can weight by task importance
|
||
|
||
## Data Transforms
|
||
|
||
### Molecule Transforms
|
||
|
||
```python
|
||
from torchdrug import transforms
|
||
|
||
# Add virtual node connected to all atoms
|
||
transform1 = transforms.VirtualNode()
|
||
|
||
# Add virtual edges
|
||
transform2 = transforms.VirtualEdge()
|
||
|
||
# Compose transforms
|
||
transform = transforms.Compose([transform1, transform2])
|
||
|
||
dataset = datasets.BBBP("~/datasets/", transform=transform)
|
||
```
|
||
|
||
### Protein Transforms
|
||
|
||
```python
|
||
# Add edges based on spatial proximity
|
||
transform = transforms.TruncateProtein(max_length=500)
|
||
|
||
dataset = datasets.Fold("~/datasets/", transform=transform)
|
||
```
|
||
|
||
## Best Practices
|
||
|
||
### Memory Efficiency
|
||
|
||
1. **Gradient Accumulation**: For large models
|
||
2. **Mixed Precision**: FP16 training
|
||
3. **Batch Size Tuning**: Balance speed and memory
|
||
4. **Data Loading**: Multiple workers for I/O
|
||
|
||
### Reproducibility
|
||
|
||
1. **Set Seeds**: PyTorch, NumPy, Python random
|
||
2. **Deterministic Operations**: `torch.use_deterministic_algorithms(True)`
|
||
3. **Save Configurations**: Use `core.Configurable`
|
||
4. **Version Control**: Track TorchDrug version
|
||
|
||
### Debugging
|
||
|
||
1. **Check Dimensions**: Verify `input_dim` and `output_dim`
|
||
2. **Validate Batching**: Print batch statistics
|
||
3. **Monitor Gradients**: Watch for vanishing/exploding
|
||
4. **Overfit Small Batch**: Ensure model capacity
|
||
|
||
### Performance Optimization
|
||
|
||
1. **GPU Utilization**: Monitor with `nvidia-smi`
|
||
2. **Profile Code**: Use PyTorch profiler
|
||
3. **Optimize Data Loading**: Prefetch, pin memory
|
||
4. **Compile Models**: Use TorchScript if possible
|
||
|
||
## Advanced Topics
|
||
|
||
### Multi-Task Learning
|
||
|
||
Train single model on multiple related tasks:
|
||
```python
|
||
task = tasks.PropertyPrediction(
|
||
model,
|
||
task=["task1", "task2", "task3"],
|
||
criterion="bce",
|
||
metric=["auroc"],
|
||
task_weight=[1.0, 1.0, 2.0] # Weight task 3 more
|
||
)
|
||
```
|
||
|
||
### Transfer Learning
|
||
|
||
1. Pre-train on large dataset
|
||
2. Fine-tune on target dataset
|
||
3. Optionally freeze early layers
|
||
|
||
### Self-Supervised Pre-training
|
||
|
||
Use pre-training tasks:
|
||
- `AttributeMasking`: Mask node features
|
||
- `EdgePrediction`: Predict edge existence
|
||
- `ContextPrediction`: Contrastive learning
|
||
|
||
### Custom Layers
|
||
|
||
Extend TorchDrug with custom GNN layers:
|
||
```python
|
||
from torchdrug import layers
|
||
|
||
class CustomConv(layers.MessagePassingBase):
|
||
def message(self, graph, input):
|
||
# Custom message function
|
||
pass
|
||
|
||
def aggregate(self, graph, message):
|
||
# Custom aggregation
|
||
pass
|
||
|
||
def combine(self, input, update):
|
||
# Custom combination
|
||
pass
|
||
```
|
||
|
||
## Common Pitfalls
|
||
|
||
1. **Forgetting `input_dim` and `output_dim`**: Models won't compose
|
||
2. **Not Batching Properly**: Use PackedGraph for variable-sized graphs
|
||
3. **Data Leakage**: Be careful with scaffold splits and pre-training
|
||
4. **Ignoring Edge Features**: Bonds/spatial info can be critical
|
||
5. **Wrong Evaluation Metrics**: Match metrics to task (AUROC for imbalanced)
|
||
6. **Insufficient Regularization**: Use dropout, weight decay, early stopping
|
||
7. **Not Validating Chemistry**: Generated molecules must be valid
|
||
8. **Overfitting Small Datasets**: Use pre-training or simpler models
|