mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-01-26 16:58:56 +08:00
Add ESM3 and ESM C models protein models
This commit is contained in:
300
scientific-packages/esm/SKILL.md
Normal file
300
scientific-packages/esm/SKILL.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
---
|
||||||
|
name: esm
|
||||||
|
description: Comprehensive toolkit for protein language models including ESM3 (generative multimodal protein design across sequence, structure, and function) and ESM C (efficient protein embeddings and representations). Use this skill when working with protein sequences, structures, or function prediction; designing novel proteins; generating protein embeddings; performing inverse folding; or conducting protein engineering tasks. Supports both local model usage and cloud-based Forge API for scalable inference.
|
||||||
|
---
|
||||||
|
|
||||||
|
# ESM: Evolutionary Scale Modeling
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
ESM provides state-of-the-art protein language models for understanding, generating, and designing proteins. This skill enables working with two model families: ESM3 for generative protein design across sequence, structure, and function, and ESM C for efficient protein representation learning and embeddings.
|
||||||
|
|
||||||
|
## Core Capabilities
|
||||||
|
|
||||||
|
### 1. Protein Sequence Generation with ESM3
|
||||||
|
|
||||||
|
Generate novel protein sequences with desired properties using multimodal generative modeling.
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Designing proteins with specific functional properties
|
||||||
|
- Completing partial protein sequences
|
||||||
|
- Generating variants of existing proteins
|
||||||
|
- Creating proteins with desired structural characteristics
|
||||||
|
|
||||||
|
**Basic usage:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Load model locally
|
||||||
|
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
|
||||||
|
# Create protein prompt
|
||||||
|
protein = ESMProtein(sequence="MPRT___KEND") # '_' represents masked positions
|
||||||
|
|
||||||
|
# Generate completion
|
||||||
|
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
|
||||||
|
print(protein.sequence)
|
||||||
|
```
|
||||||
|
|
||||||
|
**For remote/cloud usage via Forge API:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Connect to Forge
|
||||||
|
model = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", url="https://forge.evolutionaryscale.ai", token="<token>")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
|
||||||
|
```
|
||||||
|
|
||||||
|
See `references/esm3-api.md` for detailed ESM3 model specifications, advanced generation configurations, and multimodal prompting examples.
|
||||||
|
|
||||||
|
### 2. Structure Prediction and Inverse Folding
|
||||||
|
|
||||||
|
Use ESM3's structure track for structure prediction from sequence or inverse folding (sequence design from structure).
|
||||||
|
|
||||||
|
**Structure prediction:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Predict structure from sequence
|
||||||
|
protein = ESMProtein(sequence="MPRTKEINDAGLIVHSP...")
|
||||||
|
protein_with_structure = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="structure", num_steps=protein.sequence.count("_"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Access predicted structure
|
||||||
|
coordinates = protein_with_structure.coordinates # 3D coordinates
|
||||||
|
pdb_string = protein_with_structure.to_pdb()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Inverse folding (sequence from structure):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Design sequence for a target structure
|
||||||
|
protein_with_structure = ESMProtein.from_pdb("target_structure.pdb")
|
||||||
|
protein_with_structure.sequence = None # Remove sequence
|
||||||
|
|
||||||
|
# Generate sequence that folds to this structure
|
||||||
|
designed_protein = model.generate(
|
||||||
|
protein_with_structure,
|
||||||
|
GenerationConfig(track="sequence", num_steps=50, temperature=0.7)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Protein Embeddings with ESM C
|
||||||
|
|
||||||
|
Generate high-quality embeddings for downstream tasks like function prediction, classification, or similarity analysis.
|
||||||
|
|
||||||
|
**When to use:**
|
||||||
|
- Extracting protein representations for machine learning
|
||||||
|
- Computing sequence similarities
|
||||||
|
- Feature extraction for protein classification
|
||||||
|
- Transfer learning for protein-related tasks
|
||||||
|
|
||||||
|
**Basic usage:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
from esm.sdk.api import ESMProtein
|
||||||
|
|
||||||
|
# Load ESM C model
|
||||||
|
model = ESMC.from_pretrained("esmc-300m").to("cuda")
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
protein = ESMProtein(sequence="MPRTKEINDAGLIVHSP...")
|
||||||
|
protein_tensor = model.encode(protein)
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
embeddings = model.forward(protein_tensor)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Batch processing:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Encode multiple proteins
|
||||||
|
proteins = [
|
||||||
|
ESMProtein(sequence="MPRTKEIND..."),
|
||||||
|
ESMProtein(sequence="AGLIVHSPQ..."),
|
||||||
|
ESMProtein(sequence="KTEFLNDGR...")
|
||||||
|
]
|
||||||
|
|
||||||
|
embeddings_list = [model.logits(model.forward(model.encode(p))) for p in proteins]
|
||||||
|
```
|
||||||
|
|
||||||
|
See `references/esm-c-api.md` for ESM C model details, efficiency comparisons, and advanced embedding strategies.
|
||||||
|
|
||||||
|
### 4. Function Conditioning and Annotation
|
||||||
|
|
||||||
|
Use ESM3's function track to generate proteins with specific functional annotations or predict function from sequence.
|
||||||
|
|
||||||
|
**Function-conditioned generation:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import ESMProtein, FunctionAnnotation, GenerationConfig
|
||||||
|
|
||||||
|
# Create protein with desired function
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="_" * 200, # Generate 200 residue protein
|
||||||
|
function_annotations=[
|
||||||
|
FunctionAnnotation(label="fluorescent_protein", start=50, end=150)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence with specified function
|
||||||
|
functional_protein = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="sequence", num_steps=200)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Chain-of-Thought Generation
|
||||||
|
|
||||||
|
Iteratively refine protein designs using ESM3's chain-of-thought generation approach.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import GenerationConfig
|
||||||
|
|
||||||
|
# Multi-step refinement
|
||||||
|
protein = ESMProtein(sequence="MPRT" + "_" * 100 + "KEND")
|
||||||
|
|
||||||
|
# Step 1: Generate initial structure
|
||||||
|
config = GenerationConfig(track="structure", num_steps=50)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
|
||||||
|
# Step 2: Refine sequence based on structure
|
||||||
|
config = GenerationConfig(track="sequence", num_steps=50, temperature=0.5)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
|
||||||
|
# Step 3: Predict function
|
||||||
|
config = GenerationConfig(track="function", num_steps=20)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Batch Processing with Forge API
|
||||||
|
|
||||||
|
Process multiple proteins efficiently using Forge's async executor.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
client = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", token="<token>")
|
||||||
|
|
||||||
|
# Async batch processing
|
||||||
|
async def batch_generate(proteins_list):
|
||||||
|
tasks = [
|
||||||
|
client.async_generate(protein, GenerationConfig(track="sequence"))
|
||||||
|
for protein in proteins_list
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
proteins = [ESMProtein(sequence=f"MPRT{'_' * 50}KEND") for _ in range(10)]
|
||||||
|
results = asyncio.run(batch_generate(proteins))
|
||||||
|
```
|
||||||
|
|
||||||
|
See `references/forge-api.md` for detailed Forge API documentation, authentication, rate limits, and batch processing patterns.
|
||||||
|
|
||||||
|
## Model Selection Guide
|
||||||
|
|
||||||
|
**ESM3 Models (Generative):**
|
||||||
|
- `esm3-sm-open-v1` (1.4B) - Open weights, local usage, good for experimentation
|
||||||
|
- `esm3-medium-2024-08` (7B) - Best balance of quality and speed (Forge only)
|
||||||
|
- `esm3-large-2024-03` (98B) - Highest quality, slower (Forge only)
|
||||||
|
|
||||||
|
**ESM C Models (Embeddings):**
|
||||||
|
- `esmc-300m` (30 layers) - Lightweight, fast inference
|
||||||
|
- `esmc-600m` (36 layers) - Balanced performance
|
||||||
|
- `esmc-6b` (80 layers) - Maximum representation quality
|
||||||
|
|
||||||
|
**Selection criteria:**
|
||||||
|
- **Local development/testing:** Use `esm3-sm-open-v1` or `esmc-300m`
|
||||||
|
- **Production quality:** Use `esm3-medium-2024-08` via Forge
|
||||||
|
- **Maximum accuracy:** Use `esm3-large-2024-03` or `esmc-6b`
|
||||||
|
- **High throughput:** Use Forge API with batch executor
|
||||||
|
- **Cost optimization:** Use smaller models, implement caching strategies
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
**Basic installation:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install esm
|
||||||
|
```
|
||||||
|
|
||||||
|
**With Flash Attention (recommended for faster inference):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install esm
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
**For Forge API access:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install esm # SDK includes Forge client
|
||||||
|
```
|
||||||
|
|
||||||
|
No additional dependencies needed. Obtain Forge API token at https://forge.evolutionaryscale.ai
|
||||||
|
|
||||||
|
## Common Workflows
|
||||||
|
|
||||||
|
For detailed examples and complete workflows, see `references/workflows.md` which includes:
|
||||||
|
- Novel GFP design with chain-of-thought
|
||||||
|
- Protein variant generation and screening
|
||||||
|
- Structure-based sequence optimization
|
||||||
|
- Function prediction pipelines
|
||||||
|
- Embedding-based clustering and analysis
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
This skill includes comprehensive reference documentation:
|
||||||
|
|
||||||
|
- `references/esm3-api.md` - ESM3 model architecture, API reference, generation parameters, and multimodal prompting
|
||||||
|
- `references/esm-c-api.md` - ESM C model details, embedding strategies, and performance optimization
|
||||||
|
- `references/forge-api.md` - Forge platform documentation, authentication, batch processing, and deployment
|
||||||
|
- `references/workflows.md` - Complete examples and common workflow patterns
|
||||||
|
|
||||||
|
These references contain detailed API specifications, parameter descriptions, and advanced usage patterns. Load them as needed for specific tasks.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
**For generation tasks:**
|
||||||
|
- Start with smaller models for prototyping (`esm3-sm-open-v1`)
|
||||||
|
- Use temperature parameter to control diversity (0.0 = deterministic, 1.0 = diverse)
|
||||||
|
- Implement iterative refinement with chain-of-thought for complex designs
|
||||||
|
- Validate generated sequences with structure prediction or wet-lab experiments
|
||||||
|
|
||||||
|
**For embedding tasks:**
|
||||||
|
- Batch process sequences when possible for efficiency
|
||||||
|
- Cache embeddings for repeated analyses
|
||||||
|
- Normalize embeddings when computing similarities
|
||||||
|
- Use appropriate model size based on downstream task requirements
|
||||||
|
|
||||||
|
**For production deployment:**
|
||||||
|
- Use Forge API for scalability and latest models
|
||||||
|
- Implement error handling and retry logic for API calls
|
||||||
|
- Monitor token usage and implement rate limiting
|
||||||
|
- Consider AWS SageMaker deployment for dedicated infrastructure
|
||||||
|
|
||||||
|
## Resources and Documentation
|
||||||
|
|
||||||
|
- **GitHub Repository:** https://github.com/evolutionaryscale/esm
|
||||||
|
- **Forge Platform:** https://forge.evolutionaryscale.ai
|
||||||
|
- **Scientific Paper:** Hayes et al., Science (2025) - https://www.science.org/doi/10.1126/science.ads0018
|
||||||
|
- **Blog Posts:**
|
||||||
|
- ESM3 Release: https://www.evolutionaryscale.ai/blog/esm3-release
|
||||||
|
- ESM C Launch: https://www.evolutionaryscale.ai/blog/esm-cambrian
|
||||||
|
- **Community:** Slack community at https://bit.ly/3FKwcWd
|
||||||
|
- **Model Weights:** HuggingFace EvolutionaryScale organization
|
||||||
|
|
||||||
|
## Responsible Use
|
||||||
|
|
||||||
|
ESM is designed for beneficial applications in protein engineering, drug discovery, and scientific research. Follow the Responsible Biodesign Framework (https://responsiblebiodesign.ai/) when designing novel proteins. Consider biosafety and ethical implications of protein designs before experimental validation.
|
||||||
583
scientific-packages/esm/references/esm-c-api.md
Normal file
583
scientific-packages/esm/references/esm-c-api.md
Normal file
@@ -0,0 +1,583 @@
|
|||||||
|
# ESM C API Reference
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
ESM C (Cambrian) is a family of protein language models optimized for representation learning and efficient embedding generation. Designed as a drop-in replacement for ESM2, ESM C provides significant improvements in speed and quality across all model sizes.
|
||||||
|
|
||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
**ESM C Family Models:**
|
||||||
|
|
||||||
|
| Model ID | Parameters | Layers | Best For |
|
||||||
|
|----------|-----------|--------|----------|
|
||||||
|
| `esmc-300m` | 300M | 30 | Fast inference, lightweight applications |
|
||||||
|
| `esmc-600m` | 600M | 36 | Balanced performance and quality |
|
||||||
|
| `esmc-6b` | 6B | 80 | Maximum representation quality |
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- 3x faster inference than ESM2
|
||||||
|
- Improved perplexity and embedding quality
|
||||||
|
- Efficient architecture for production deployment
|
||||||
|
- Compatible with ESM2 workflows (drop-in replacement)
|
||||||
|
- Support for long sequences (up to 1024 residues efficiently)
|
||||||
|
|
||||||
|
**Architecture Improvements over ESM2:**
|
||||||
|
- Optimized attention mechanisms
|
||||||
|
- Better token representation
|
||||||
|
- Enhanced training procedures
|
||||||
|
- Reduced memory footprint
|
||||||
|
|
||||||
|
## Core API Components
|
||||||
|
|
||||||
|
### ESMC Class
|
||||||
|
|
||||||
|
Main interface for ESM C models.
|
||||||
|
|
||||||
|
**Model Loading:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
from esm.sdk.api import ESMProtein
|
||||||
|
|
||||||
|
# Load model with automatic device placement
|
||||||
|
model = ESMC.from_pretrained("esmc-300m").to("cuda")
|
||||||
|
|
||||||
|
# Or specify device explicitly
|
||||||
|
model = ESMC.from_pretrained("esmc-600m").to("cpu")
|
||||||
|
|
||||||
|
# For maximum quality
|
||||||
|
model = ESMC.from_pretrained("esmc-6b").to("cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Model Selection Criteria:**
|
||||||
|
|
||||||
|
- **esmc-300m**: Development, real-time applications, batch processing of many sequences
|
||||||
|
- **esmc-600m**: Production deployments, good quality/speed balance
|
||||||
|
- **esmc-6b**: Research, maximum accuracy for downstream tasks
|
||||||
|
|
||||||
|
### Basic Embedding Generation
|
||||||
|
|
||||||
|
**Single Sequence:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
from esm.sdk.api import ESMProtein
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = ESMC.from_pretrained("esmc-600m").to("cuda")
|
||||||
|
|
||||||
|
# Create protein
|
||||||
|
protein = ESMProtein(sequence="MPRTKEINDAGLIVHSPQWFYK")
|
||||||
|
|
||||||
|
# Encode to tensor
|
||||||
|
protein_tensor = model.encode(protein)
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
embeddings = model.forward(protein_tensor)
|
||||||
|
|
||||||
|
# Get logits (per-position predictions)
|
||||||
|
logits = model.logits(embeddings)
|
||||||
|
|
||||||
|
print(f"Embedding shape: {embeddings.shape}")
|
||||||
|
print(f"Logits shape: {logits.shape}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output Shapes:**
|
||||||
|
|
||||||
|
For a sequence of length L:
|
||||||
|
- `embeddings.shape`: `(1, L, hidden_dim)` where hidden_dim depends on model
|
||||||
|
- esmc-300m: hidden_dim = 960
|
||||||
|
- esmc-600m: hidden_dim = 1152
|
||||||
|
- esmc-6b: hidden_dim = 2560
|
||||||
|
- `logits.shape`: `(1, L, 64)` - per-position amino acid predictions
|
||||||
|
|
||||||
|
### Batch Processing
|
||||||
|
|
||||||
|
Process multiple sequences efficiently:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Multiple proteins
|
||||||
|
sequences = [
|
||||||
|
"MPRTKEINDAGLIVHSP",
|
||||||
|
"AGKWFYLTQSNHERVPM",
|
||||||
|
"DEIFKRNAVWGSLTPQY"
|
||||||
|
]
|
||||||
|
|
||||||
|
proteins = [ESMProtein(sequence=seq) for seq in sequences]
|
||||||
|
|
||||||
|
# Encode all
|
||||||
|
protein_tensors = [model.encode(p) for p in proteins]
|
||||||
|
|
||||||
|
# Process batch (if same length)
|
||||||
|
# For variable lengths, process individually or pad
|
||||||
|
embeddings_list = []
|
||||||
|
for tensor in protein_tensors:
|
||||||
|
embedding = model.forward(tensor)
|
||||||
|
embeddings_list.append(embedding)
|
||||||
|
|
||||||
|
print(f"Processed {len(embeddings_list)} proteins")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Efficient Batching for Variable Lengths:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def batch_encode_variable_length(model, sequences, max_batch_size=32):
|
||||||
|
"""
|
||||||
|
Efficiently batch encode sequences of variable length.
|
||||||
|
Groups by similar length for efficiency.
|
||||||
|
"""
|
||||||
|
# Sort by length
|
||||||
|
sorted_seqs = sorted(enumerate(sequences), key=lambda x: len(x[1]))
|
||||||
|
|
||||||
|
results = [None] * len(sequences)
|
||||||
|
batch = []
|
||||||
|
batch_indices = []
|
||||||
|
|
||||||
|
for idx, seq in sorted_seqs:
|
||||||
|
batch.append(seq)
|
||||||
|
batch_indices.append(idx)
|
||||||
|
|
||||||
|
# Process batch when full or length changes significantly
|
||||||
|
if (len(batch) >= max_batch_size or
|
||||||
|
(len(batch) > 0 and abs(len(seq) - len(batch[0])) > 10)):
|
||||||
|
|
||||||
|
# Process current batch
|
||||||
|
proteins = [ESMProtein(sequence=s) for s in batch]
|
||||||
|
embeddings = [model.forward(model.encode(p)) for p in proteins]
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
for i, emb in zip(batch_indices, embeddings):
|
||||||
|
results[i] = emb
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
batch_indices = []
|
||||||
|
|
||||||
|
# Process remaining
|
||||||
|
if batch:
|
||||||
|
proteins = [ESMProtein(sequence=s) for s in batch]
|
||||||
|
embeddings = [model.forward(model.encode(p)) for p in proteins]
|
||||||
|
for i, emb in zip(batch_indices, embeddings):
|
||||||
|
results[i] = emb
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
### 1. Sequence Similarity Analysis
|
||||||
|
|
||||||
|
Compute similarity between proteins using embeddings:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def get_sequence_embedding(model, sequence):
|
||||||
|
"""Get mean-pooled sequence embedding."""
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
tensor = model.encode(protein)
|
||||||
|
embedding = model.forward(tensor)
|
||||||
|
|
||||||
|
# Mean pooling over sequence length
|
||||||
|
return embedding.mean(dim=1)
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
seq1_emb = get_sequence_embedding(model, "MPRTKEINDAGLIVHSP")
|
||||||
|
seq2_emb = get_sequence_embedding(model, "MPRTKEINDAGLIVHSQ") # Similar
|
||||||
|
seq3_emb = get_sequence_embedding(model, "WWWWWWWWWWWWWWWWW") # Different
|
||||||
|
|
||||||
|
# Compute cosine similarity
|
||||||
|
sim_1_2 = F.cosine_similarity(seq1_emb, seq2_emb)
|
||||||
|
sim_1_3 = F.cosine_similarity(seq1_emb, seq3_emb)
|
||||||
|
|
||||||
|
print(f"Similarity (1,2): {sim_1_2.item():.4f}")
|
||||||
|
print(f"Similarity (1,3): {sim_1_3.item():.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Protein Classification
|
||||||
|
|
||||||
|
Use embeddings as features for classification:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
# Generate embeddings for training set
|
||||||
|
def embed_dataset(model, sequences):
|
||||||
|
embeddings = []
|
||||||
|
for seq in sequences:
|
||||||
|
protein = ESMProtein(sequence=seq)
|
||||||
|
tensor = model.encode(protein)
|
||||||
|
emb = model.forward(tensor).mean(dim=1) # Mean pooling
|
||||||
|
embeddings.append(emb.cpu().detach().numpy().flatten())
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
# Example: Classify proteins by function
|
||||||
|
train_sequences = [...] # Your sequences
|
||||||
|
train_labels = [...] # Your labels
|
||||||
|
|
||||||
|
embeddings = embed_dataset(model, train_sequences)
|
||||||
|
|
||||||
|
# Train classifier
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
embeddings, train_labels, test_size=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
classifier = LogisticRegression(max_iter=1000)
|
||||||
|
classifier.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
accuracy = classifier.score(X_test, y_test)
|
||||||
|
print(f"Classification accuracy: {accuracy:.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Protein Clustering
|
||||||
|
|
||||||
|
Cluster proteins based on sequence similarity:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
sequences = [...] # Your protein sequences
|
||||||
|
embeddings = embed_dataset(model, sequences)
|
||||||
|
|
||||||
|
# Cluster
|
||||||
|
n_clusters = 5
|
||||||
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
||||||
|
cluster_labels = kmeans.fit_predict(embeddings)
|
||||||
|
|
||||||
|
# Analyze clusters
|
||||||
|
for i in range(n_clusters):
|
||||||
|
cluster_seqs = [seq for seq, label in zip(sequences, cluster_labels) if label == i]
|
||||||
|
print(f"Cluster {i}: {len(cluster_seqs)} sequences")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Sequence Search and Retrieval
|
||||||
|
|
||||||
|
Find similar sequences in a database:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
def build_sequence_index(model, database_sequences):
|
||||||
|
"""Build searchable index of sequence embeddings."""
|
||||||
|
embeddings = []
|
||||||
|
for seq in database_sequences:
|
||||||
|
emb = get_sequence_embedding(model, seq)
|
||||||
|
embeddings.append(emb.cpu().detach().numpy().flatten())
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
def search_similar_sequences(model, query_seq, database_embeddings,
|
||||||
|
database_sequences, top_k=10):
|
||||||
|
"""Find top-k most similar sequences."""
|
||||||
|
query_emb = get_sequence_embedding(model, query_seq)
|
||||||
|
query_emb_np = query_emb.cpu().detach().numpy().flatten().reshape(1, -1)
|
||||||
|
|
||||||
|
# Compute similarities
|
||||||
|
similarities = cosine_similarity(query_emb_np, database_embeddings)[0]
|
||||||
|
|
||||||
|
# Get top-k
|
||||||
|
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
||||||
|
|
||||||
|
results = [
|
||||||
|
(database_sequences[idx], similarities[idx])
|
||||||
|
for idx in top_indices
|
||||||
|
]
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
database_seqs = [...] # Large sequence database
|
||||||
|
index = build_sequence_index(model, database_seqs)
|
||||||
|
|
||||||
|
query = "MPRTKEINDAGLIVHSP"
|
||||||
|
similar = search_similar_sequences(model, query, index, database_seqs, top_k=5)
|
||||||
|
|
||||||
|
for seq, score in similar:
|
||||||
|
print(f"Score: {score:.4f} - {seq[:30]}...")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Feature Extraction for Downstream Models
|
||||||
|
|
||||||
|
Use ESM C embeddings as input to custom neural networks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class ProteinPropertyPredictor(nn.Module):
|
||||||
|
"""Example: Predict protein properties from ESM C embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, embedding_dim, hidden_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
|
||||||
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
self.fc3 = nn.Linear(hidden_dim, output_dim)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(0.3)
|
||||||
|
|
||||||
|
def forward(self, embeddings):
|
||||||
|
# embeddings: (batch, seq_len, embedding_dim)
|
||||||
|
# Mean pool over sequence
|
||||||
|
x = embeddings.mean(dim=1)
|
||||||
|
|
||||||
|
x = self.relu(self.fc1(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.relu(self.fc2(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Use ESM C as frozen feature extractor
|
||||||
|
esm_model = ESMC.from_pretrained("esmc-600m").to("cuda")
|
||||||
|
esm_model.eval() # Freeze
|
||||||
|
|
||||||
|
# Create task-specific model
|
||||||
|
predictor = ProteinPropertyPredictor(
|
||||||
|
embedding_dim=1152, # esmc-600m dimension
|
||||||
|
hidden_dim=512,
|
||||||
|
output_dim=1 # e.g., stability score
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for sequence, target in dataloader:
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = esm_model.forward(esm_model.encode(protein))
|
||||||
|
|
||||||
|
prediction = predictor(embeddings)
|
||||||
|
loss = criterion(prediction, target)
|
||||||
|
# ... backprop through predictor only
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Per-Residue Analysis
|
||||||
|
|
||||||
|
Extract per-residue representations for detailed analysis:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_per_residue_embeddings(model, sequence):
|
||||||
|
"""Get embedding for each residue."""
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
tensor = model.encode(protein)
|
||||||
|
embeddings = model.forward(tensor)
|
||||||
|
|
||||||
|
# embeddings shape: (1, seq_len, hidden_dim)
|
||||||
|
return embeddings.squeeze(0) # (seq_len, hidden_dim)
|
||||||
|
|
||||||
|
# Analyze specific positions
|
||||||
|
sequence = "MPRTKEINDAGLIVHSPQWFYK"
|
||||||
|
residue_embeddings = get_per_residue_embeddings(model, sequence)
|
||||||
|
|
||||||
|
# Extract features for position 10
|
||||||
|
position_10_features = residue_embeddings[10]
|
||||||
|
print(f"Features for residue {sequence[10]} at position 10:")
|
||||||
|
print(f"Shape: {position_10_features.shape}")
|
||||||
|
|
||||||
|
# Compare residue representations
|
||||||
|
pos_5 = residue_embeddings[5]
|
||||||
|
pos_15 = residue_embeddings[15]
|
||||||
|
similarity = F.cosine_similarity(pos_5, pos_15, dim=0)
|
||||||
|
print(f"Residue similarity: {similarity.item():.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Optimization
|
||||||
|
|
||||||
|
### Memory Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Use half precision for memory efficiency
|
||||||
|
model = ESMC.from_pretrained("esmc-600m").to("cuda").half()
|
||||||
|
|
||||||
|
# Process with mixed precision
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
embeddings = model.forward(model.encode(protein))
|
||||||
|
|
||||||
|
# Clear cache between batches
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batch Processing Best Practices
|
||||||
|
|
||||||
|
```python
|
||||||
|
def efficient_batch_processing(model, sequences, batch_size=32):
|
||||||
|
"""Process sequences in optimized batches."""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in range(0, len(sequences), batch_size):
|
||||||
|
batch = sequences[i:i + batch_size]
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
batch_embeddings = []
|
||||||
|
for seq in batch:
|
||||||
|
protein = ESMProtein(sequence=seq)
|
||||||
|
emb = model.forward(model.encode(protein))
|
||||||
|
batch_embeddings.append(emb)
|
||||||
|
|
||||||
|
results.extend(batch_embeddings)
|
||||||
|
|
||||||
|
# Periodically clear cache
|
||||||
|
if i % (batch_size * 10) == 0:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Caching Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pickle
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
def get_cache_key(sequence):
|
||||||
|
"""Generate cache key for sequence."""
|
||||||
|
return hashlib.md5(sequence.encode()).hexdigest()
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""Cache for protein embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, cache_file="embeddings_cache.pkl"):
|
||||||
|
self.cache_file = cache_file
|
||||||
|
try:
|
||||||
|
with open(cache_file, 'rb') as f:
|
||||||
|
self.cache = pickle.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def get(self, sequence):
|
||||||
|
key = get_cache_key(sequence)
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def set(self, sequence, embedding):
|
||||||
|
key = get_cache_key(sequence)
|
||||||
|
self.cache[key] = embedding
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.cache_file, 'wb') as f:
|
||||||
|
pickle.dump(self.cache, f)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
cache = EmbeddingCache()
|
||||||
|
|
||||||
|
def get_embedding_cached(model, sequence):
|
||||||
|
cached = cache.get(sequence)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Compute
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
embedding = model.forward(model.encode(protein))
|
||||||
|
cache.set(sequence, embedding)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
# Don't forget to save cache
|
||||||
|
cache.save()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Comparison with ESM2
|
||||||
|
|
||||||
|
**Performance Improvements:**
|
||||||
|
|
||||||
|
| Metric | ESM2-650M | ESM C-600M | Improvement |
|
||||||
|
|--------|-----------|------------|-------------|
|
||||||
|
| Inference Speed | 1.0x | 3.0x | 3x faster |
|
||||||
|
| Perplexity | Higher | Lower | Better |
|
||||||
|
| Memory Usage | 1.0x | 0.8x | 20% less |
|
||||||
|
| Embedding Quality | Baseline | Improved | +5-10% |
|
||||||
|
|
||||||
|
**Migration from ESM2:**
|
||||||
|
|
||||||
|
ESM C is designed as a drop-in replacement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Old ESM2 code
|
||||||
|
from esm import pretrained
|
||||||
|
model, alphabet = pretrained.esm2_t33_650M_UR50D()
|
||||||
|
|
||||||
|
# New ESM C code (similar API)
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
model = ESMC.from_pretrained("esmc-600m")
|
||||||
|
```
|
||||||
|
|
||||||
|
Key differences:
|
||||||
|
- Faster inference with same or better quality
|
||||||
|
- Simplified API through ESMProtein
|
||||||
|
- Better support for long sequences
|
||||||
|
- More efficient memory usage
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
### Fine-tuning ESM C
|
||||||
|
|
||||||
|
ESM C can be fine-tuned for specific tasks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = ESMC.from_pretrained("esmc-300m").to("cuda")
|
||||||
|
|
||||||
|
# Unfreeze for fine-tuning
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
# Define optimizer
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=1e-5)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
for sequences, labels in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
proteins = [ESMProtein(sequence=seq) for seq in sequences]
|
||||||
|
embeddings = [model.forward(model.encode(p)) for p in proteins]
|
||||||
|
|
||||||
|
# Your task-specific loss
|
||||||
|
loss = compute_loss(embeddings, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Attention Visualization
|
||||||
|
|
||||||
|
Extract attention weights for interpretability:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_attention_weights(model, sequence):
|
||||||
|
"""Extract attention weights from model."""
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
tensor = model.encode(protein)
|
||||||
|
|
||||||
|
# Forward with attention output
|
||||||
|
output = model.forward(tensor, output_attentions=True)
|
||||||
|
|
||||||
|
return output.attentions # List of attention tensors per layer
|
||||||
|
|
||||||
|
# Visualize attention
|
||||||
|
attentions = get_attention_weights(model, "MPRTKEINDAGLIVHSP")
|
||||||
|
# Process and visualize attention patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If using ESM C in research, cite:
|
||||||
|
|
||||||
|
```
|
||||||
|
ESM Cambrian: https://www.evolutionaryscale.ai/blog/esm-cambrian
|
||||||
|
EvolutionaryScale (2024)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- ESM C blog post: https://www.evolutionaryscale.ai/blog/esm-cambrian
|
||||||
|
- Model weights: HuggingFace EvolutionaryScale organization
|
||||||
|
- Comparison benchmarks: See blog post for detailed performance comparisons
|
||||||
452
scientific-packages/esm/references/esm3-api.md
Normal file
452
scientific-packages/esm/references/esm3-api.md
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
# ESM3 API Reference
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
ESM3 is a frontier multimodal generative language model that reasons over the sequence, structure, and function of proteins. It uses iterative masked language modeling to simultaneously generate across these three modalities.
|
||||||
|
|
||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
**ESM3 Family Models:**
|
||||||
|
|
||||||
|
| Model ID | Parameters | Availability | Best For |
|
||||||
|
|----------|-----------|--------------|----------|
|
||||||
|
| `esm3-sm-open-v1` | 1.4B | Open weights (local) | Development, testing, learning |
|
||||||
|
| `esm3-medium-2024-08` | 7B | Forge API only | Production, balanced quality/speed |
|
||||||
|
| `esm3-large-2024-03` | 98B | Forge API only | Maximum quality, research |
|
||||||
|
| `esm3-medium-multimer-2024-09` | 7B | Forge API only | Protein complexes (experimental) |
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Simultaneous reasoning across sequence, structure, and function
|
||||||
|
- Iterative generation with controllable number of steps
|
||||||
|
- Support for partial prompting across modalities
|
||||||
|
- Chain-of-thought generation for complex designs
|
||||||
|
- Temperature control for generation diversity
|
||||||
|
|
||||||
|
## Core API Components
|
||||||
|
|
||||||
|
### ESMProtein Class
|
||||||
|
|
||||||
|
The central data structure representing a protein with optional sequence, structure, and function information.
|
||||||
|
|
||||||
|
**Constructor:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import ESMProtein
|
||||||
|
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="MPRTKEINDAGLIVHSP", # Amino acid sequence (optional)
|
||||||
|
coordinates=coordinates_array, # 3D structure (optional)
|
||||||
|
function_annotations=[...], # Function labels (optional)
|
||||||
|
secondary_structure="HHHEEEECCC", # SS annotations (optional)
|
||||||
|
sasa=sasa_array # Solvent accessibility (optional)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Methods:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Load from PDB file
|
||||||
|
protein = ESMProtein.from_pdb("protein.pdb")
|
||||||
|
|
||||||
|
# Export to PDB format
|
||||||
|
pdb_string = protein.to_pdb()
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open("output.pdb", "w") as f:
|
||||||
|
f.write(protein.to_pdb())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Masking Conventions:**
|
||||||
|
|
||||||
|
Use `_` (underscore) to represent masked positions for generation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Mask positions 5-10 for generation
|
||||||
|
protein = ESMProtein(sequence="MPRT______AGLIVHSP")
|
||||||
|
|
||||||
|
# Fully masked sequence (generate from scratch)
|
||||||
|
protein = ESMProtein(sequence="_" * 200)
|
||||||
|
|
||||||
|
# Partial structure (some coordinates None)
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="MPRTKEIND",
|
||||||
|
coordinates=partial_coords # Some positions can be None
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### GenerationConfig Class
|
||||||
|
|
||||||
|
Controls generation behavior and parameters.
|
||||||
|
|
||||||
|
**Basic Configuration:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import GenerationConfig
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence", # Track to generate: "sequence", "structure", or "function"
|
||||||
|
num_steps=8, # Number of demasking steps
|
||||||
|
temperature=0.7, # Sampling temperature (0.0-1.0)
|
||||||
|
top_p=None, # Nucleus sampling threshold
|
||||||
|
condition_on_coordinates_only=False # For structure conditioning
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameter Details:**
|
||||||
|
|
||||||
|
- **track**: Which modality to generate
|
||||||
|
- `"sequence"`: Generate amino acid sequence
|
||||||
|
- `"structure"`: Generate 3D coordinates
|
||||||
|
- `"function"`: Generate function annotations
|
||||||
|
|
||||||
|
- **num_steps**: Number of iterative demasking steps
|
||||||
|
- Higher = slower but potentially better quality
|
||||||
|
- Typical range: 8-100 depending on sequence length
|
||||||
|
- For full sequence generation: approximately sequence_length / 2
|
||||||
|
|
||||||
|
- **temperature**: Controls randomness
|
||||||
|
- 0.0: Fully deterministic (greedy decoding)
|
||||||
|
- 0.5-0.7: Balanced exploration
|
||||||
|
- 1.0: Maximum diversity
|
||||||
|
- Higher values increase novelty but may reduce quality
|
||||||
|
|
||||||
|
- **top_p**: Nucleus sampling parameter
|
||||||
|
- Limits sampling to top probability mass
|
||||||
|
- Values: 0.0-1.0 (e.g., 0.9 = sample from top 90% probability mass)
|
||||||
|
- Use for controlled diversity without extreme sampling
|
||||||
|
|
||||||
|
- **condition_on_coordinates_only**: Structure conditioning mode
|
||||||
|
- `True`: Condition only on backbone coordinates (ignore sequence)
|
||||||
|
- Useful for inverse folding tasks
|
||||||
|
|
||||||
|
### ESM3InferenceClient Interface
|
||||||
|
|
||||||
|
The unified interface for both local and remote inference.
|
||||||
|
|
||||||
|
**Local Model Loading:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
|
||||||
|
# Load with automatic device placement
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
|
||||||
|
# Or explicitly specify device
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cpu")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Generation Method:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Basic generation
|
||||||
|
protein_output = model.generate(protein_input, config)
|
||||||
|
|
||||||
|
# With explicit track specification
|
||||||
|
protein_output = model.generate(
|
||||||
|
protein_input,
|
||||||
|
GenerationConfig(track="sequence", num_steps=16, temperature=0.6)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Forward Pass (Advanced):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Get raw model logits for custom sampling
|
||||||
|
protein_tensor = model.encode(protein)
|
||||||
|
output = model.forward(protein_tensor)
|
||||||
|
logits = model.decode(output)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Usage Patterns
|
||||||
|
|
||||||
|
### 1. Sequence Completion
|
||||||
|
|
||||||
|
Fill in masked regions of a protein sequence:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define partial sequence
|
||||||
|
protein = ESMProtein(sequence="MPRTK____LIVHSP____END")
|
||||||
|
|
||||||
|
# Generate missing positions
|
||||||
|
config = GenerationConfig(track="sequence", num_steps=12, temperature=0.5)
|
||||||
|
completed = model.generate(protein, config)
|
||||||
|
|
||||||
|
print(f"Original: {protein.sequence}")
|
||||||
|
print(f"Completed: {completed.sequence}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Structure Prediction
|
||||||
|
|
||||||
|
Predict 3D structure from sequence:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Input: sequence only
|
||||||
|
protein = ESMProtein(sequence="MPRTKEINDAGLIVHSPQWFYK")
|
||||||
|
|
||||||
|
# Generate structure
|
||||||
|
config = GenerationConfig(track="structure", num_steps=len(protein.sequence))
|
||||||
|
protein_with_structure = model.generate(protein, config)
|
||||||
|
|
||||||
|
# Save as PDB
|
||||||
|
with open("predicted_structure.pdb", "w") as f:
|
||||||
|
f.write(protein_with_structure.to_pdb())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Inverse Folding
|
||||||
|
|
||||||
|
Design sequence for a target structure:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Load target structure
|
||||||
|
target = ESMProtein.from_pdb("target.pdb")
|
||||||
|
|
||||||
|
# Remove sequence, keep structure
|
||||||
|
target.sequence = None
|
||||||
|
|
||||||
|
# Generate sequence that folds to this structure
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=50,
|
||||||
|
temperature=0.7,
|
||||||
|
condition_on_coordinates_only=True
|
||||||
|
)
|
||||||
|
designed = model.generate(target, config)
|
||||||
|
|
||||||
|
print(f"Designed sequence: {designed.sequence}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Function-Conditioned Generation
|
||||||
|
|
||||||
|
Generate protein with specific function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import FunctionAnnotation
|
||||||
|
|
||||||
|
# Specify desired function
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="_" * 150,
|
||||||
|
function_annotations=[
|
||||||
|
FunctionAnnotation(
|
||||||
|
label="enzymatic_activity",
|
||||||
|
start=30,
|
||||||
|
end=90
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence with this function
|
||||||
|
config = GenerationConfig(track="sequence", num_steps=75, temperature=0.6)
|
||||||
|
functional_protein = model.generate(protein, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Track Generation (Chain-of-Thought)
|
||||||
|
|
||||||
|
Iteratively generate across multiple tracks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Start with partial sequence
|
||||||
|
protein = ESMProtein(sequence="MPRT" + "_" * 100)
|
||||||
|
|
||||||
|
# Step 1: Complete sequence
|
||||||
|
protein = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="sequence", num_steps=50, temperature=0.6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Predict structure for completed sequence
|
||||||
|
protein = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="structure", num_steps=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Predict function
|
||||||
|
protein = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="function", num_steps=20)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Final sequence: {protein.sequence}")
|
||||||
|
print(f"Functions: {protein.function_annotations}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Variant Generation
|
||||||
|
|
||||||
|
Generate multiple variants of a protein:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
base_sequence = "MPRTKEINDAGLIVHSPQWFYK"
|
||||||
|
variants = []
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
# Mask random positions
|
||||||
|
seq_list = list(base_sequence)
|
||||||
|
mask_indices = np.random.choice(len(seq_list), size=5, replace=False)
|
||||||
|
for idx in mask_indices:
|
||||||
|
seq_list[idx] = '_'
|
||||||
|
|
||||||
|
protein = ESMProtein(sequence=''.join(seq_list))
|
||||||
|
|
||||||
|
# Generate variant
|
||||||
|
variant = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="sequence", num_steps=8, temperature=0.8)
|
||||||
|
)
|
||||||
|
variants.append(variant.sequence)
|
||||||
|
|
||||||
|
print(f"Generated {len(variants)} variants")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
### Temperature Scheduling
|
||||||
|
|
||||||
|
Vary temperature during generation for better control:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def generate_with_temperature_schedule(model, protein, temperatures):
|
||||||
|
"""Generate with decreasing temperature for annealing."""
|
||||||
|
current = protein
|
||||||
|
steps_per_temp = 10
|
||||||
|
|
||||||
|
for temp in temperatures:
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=steps_per_temp,
|
||||||
|
temperature=temp
|
||||||
|
)
|
||||||
|
current = model.generate(current, config)
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
# Example: Start diverse, end deterministic
|
||||||
|
result = generate_with_temperature_schedule(
|
||||||
|
model,
|
||||||
|
protein,
|
||||||
|
temperatures=[1.0, 0.8, 0.6, 0.4, 0.2]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Constrained Generation
|
||||||
|
|
||||||
|
Preserve specific regions during generation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Keep active site residues fixed
|
||||||
|
def mask_except_active_site(sequence, active_site_positions):
|
||||||
|
"""Mask everything except specified positions."""
|
||||||
|
seq_list = ['_'] * len(sequence)
|
||||||
|
for pos in active_site_positions:
|
||||||
|
seq_list[pos] = sequence[pos]
|
||||||
|
return ''.join(seq_list)
|
||||||
|
|
||||||
|
# Define active site
|
||||||
|
active_site = [23, 24, 25, 45, 46, 89]
|
||||||
|
constrained_seq = mask_except_active_site(original_sequence, active_site)
|
||||||
|
|
||||||
|
protein = ESMProtein(sequence=constrained_seq)
|
||||||
|
result = model.generate(protein, GenerationConfig(track="sequence", num_steps=50))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Secondary Structure Conditioning
|
||||||
|
|
||||||
|
Use secondary structure information in generation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define secondary structure (H=helix, E=sheet, C=coil)
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="_" * 80,
|
||||||
|
secondary_structure="CCHHHHHHHEEEEECCCHHHHHHCC" + "C" * 55
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence with this structure
|
||||||
|
result = model.generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="sequence", num_steps=40, temperature=0.6)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Optimization
|
||||||
|
|
||||||
|
### Memory Management
|
||||||
|
|
||||||
|
For large proteins or batch processing:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Clear CUDA cache between generations
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Use half precision for memory efficiency
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda").half()
|
||||||
|
|
||||||
|
# Process in chunks for very long sequences
|
||||||
|
def chunk_generate(model, long_sequence, chunk_size=500):
|
||||||
|
chunks = [long_sequence[i:i+chunk_size]
|
||||||
|
for i in range(0, len(long_sequence), chunk_size)]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
protein = ESMProtein(sequence=chunk)
|
||||||
|
result = model.generate(protein, GenerationConfig(track="sequence"))
|
||||||
|
results.append(result.sequence)
|
||||||
|
|
||||||
|
return ''.join(results)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batch Processing Tips
|
||||||
|
|
||||||
|
When processing multiple proteins:
|
||||||
|
|
||||||
|
1. Sort by sequence length for efficient batching
|
||||||
|
2. Use padding for similar-length sequences
|
||||||
|
3. Process on GPU when available
|
||||||
|
4. Implement checkpointing for long-running jobs
|
||||||
|
5. Use Forge API for large-scale processing (see `forge-api.md`)
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
try:
|
||||||
|
protein = model.generate(protein_input, config)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Invalid input: {e}")
|
||||||
|
# Handle invalid sequence or structure
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Generation failed: {e}")
|
||||||
|
# Handle model errors
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
print("GPU out of memory - try smaller model or CPU")
|
||||||
|
# Fallback to CPU or smaller model
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model-Specific Considerations
|
||||||
|
|
||||||
|
**esm3-sm-open-v1:**
|
||||||
|
- Suitable for development and testing
|
||||||
|
- Lower quality than larger models
|
||||||
|
- Fast inference on consumer GPUs
|
||||||
|
- Open weights allow fine-tuning
|
||||||
|
|
||||||
|
**esm3-medium-2024-08:**
|
||||||
|
- Production quality
|
||||||
|
- Good balance of speed and accuracy
|
||||||
|
- Requires Forge API access
|
||||||
|
- Recommended for most applications
|
||||||
|
|
||||||
|
**esm3-large-2024-03:**
|
||||||
|
- State-of-the-art quality
|
||||||
|
- Slowest inference
|
||||||
|
- Use for critical applications
|
||||||
|
- Best for novel protein design
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If using ESM3 in research, cite:
|
||||||
|
|
||||||
|
```
|
||||||
|
Hayes, T. et al. (2025). Simulating 500 million years of evolution with a language model.
|
||||||
|
Science. DOI: 10.1126/science.ads0018
|
||||||
|
```
|
||||||
657
scientific-packages/esm/references/forge-api.md
Normal file
657
scientific-packages/esm/references/forge-api.md
Normal file
@@ -0,0 +1,657 @@
|
|||||||
|
# Forge API Reference
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Forge is EvolutionaryScale's cloud platform for scalable protein design and inference. It provides API access to the full ESM3 model family, including large models not available for local execution.
|
||||||
|
|
||||||
|
**Key Benefits:**
|
||||||
|
- Access to all ESM3 models including 98B parameter version
|
||||||
|
- No local GPU requirements
|
||||||
|
- Scalable batch processing
|
||||||
|
- Automatic updates to latest models
|
||||||
|
- Production-ready infrastructure
|
||||||
|
- Async/concurrent request support
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### 1. Obtain API Token
|
||||||
|
|
||||||
|
Sign up and get your API token at: https://forge.evolutionaryscale.ai
|
||||||
|
|
||||||
|
### 2. Install ESM SDK
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install esm
|
||||||
|
```
|
||||||
|
|
||||||
|
The Forge client is included in the standard ESM package.
|
||||||
|
|
||||||
|
### 3. Basic Connection
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Initialize client
|
||||||
|
client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-medium-2024-08",
|
||||||
|
url="https://forge.evolutionaryscale.ai",
|
||||||
|
token="<your-token-here>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
protein = ESMProtein(sequence="MPRT___KEND")
|
||||||
|
result = client.generate(protein, GenerationConfig(track="sequence", num_steps=8))
|
||||||
|
print(result.sequence)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
| Model ID | Parameters | Speed | Quality | Use Case |
|
||||||
|
|----------|-----------|-------|---------|----------|
|
||||||
|
| `esm3-small-2024-08` | 1.4B | Fastest | Good | Rapid prototyping, testing |
|
||||||
|
| `esm3-medium-2024-08` | 7B | Fast | Excellent | Production, most applications |
|
||||||
|
| `esm3-large-2024-03` | 98B | Slower | Best | Research, critical designs |
|
||||||
|
| `esm3-medium-multimer-2024-09` | 7B | Fast | Experimental | Protein complexes |
|
||||||
|
|
||||||
|
**Model Selection Guidelines:**
|
||||||
|
|
||||||
|
- **Development/Testing**: Use `esm3-small-2024-08` for quick iteration
|
||||||
|
- **Production**: Use `esm3-medium-2024-08` for best balance
|
||||||
|
- **Research/Critical**: Use `esm3-large-2024-03` for highest quality
|
||||||
|
- **Complexes**: Use `esm3-medium-multimer-2024-09` (experimental)
|
||||||
|
|
||||||
|
## ESM3ForgeInferenceClient API
|
||||||
|
|
||||||
|
### Initialization
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||||
|
|
||||||
|
# Basic initialization
|
||||||
|
client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-medium-2024-08",
|
||||||
|
token="<your-token>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# With custom URL (for enterprise deployments)
|
||||||
|
client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-medium-2024-08",
|
||||||
|
url="https://custom.forge.instance.com",
|
||||||
|
token="<your-token>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# With timeout configuration
|
||||||
|
client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-medium-2024-08",
|
||||||
|
token="<your-token>",
|
||||||
|
timeout=300 # 5 minutes
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Synchronous Generation
|
||||||
|
|
||||||
|
Standard blocking generation calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Basic generation
|
||||||
|
protein = ESMProtein(sequence="MPRT___KEND")
|
||||||
|
config = GenerationConfig(track="sequence", num_steps=8)
|
||||||
|
|
||||||
|
result = client.generate(protein, config)
|
||||||
|
print(f"Generated: {result.sequence}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Asynchronous Generation
|
||||||
|
|
||||||
|
For concurrent processing of multiple proteins:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
async def generate_many(client, proteins):
|
||||||
|
"""Generate multiple proteins concurrently."""
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for protein in proteins:
|
||||||
|
task = client.async_generate(
|
||||||
|
protein,
|
||||||
|
GenerationConfig(track="sequence", num_steps=8)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
proteins = [
|
||||||
|
ESMProtein(sequence=f"MPRT{'_' * 10}KEND"),
|
||||||
|
ESMProtein(sequence=f"AGLV{'_' * 10}HSPQ"),
|
||||||
|
ESMProtein(sequence=f"KEIT{'_' * 10}NDFL")
|
||||||
|
]
|
||||||
|
|
||||||
|
results = asyncio.run(generate_many(client, proteins))
|
||||||
|
print(f"Generated {len(results)} proteins")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batch Processing with BatchExecutor
|
||||||
|
|
||||||
|
For large-scale processing with automatic concurrency management:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.sdk.forge import BatchExecutor
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
|
||||||
|
# Create batch executor
|
||||||
|
executor = BatchExecutor(
|
||||||
|
client=client,
|
||||||
|
max_concurrent=10 # Process 10 requests concurrently
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare batch of proteins
|
||||||
|
proteins = [ESMProtein(sequence=f"MPRT{'_' * 50}KEND") for _ in range(100)]
|
||||||
|
config = GenerationConfig(track="sequence", num_steps=25)
|
||||||
|
|
||||||
|
# Submit batch
|
||||||
|
batch_results = executor.submit_batch(
|
||||||
|
proteins=proteins,
|
||||||
|
config=config,
|
||||||
|
progress_callback=lambda i, total: print(f"Processed {i}/{total}")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Completed {len(batch_results)} generations")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Rate Limiting and Quotas
|
||||||
|
|
||||||
|
### Understanding Limits
|
||||||
|
|
||||||
|
Forge implements rate limiting based on:
|
||||||
|
- Requests per minute (RPM)
|
||||||
|
- Tokens per minute (TPM)
|
||||||
|
- Concurrent requests
|
||||||
|
|
||||||
|
**Typical Limits (subject to change):**
|
||||||
|
- Free tier: 60 RPM, 5 concurrent
|
||||||
|
- Pro tier: 300 RPM, 20 concurrent
|
||||||
|
- Enterprise: Custom limits
|
||||||
|
|
||||||
|
### Handling Rate Limits
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
def generate_with_retry(client, protein, config, max_retries=3):
|
||||||
|
"""Generate with automatic retry on rate limit."""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return client.generate(protein, config)
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.response.status_code == 429: # Rate limit
|
||||||
|
wait_time = 2 ** attempt # Exponential backoff
|
||||||
|
print(f"Rate limited, waiting {wait_time}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
raise Exception("Max retries exceeded")
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
result = generate_with_retry(client, protein, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementing Custom Rate Limiter
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Simple rate limiter for API calls."""
|
||||||
|
|
||||||
|
def __init__(self, max_per_minute=60):
|
||||||
|
self.max_per_minute = max_per_minute
|
||||||
|
self.calls = deque()
|
||||||
|
|
||||||
|
def wait_if_needed(self):
|
||||||
|
"""Wait if rate limit would be exceeded."""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Remove old calls
|
||||||
|
while self.calls and self.calls[0] < now - 60:
|
||||||
|
self.calls.popleft()
|
||||||
|
|
||||||
|
# Wait if at limit
|
||||||
|
if len(self.calls) >= self.max_per_minute:
|
||||||
|
sleep_time = 60 - (now - self.calls[0])
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
self.calls.popleft()
|
||||||
|
|
||||||
|
self.calls.append(now)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
limiter = RateLimiter(max_per_minute=60)
|
||||||
|
|
||||||
|
for protein in proteins:
|
||||||
|
limiter.wait_if_needed()
|
||||||
|
result = client.generate(protein, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Streaming Results
|
||||||
|
|
||||||
|
Process results as they complete:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
async def stream_generate(client, proteins, config):
|
||||||
|
"""Stream results as they complete."""
|
||||||
|
pending = {
|
||||||
|
asyncio.create_task(client.async_generate(p, config)): i
|
||||||
|
for i, p in enumerate(proteins)
|
||||||
|
}
|
||||||
|
|
||||||
|
results = [None] * len(proteins)
|
||||||
|
|
||||||
|
while pending:
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
pending.keys(),
|
||||||
|
return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
idx = pending.pop(task)
|
||||||
|
result = await task
|
||||||
|
results[idx] = result
|
||||||
|
yield idx, result
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
async def process_stream():
|
||||||
|
async for idx, result in stream_generate(client, proteins, config):
|
||||||
|
print(f"Completed protein {idx}: {result.sequence[:20]}...")
|
||||||
|
|
||||||
|
asyncio.run(process_stream())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batch with Progress Tracking
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tqdm import tqdm
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def batch_with_progress(client, proteins, config):
|
||||||
|
"""Process batch with progress bar."""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
with tqdm(total=len(proteins)) as pbar:
|
||||||
|
for protein in proteins:
|
||||||
|
result = await client.async_generate(protein, config)
|
||||||
|
results.append(result)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
results = asyncio.run(batch_with_progress(client, proteins, config))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Checkpoint and Resume
|
||||||
|
|
||||||
|
For long-running batch jobs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
|
||||||
|
class CheckpointedBatchProcessor:
|
||||||
|
"""Batch processor with checkpoint/resume capability."""
|
||||||
|
|
||||||
|
def __init__(self, client, checkpoint_file="checkpoint.pkl"):
|
||||||
|
self.client = client
|
||||||
|
self.checkpoint_file = checkpoint_file
|
||||||
|
self.completed = self.load_checkpoint()
|
||||||
|
|
||||||
|
def load_checkpoint(self):
|
||||||
|
if os.path.exists(self.checkpoint_file):
|
||||||
|
with open(self.checkpoint_file, 'rb') as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def save_checkpoint(self):
|
||||||
|
with open(self.checkpoint_file, 'wb') as f:
|
||||||
|
pickle.dump(self.completed, f)
|
||||||
|
|
||||||
|
def process_batch(self, proteins, config):
|
||||||
|
"""Process batch with checkpointing."""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for i, protein in enumerate(proteins):
|
||||||
|
# Skip if already completed
|
||||||
|
if i in self.completed:
|
||||||
|
results[i] = self.completed[i]
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.generate(protein, config)
|
||||||
|
results[i] = result
|
||||||
|
self.completed[i] = result
|
||||||
|
|
||||||
|
# Save checkpoint every 10 items
|
||||||
|
if i % 10 == 0:
|
||||||
|
self.save_checkpoint()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {i}: {e}")
|
||||||
|
self.save_checkpoint()
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.save_checkpoint()
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
processor = CheckpointedBatchProcessor(client)
|
||||||
|
results = processor.process_batch(proteins, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Errors and Solutions
|
||||||
|
|
||||||
|
```python
|
||||||
|
from requests.exceptions import HTTPError, ConnectionError, Timeout
|
||||||
|
|
||||||
|
def robust_generate(client, protein, config):
|
||||||
|
"""Generate with comprehensive error handling."""
|
||||||
|
try:
|
||||||
|
return client.generate(protein, config)
|
||||||
|
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
raise ValueError("Invalid API token")
|
||||||
|
elif e.response.status_code == 429:
|
||||||
|
raise ValueError("Rate limit exceeded - slow down requests")
|
||||||
|
elif e.response.status_code == 500:
|
||||||
|
raise ValueError("Server error - try again later")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
except ConnectionError:
|
||||||
|
raise ValueError("Network error - check internet connection")
|
||||||
|
|
||||||
|
except Timeout:
|
||||||
|
raise ValueError("Request timeout - try smaller protein or increase timeout")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
# Usage with retry logic
|
||||||
|
def generate_with_full_retry(client, protein, config, max_retries=3):
|
||||||
|
"""Combine error handling with retry logic."""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return robust_generate(client, protein, config)
|
||||||
|
except ValueError as e:
|
||||||
|
if "rate limit" in str(e).lower() and attempt < max_retries - 1:
|
||||||
|
time.sleep(2 ** attempt)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cost Optimization
|
||||||
|
|
||||||
|
### Strategies to Reduce Costs
|
||||||
|
|
||||||
|
**1. Use Appropriate Model Size:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use smaller model for testing
|
||||||
|
dev_client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-small-2024-08",
|
||||||
|
token=token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use larger model only for final generation
|
||||||
|
prod_client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-large-2024-03",
|
||||||
|
token=token
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Cache Results:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
class ForgeCache:
|
||||||
|
"""Cache Forge API results locally."""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir="forge_cache"):
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def get_cache_key(self, protein, config):
|
||||||
|
"""Generate cache key from inputs."""
|
||||||
|
data = {
|
||||||
|
'sequence': protein.sequence,
|
||||||
|
'config': str(config)
|
||||||
|
}
|
||||||
|
return hashlib.md5(json.dumps(data, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
|
def get(self, protein, config):
|
||||||
|
"""Get cached result."""
|
||||||
|
key = self.get_cache_key(protein, config)
|
||||||
|
path = os.path.join(self.cache_dir, f"{key}.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, protein, config, result):
|
||||||
|
"""Cache result."""
|
||||||
|
key = self.get_cache_key(protein, config)
|
||||||
|
path = os.path.join(self.cache_dir, f"{key}.pkl")
|
||||||
|
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
pickle.dump(result, f)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
cache = ForgeCache()
|
||||||
|
|
||||||
|
def cached_generate(client, protein, config):
|
||||||
|
"""Generate with caching."""
|
||||||
|
cached = cache.get(protein, config)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
result = client.generate(protein, config)
|
||||||
|
cache.set(protein, config, result)
|
||||||
|
return result
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Batch Similar Requests:**
|
||||||
|
|
||||||
|
Group similar generation tasks to reduce overhead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def batch_similar_tasks(proteins, max_batch_size=50):
|
||||||
|
"""Group proteins by similar properties."""
|
||||||
|
# Sort by length for efficient processing
|
||||||
|
sorted_proteins = sorted(proteins, key=lambda p: len(p.sequence))
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
current_batch = []
|
||||||
|
|
||||||
|
for protein in sorted_proteins:
|
||||||
|
current_batch.append(protein)
|
||||||
|
|
||||||
|
if len(current_batch) >= max_batch_size:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = []
|
||||||
|
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring and Logging
|
||||||
|
|
||||||
|
### Track API Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class ForgeMonitor:
|
||||||
|
"""Monitor Forge API usage."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
self.errors = []
|
||||||
|
|
||||||
|
def log_call(self, model, protein_length, duration, success=True, error=None):
|
||||||
|
"""Log API call."""
|
||||||
|
entry = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'model': model,
|
||||||
|
'protein_length': protein_length,
|
||||||
|
'duration': duration,
|
||||||
|
'success': success,
|
||||||
|
'error': str(error) if error else None
|
||||||
|
}
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.calls.append(entry)
|
||||||
|
else:
|
||||||
|
self.errors.append(entry)
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Get usage statistics."""
|
||||||
|
total_calls = len(self.calls) + len(self.errors)
|
||||||
|
success_rate = len(self.calls) / total_calls if total_calls > 0 else 0
|
||||||
|
avg_duration = sum(c['duration'] for c in self.calls) / len(self.calls) if self.calls else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_calls': total_calls,
|
||||||
|
'successful': len(self.calls),
|
||||||
|
'failed': len(self.errors),
|
||||||
|
'success_rate': success_rate,
|
||||||
|
'avg_duration': avg_duration
|
||||||
|
}
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
monitor = ForgeMonitor()
|
||||||
|
|
||||||
|
def monitored_generate(client, protein, config):
|
||||||
|
"""Generate with monitoring."""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = client.generate(protein, config)
|
||||||
|
duration = time.time() - start
|
||||||
|
monitor.log_call(
|
||||||
|
model=client.model,
|
||||||
|
protein_length=len(protein.sequence),
|
||||||
|
duration=duration,
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.time() - start
|
||||||
|
monitor.log_call(
|
||||||
|
model=client.model,
|
||||||
|
protein_length=len(protein.sequence),
|
||||||
|
duration=duration,
|
||||||
|
success=False,
|
||||||
|
error=e
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Check stats
|
||||||
|
print(monitor.get_stats())
|
||||||
|
```
|
||||||
|
|
||||||
|
## AWS SageMaker Deployment
|
||||||
|
|
||||||
|
For dedicated infrastructure and enterprise use:
|
||||||
|
|
||||||
|
### Deployment Options
|
||||||
|
|
||||||
|
1. **AWS Marketplace Listing**: Deploy ESM3 via AWS SageMaker Marketplace
|
||||||
|
2. **Custom Endpoint**: Configure dedicated inference endpoint
|
||||||
|
3. **Batch Transform**: Use SageMaker Batch Transform for large-scale processing
|
||||||
|
|
||||||
|
### Benefits
|
||||||
|
|
||||||
|
- Dedicated compute resources
|
||||||
|
- No rate limiting beyond your infrastructure
|
||||||
|
- Data stays in your AWS environment
|
||||||
|
- Integration with AWS services
|
||||||
|
- Custom instance types and scaling
|
||||||
|
|
||||||
|
**More Information:**
|
||||||
|
- AWS Marketplace: https://aws.amazon.com/marketplace/seller-profile?id=seller-iw2nbscescndm
|
||||||
|
- Contact EvolutionaryScale for enterprise licensing
|
||||||
|
|
||||||
|
## Best Practices Summary
|
||||||
|
|
||||||
|
1. **Authentication**: Store tokens securely (environment variables, secrets manager)
|
||||||
|
2. **Rate Limiting**: Implement exponential backoff and respect limits
|
||||||
|
3. **Error Handling**: Always handle network errors and retries
|
||||||
|
4. **Caching**: Cache results for repeated queries
|
||||||
|
5. **Model Selection**: Use appropriate model size for task
|
||||||
|
6. **Batch Processing**: Use async/batch processing for multiple proteins
|
||||||
|
7. **Monitoring**: Track usage and costs
|
||||||
|
8. **Checkpointing**: Save progress for long-running jobs
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Connection Issues
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Test connection
|
||||||
|
try:
|
||||||
|
client = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", token=token)
|
||||||
|
test_protein = ESMProtein(sequence="MPRTK")
|
||||||
|
result = client.generate(test_protein, GenerationConfig(track="sequence", num_steps=1))
|
||||||
|
print("Connection successful!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Connection failed: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Token Validation
|
||||||
|
|
||||||
|
```python
|
||||||
|
def validate_token(token):
|
||||||
|
"""Validate API token."""
|
||||||
|
try:
|
||||||
|
client = ESM3ForgeInferenceClient(
|
||||||
|
model="esm3-small-2024-08",
|
||||||
|
token=token
|
||||||
|
)
|
||||||
|
# Make minimal test call
|
||||||
|
test = ESMProtein(sequence="MPR")
|
||||||
|
client.generate(test, GenerationConfig(track="sequence", num_steps=1))
|
||||||
|
return True
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
return False
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- **Forge Platform**: https://forge.evolutionaryscale.ai
|
||||||
|
- **API Documentation**: Check Forge dashboard for latest API specs
|
||||||
|
- **Community Support**: Slack community at https://bit.ly/3FKwcWd
|
||||||
|
- **Enterprise Contact**: Contact EvolutionaryScale for custom deployments
|
||||||
685
scientific-packages/esm/references/workflows.md
Normal file
685
scientific-packages/esm/references/workflows.md
Normal file
@@ -0,0 +1,685 @@
|
|||||||
|
# ESM Workflows and Examples
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document provides complete, end-to-end examples of common workflows using ESM3 and ESM C. Each workflow includes setup, execution, and analysis code.
|
||||||
|
|
||||||
|
## Workflow 1: Novel GFP Design with Chain-of-Thought
|
||||||
|
|
||||||
|
Design a novel fluorescent protein using ESM3's multimodal generation capabilities.
|
||||||
|
|
||||||
|
### Objective
|
||||||
|
|
||||||
|
Generate a green fluorescent protein (GFP) with specific properties using chain-of-thought reasoning across sequence, structure, and function.
|
||||||
|
|
||||||
|
### Complete Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig, FunctionAnnotation
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
|
||||||
|
# Step 1: Define target properties
|
||||||
|
print("Step 1: Defining target GFP properties...")
|
||||||
|
|
||||||
|
# Create protein with desired function
|
||||||
|
target_length = 238 # Typical GFP length
|
||||||
|
protein = ESMProtein(
|
||||||
|
sequence="_" * target_length,
|
||||||
|
function_annotations=[
|
||||||
|
FunctionAnnotation(
|
||||||
|
label="green_fluorescent_protein",
|
||||||
|
start=65,
|
||||||
|
end=75 # Chromophore region
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Generate initial sequence with function conditioning
|
||||||
|
print("Step 2: Generating initial sequence...")
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=target_length // 3, # Gradual generation
|
||||||
|
temperature=0.7 # Moderate diversity
|
||||||
|
)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
print(f"Generated sequence: {protein.sequence[:50]}...")
|
||||||
|
|
||||||
|
# Step 3: Predict structure
|
||||||
|
print("Step 3: Predicting structure...")
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="structure",
|
||||||
|
num_steps=target_length // 2
|
||||||
|
)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
print(f"Structure predicted, coordinates shape: {protein.coordinates.shape}")
|
||||||
|
|
||||||
|
# Step 4: Refine sequence based on structure
|
||||||
|
print("Step 4: Refining sequence based on structure...")
|
||||||
|
|
||||||
|
# Mask regions for refinement (e.g., surface residues)
|
||||||
|
sequence_list = list(protein.sequence)
|
||||||
|
# Keep chromophore region, refine others
|
||||||
|
for i in range(0, 65):
|
||||||
|
if i % 3 == 0: # Refine every third position
|
||||||
|
sequence_list[i] = '_'
|
||||||
|
for i in range(75, target_length):
|
||||||
|
if i % 3 == 0:
|
||||||
|
sequence_list[i] = '_'
|
||||||
|
|
||||||
|
protein.sequence = ''.join(sequence_list)
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=50,
|
||||||
|
temperature=0.5 # Lower temperature for refinement
|
||||||
|
)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
|
||||||
|
# Step 5: Final validation
|
||||||
|
print("Step 5: Final validation...")
|
||||||
|
|
||||||
|
# Predict final structure
|
||||||
|
config = GenerationConfig(track="structure", num_steps=30)
|
||||||
|
protein = model.generate(protein, config)
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
with open("novel_gfp.pdb", "w") as f:
|
||||||
|
f.write(protein.to_pdb())
|
||||||
|
|
||||||
|
with open("novel_gfp_sequence.txt", "w") as f:
|
||||||
|
f.write(f">Novel_GFP\n{protein.sequence}\n")
|
||||||
|
|
||||||
|
print(f"\nFinal GFP sequence:\n{protein.sequence}")
|
||||||
|
print(f"\nFunction annotations: {protein.function_annotations}")
|
||||||
|
print(f"Structure saved to: novel_gfp.pdb")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Validation Steps
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Analyze designed GFP
|
||||||
|
def analyze_gfp(protein):
|
||||||
|
"""Analyze generated GFP properties."""
|
||||||
|
|
||||||
|
# Check chromophore region (should be around Ser65-Tyr66-Gly67)
|
||||||
|
chromophore_region = protein.sequence[64:68]
|
||||||
|
print(f"Chromophore region: {chromophore_region}")
|
||||||
|
|
||||||
|
# Check barrel structure (GFPs have beta-barrel)
|
||||||
|
# Analyze secondary structure if available
|
||||||
|
if protein.secondary_structure:
|
||||||
|
beta_content = protein.secondary_structure.count('E') / len(protein.sequence)
|
||||||
|
print(f"Beta sheet content: {beta_content:.2%}")
|
||||||
|
|
||||||
|
# Check sequence similarity to known GFPs
|
||||||
|
# (Would require BLAST or alignment tool in practice)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'length': len(protein.sequence),
|
||||||
|
'chromophore': chromophore_region,
|
||||||
|
'coordinates_available': protein.coordinates is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis = analyze_gfp(protein)
|
||||||
|
print(f"\nAnalysis results: {analysis}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow 2: Protein Variant Library Generation
|
||||||
|
|
||||||
|
Generate and analyze a library of protein variants for directed evolution.
|
||||||
|
|
||||||
|
### Objective
|
||||||
|
|
||||||
|
Create variants of a parent protein by targeted mutagenesis while maintaining structural integrity.
|
||||||
|
|
||||||
|
### Complete Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
|
||||||
|
# Parent protein
|
||||||
|
parent_sequence = "MPRTKEINDAGLIVHSPQWFYKARNDTESLGKIVHEFPM"
|
||||||
|
parent_protein = ESMProtein(sequence=parent_sequence)
|
||||||
|
|
||||||
|
# Define mutation parameters
|
||||||
|
num_variants = 50
|
||||||
|
positions_to_mutate = 5 # Number of positions per variant
|
||||||
|
|
||||||
|
# Step 1: Generate variant library
|
||||||
|
print("Generating variant library...")
|
||||||
|
|
||||||
|
variants = []
|
||||||
|
for i in range(num_variants):
|
||||||
|
# Create masked sequence with random positions
|
||||||
|
seq_list = list(parent_sequence)
|
||||||
|
|
||||||
|
# Select random positions to mutate
|
||||||
|
mutation_positions = np.random.choice(
|
||||||
|
len(seq_list),
|
||||||
|
size=positions_to_mutate,
|
||||||
|
replace=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for pos in mutation_positions:
|
||||||
|
seq_list[pos] = '_'
|
||||||
|
|
||||||
|
# Generate variant
|
||||||
|
variant_protein = ESMProtein(sequence=''.join(seq_list))
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=positions_to_mutate * 2,
|
||||||
|
temperature=0.8 # Higher diversity
|
||||||
|
)
|
||||||
|
|
||||||
|
variant = model.generate(variant_protein, config)
|
||||||
|
variants.append(variant.sequence)
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
print(f"Generated {i + 1}/{num_variants} variants")
|
||||||
|
|
||||||
|
print(f"\nGenerated {len(variants)} variants")
|
||||||
|
|
||||||
|
# Step 2: Predict structures for variants
|
||||||
|
print("\nPredicting structures...")
|
||||||
|
|
||||||
|
variant_proteins_with_structure = []
|
||||||
|
for i, seq in enumerate(variants):
|
||||||
|
protein = ESMProtein(sequence=seq)
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="structure",
|
||||||
|
num_steps=len(seq) // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
protein_with_structure = model.generate(protein, config)
|
||||||
|
variant_proteins_with_structure.append(protein_with_structure)
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
print(f"Predicted structures for {i + 1}/{len(variants)} variants")
|
||||||
|
|
||||||
|
# Step 3: Analyze variant diversity
|
||||||
|
print("\nAnalyzing variant diversity...")
|
||||||
|
|
||||||
|
# Calculate Hamming distances from parent
|
||||||
|
def hamming_distance(seq1, seq2):
|
||||||
|
"""Calculate Hamming distance between sequences."""
|
||||||
|
return sum(c1 != c2 for c1, c2 in zip(seq1, seq2))
|
||||||
|
|
||||||
|
distances = [hamming_distance(parent_sequence, var) for var in variants]
|
||||||
|
print(f"Average mutations per variant: {np.mean(distances):.1f}")
|
||||||
|
print(f"Mutation range: {min(distances)}-{max(distances)}")
|
||||||
|
|
||||||
|
# Step 4: Get embeddings for clustering
|
||||||
|
print("\nGenerating embeddings for clustering...")
|
||||||
|
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
|
||||||
|
embedding_model = ESMC.from_pretrained("esmc-300m").to("cuda")
|
||||||
|
|
||||||
|
def get_embedding(sequence):
|
||||||
|
"""Get mean-pooled embedding for sequence."""
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
tensor = embedding_model.encode(protein)
|
||||||
|
emb = embedding_model.forward(tensor)
|
||||||
|
return emb.mean(dim=1).cpu().detach().numpy().flatten()
|
||||||
|
|
||||||
|
variant_embeddings = np.array([get_embedding(seq) for seq in variants])
|
||||||
|
|
||||||
|
# Step 5: Cluster variants
|
||||||
|
print("Clustering variants...")
|
||||||
|
|
||||||
|
n_clusters = 5
|
||||||
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
|
||||||
|
cluster_labels = kmeans.fit_predict(variant_embeddings)
|
||||||
|
|
||||||
|
# Analyze clusters
|
||||||
|
print("\nCluster analysis:")
|
||||||
|
for i in range(n_clusters):
|
||||||
|
cluster_variants = [var for var, label in zip(variants, cluster_labels) if label == i]
|
||||||
|
cluster_distances = [hamming_distance(parent_sequence, var) for var in cluster_variants]
|
||||||
|
|
||||||
|
print(f"\nCluster {i}:")
|
||||||
|
print(f" Size: {len(cluster_variants)}")
|
||||||
|
print(f" Avg distance from parent: {np.mean(cluster_distances):.1f}")
|
||||||
|
print(f" Representative: {cluster_variants[0][:40]}...")
|
||||||
|
|
||||||
|
# Step 6: Select diverse representatives
|
||||||
|
print("\nSelecting diverse representatives...")
|
||||||
|
|
||||||
|
representatives = []
|
||||||
|
for i in range(n_clusters):
|
||||||
|
# Get centroid
|
||||||
|
cluster_indices = np.where(cluster_labels == i)[0]
|
||||||
|
cluster_embs = variant_embeddings[cluster_indices]
|
||||||
|
|
||||||
|
# Find closest to centroid
|
||||||
|
centroid = cluster_embs.mean(axis=0)
|
||||||
|
distances_to_centroid = np.linalg.norm(cluster_embs - centroid, axis=1)
|
||||||
|
rep_idx = cluster_indices[np.argmin(distances_to_centroid)]
|
||||||
|
|
||||||
|
representatives.append(variants[rep_idx])
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
print("\nSaving results...")
|
||||||
|
|
||||||
|
with open("variant_library.fasta", "w") as f:
|
||||||
|
f.write(f">Parent\n{parent_sequence}\n\n")
|
||||||
|
for i, var in enumerate(variants):
|
||||||
|
f.write(f">Variant_{i+1}_Cluster_{cluster_labels[i]}\n{var}\n")
|
||||||
|
|
||||||
|
with open("representative_variants.fasta", "w") as f:
|
||||||
|
for i, rep in enumerate(representatives):
|
||||||
|
f.write(f">Representative_Cluster_{i}\n{rep}\n")
|
||||||
|
|
||||||
|
print("Variant library saved to: variant_library.fasta")
|
||||||
|
print("Representatives saved to: representative_variants.fasta")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow 3: Structure-Based Sequence Optimization
|
||||||
|
|
||||||
|
Optimize a protein sequence to improve stability while maintaining function.
|
||||||
|
|
||||||
|
### Objective
|
||||||
|
|
||||||
|
Given a protein structure, design sequences that maintain the fold but have improved properties.
|
||||||
|
|
||||||
|
### Complete Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
|
||||||
|
# Load target structure (e.g., from PDB)
|
||||||
|
target_protein = ESMProtein.from_pdb("target_structure.pdb")
|
||||||
|
original_sequence = target_protein.sequence
|
||||||
|
|
||||||
|
print(f"Original sequence: {original_sequence}")
|
||||||
|
print(f"Structure loaded: {target_protein.coordinates.shape}")
|
||||||
|
|
||||||
|
# Step 1: Generate multiple sequence designs
|
||||||
|
print("\nGenerating optimized sequences...")
|
||||||
|
|
||||||
|
num_designs = 20
|
||||||
|
optimized_sequences = []
|
||||||
|
|
||||||
|
for i in range(num_designs):
|
||||||
|
# Start with structure, remove sequence
|
||||||
|
design_protein = ESMProtein(
|
||||||
|
coordinates=target_protein.coordinates.copy(),
|
||||||
|
secondary_structure=target_protein.secondary_structure
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence for this structure
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="sequence",
|
||||||
|
num_steps=len(original_sequence),
|
||||||
|
temperature=0.7,
|
||||||
|
condition_on_coordinates_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
designed = model.generate(design_protein, config)
|
||||||
|
optimized_sequences.append(designed.sequence)
|
||||||
|
|
||||||
|
if (i + 1) % 5 == 0:
|
||||||
|
print(f"Generated {i + 1}/{num_designs} designs")
|
||||||
|
|
||||||
|
# Step 2: Validate structural compatibility
|
||||||
|
print("\nValidating structural compatibility...")
|
||||||
|
|
||||||
|
validated_designs = []
|
||||||
|
|
||||||
|
for seq in optimized_sequences:
|
||||||
|
# Predict structure for designed sequence
|
||||||
|
test_protein = ESMProtein(sequence=seq)
|
||||||
|
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="structure",
|
||||||
|
num_steps=len(seq) // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted = model.generate(test_protein, config)
|
||||||
|
|
||||||
|
# Calculate RMSD (simplified - in practice use proper alignment)
|
||||||
|
# Here we just check if structure prediction succeeds
|
||||||
|
if predicted.coordinates is not None:
|
||||||
|
validated_designs.append(seq)
|
||||||
|
|
||||||
|
print(f"Validated {len(validated_designs)}/{num_designs} designs")
|
||||||
|
|
||||||
|
# Step 3: Analyze sequence properties
|
||||||
|
print("\nAnalyzing sequence properties...")
|
||||||
|
|
||||||
|
def calculate_properties(sequence):
|
||||||
|
"""Calculate basic sequence properties."""
|
||||||
|
# Hydrophobicity (simplified)
|
||||||
|
hydrophobic = "AILMFWYV"
|
||||||
|
hydrophobic_fraction = sum(1 for aa in sequence if aa in hydrophobic) / len(sequence)
|
||||||
|
|
||||||
|
# Charge
|
||||||
|
positive = "KR"
|
||||||
|
negative = "DE"
|
||||||
|
net_charge = sum(1 for aa in sequence if aa in positive) - sum(1 for aa in sequence if aa in negative)
|
||||||
|
|
||||||
|
# Aromatic content
|
||||||
|
aromatic = "FWY"
|
||||||
|
aromatic_fraction = sum(1 for aa in sequence if aa in aromatic) / len(sequence)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'hydrophobic_fraction': hydrophobic_fraction,
|
||||||
|
'net_charge': net_charge,
|
||||||
|
'aromatic_fraction': aromatic_fraction
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compare to original
|
||||||
|
original_props = calculate_properties(original_sequence)
|
||||||
|
print(f"\nOriginal properties:")
|
||||||
|
print(f" Hydrophobic: {original_props['hydrophobic_fraction']:.2%}")
|
||||||
|
print(f" Net charge: {original_props['net_charge']:+d}")
|
||||||
|
print(f" Aromatic: {original_props['aromatic_fraction']:.2%}")
|
||||||
|
|
||||||
|
# Analyze designs
|
||||||
|
design_properties = [calculate_properties(seq) for seq in validated_designs]
|
||||||
|
|
||||||
|
avg_hydrophobic = np.mean([p['hydrophobic_fraction'] for p in design_properties])
|
||||||
|
avg_charge = np.mean([p['net_charge'] for p in design_properties])
|
||||||
|
avg_aromatic = np.mean([p['aromatic_fraction'] for p in design_properties])
|
||||||
|
|
||||||
|
print(f"\nDesigned sequences (average):")
|
||||||
|
print(f" Hydrophobic: {avg_hydrophobic:.2%}")
|
||||||
|
print(f" Net charge: {avg_charge:+.1f}")
|
||||||
|
print(f" Aromatic: {avg_aromatic:.2%}")
|
||||||
|
|
||||||
|
# Step 4: Rank designs
|
||||||
|
print("\nRanking designs...")
|
||||||
|
|
||||||
|
def score_design(sequence, original_props):
|
||||||
|
"""Score design based on desired properties."""
|
||||||
|
props = calculate_properties(sequence)
|
||||||
|
|
||||||
|
# Prefer higher hydrophobic content (for stability)
|
||||||
|
hydrophobic_score = props['hydrophobic_fraction']
|
||||||
|
|
||||||
|
# Prefer similar charge to original
|
||||||
|
charge_score = 1.0 / (1.0 + abs(props['net_charge'] - original_props['net_charge']))
|
||||||
|
|
||||||
|
# Combined score
|
||||||
|
return hydrophobic_score * 0.6 + charge_score * 0.4
|
||||||
|
|
||||||
|
scores = [(seq, score_design(seq, original_props)) for seq in validated_designs]
|
||||||
|
scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
print("\nTop 5 designs:")
|
||||||
|
for i, (seq, score) in enumerate(scores[:5]):
|
||||||
|
print(f"\n{i+1}. Score: {score:.3f}")
|
||||||
|
print(f" Sequence: {seq[:40]}...")
|
||||||
|
|
||||||
|
# Step 5: Save results
|
||||||
|
print("\nSaving results...")
|
||||||
|
|
||||||
|
with open("optimized_sequences.fasta", "w") as f:
|
||||||
|
f.write(f">Original\n{original_sequence}\n\n")
|
||||||
|
|
||||||
|
for i, (seq, score) in enumerate(scores):
|
||||||
|
props = calculate_properties(seq)
|
||||||
|
f.write(f">Design_{i+1}_Score_{score:.3f}\n")
|
||||||
|
f.write(f"# Hydrophobic: {props['hydrophobic_fraction']:.2%}, ")
|
||||||
|
f.write(f"Charge: {props['net_charge']:+d}, ")
|
||||||
|
f.write(f"Aromatic: {props['aromatic_fraction']:.2%}\n")
|
||||||
|
f.write(f"{seq}\n\n")
|
||||||
|
|
||||||
|
print("Results saved to: optimized_sequences.fasta")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow 4: Function Prediction Pipeline
|
||||||
|
|
||||||
|
Predict protein function from sequence using ESM3 and ESM C.
|
||||||
|
|
||||||
|
### Objective
|
||||||
|
|
||||||
|
Build a pipeline that predicts protein function using both generative (ESM3) and embedding (ESM C) approaches.
|
||||||
|
|
||||||
|
### Complete Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esm3 import ESM3
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
from esm.sdk.api import ESMProtein, GenerationConfig
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.model_selection import cross_val_score
|
||||||
|
|
||||||
|
# Setup models
|
||||||
|
esm3_model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda")
|
||||||
|
esmc_model = ESMC.from_pretrained("esmc-600m").to("cuda")
|
||||||
|
|
||||||
|
# Example: Predict if protein is an enzyme
|
||||||
|
# (In practice, you'd have a labeled training set)
|
||||||
|
|
||||||
|
def predict_function_generative(sequence):
|
||||||
|
"""Predict function using ESM3 generative approach."""
|
||||||
|
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
|
||||||
|
# Generate function annotations
|
||||||
|
config = GenerationConfig(
|
||||||
|
track="function",
|
||||||
|
num_steps=20,
|
||||||
|
temperature=0.3 # Low temperature for confident predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
protein_with_function = esm3_model.generate(protein, config)
|
||||||
|
|
||||||
|
return protein_with_function.function_annotations
|
||||||
|
|
||||||
|
def predict_function_embedding(sequence, function_classifier):
|
||||||
|
"""Predict function using ESM C embeddings + classifier."""
|
||||||
|
|
||||||
|
# Get embedding
|
||||||
|
protein = ESMProtein(sequence=sequence)
|
||||||
|
tensor = esmc_model.encode(protein)
|
||||||
|
embedding = esmc_model.forward(tensor)
|
||||||
|
|
||||||
|
# Mean pool
|
||||||
|
embedding_pooled = embedding.mean(dim=1).cpu().detach().numpy()
|
||||||
|
|
||||||
|
# Predict with classifier
|
||||||
|
prediction = function_classifier.predict(embedding_pooled)
|
||||||
|
probability = function_classifier.predict_proba(embedding_pooled)
|
||||||
|
|
||||||
|
return prediction[0], probability[0]
|
||||||
|
|
||||||
|
# Example workflow with test sequences
|
||||||
|
test_sequences = {
|
||||||
|
"kinase": "MPRTKEINDAGLIVHSPQWFYKARNDTESLGKIVHEF",
|
||||||
|
"protease": "AGLIVHSPQWFYKARNDTESLGKIVHEFPMCDEGH",
|
||||||
|
"transporter": "KTEFLNDGRPMLIVHSPQWFYKARNDTESLGKIVH"
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Predicting functions...\n")
|
||||||
|
|
||||||
|
for name, sequence in test_sequences.items():
|
||||||
|
print(f"{name.upper()}:")
|
||||||
|
print(f"Sequence: {sequence[:30]}...")
|
||||||
|
|
||||||
|
# Method 1: Generative
|
||||||
|
functions = predict_function_generative(sequence)
|
||||||
|
print(f" Generative predictions: {functions}")
|
||||||
|
|
||||||
|
# Method 2: Embedding-based would require trained classifier
|
||||||
|
# (Skipped in this example as it needs training data)
|
||||||
|
|
||||||
|
print()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow 5: Embedding-Based Clustering and Analysis
|
||||||
|
|
||||||
|
Cluster and analyze a large protein dataset using ESM C embeddings.
|
||||||
|
|
||||||
|
### Complete Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from esm.models.esmc import ESMC
|
||||||
|
from esm.sdk.api import ESMProtein
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
model = ESMC.from_pretrained("esmc-600m").to("cuda")
|
||||||
|
|
||||||
|
# Load protein dataset (example)
|
||||||
|
sequences = [
|
||||||
|
# In practice, load from FASTA or database
|
||||||
|
"MPRTKEINDAGLIVHSPQWFYK",
|
||||||
|
"AGLIVHSPQWFYKARNDTESL",
|
||||||
|
# ... more sequences
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Loaded {len(sequences)} sequences")
|
||||||
|
|
||||||
|
# Step 1: Generate embeddings
|
||||||
|
print("Generating embeddings...")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for i, seq in enumerate(sequences):
|
||||||
|
protein = ESMProtein(sequence=seq)
|
||||||
|
tensor = model.encode(protein)
|
||||||
|
emb = model.forward(tensor)
|
||||||
|
|
||||||
|
# Mean pooling
|
||||||
|
emb_pooled = emb.mean(dim=1).cpu().detach().numpy().flatten()
|
||||||
|
embeddings.append(emb_pooled)
|
||||||
|
|
||||||
|
if (i + 1) % 100 == 0:
|
||||||
|
print(f"Processed {i + 1}/{len(sequences)}")
|
||||||
|
|
||||||
|
embeddings = np.array(embeddings)
|
||||||
|
print(f"Embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Step 2: Dimensionality reduction for visualization
|
||||||
|
print("\nReducing dimensionality...")
|
||||||
|
|
||||||
|
# PCA for initial reduction
|
||||||
|
pca = PCA(n_components=50)
|
||||||
|
embeddings_pca = pca.fit_transform(embeddings)
|
||||||
|
print(f"PCA explained variance: {pca.explained_variance_ratio_[:10].sum():.2%}")
|
||||||
|
|
||||||
|
# t-SNE for visualization
|
||||||
|
tsne = TSNE(n_components=2, random_state=42)
|
||||||
|
embeddings_2d = tsne.fit_transform(embeddings_pca)
|
||||||
|
|
||||||
|
# Step 3: Clustering
|
||||||
|
print("\nClustering...")
|
||||||
|
|
||||||
|
# DBSCAN for density-based clustering
|
||||||
|
clustering = DBSCAN(eps=0.5, min_samples=5)
|
||||||
|
cluster_labels = clustering.fit_predict(embeddings)
|
||||||
|
|
||||||
|
n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
|
||||||
|
n_noise = list(cluster_labels).count(-1)
|
||||||
|
|
||||||
|
print(f"Number of clusters: {n_clusters}")
|
||||||
|
print(f"Number of noise points: {n_noise}")
|
||||||
|
|
||||||
|
# Step 4: Visualize
|
||||||
|
print("\nGenerating visualization...")
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
scatter = plt.scatter(
|
||||||
|
embeddings_2d[:, 0],
|
||||||
|
embeddings_2d[:, 1],
|
||||||
|
c=cluster_labels,
|
||||||
|
cmap='viridis',
|
||||||
|
alpha=0.6
|
||||||
|
)
|
||||||
|
plt.colorbar(scatter)
|
||||||
|
plt.title("Protein Sequence Clustering (ESM C Embeddings)")
|
||||||
|
plt.xlabel("t-SNE 1")
|
||||||
|
plt.ylabel("t-SNE 2")
|
||||||
|
plt.savefig("protein_clusters.png", dpi=300, bbox_inches='tight')
|
||||||
|
print("Visualization saved to: protein_clusters.png")
|
||||||
|
|
||||||
|
# Step 5: Analyze clusters
|
||||||
|
print("\nCluster analysis:")
|
||||||
|
|
||||||
|
for cluster_id in range(n_clusters):
|
||||||
|
cluster_indices = np.where(cluster_labels == cluster_id)[0]
|
||||||
|
cluster_seqs = [sequences[i] for i in cluster_indices]
|
||||||
|
|
||||||
|
print(f"\nCluster {cluster_id}:")
|
||||||
|
print(f" Size: {len(cluster_seqs)}")
|
||||||
|
print(f" Avg length: {np.mean([len(s) for s in cluster_seqs]):.1f}")
|
||||||
|
print(f" Example: {cluster_seqs[0][:40]}...")
|
||||||
|
|
||||||
|
# Save cluster assignments
|
||||||
|
with open("cluster_assignments.txt", "w") as f:
|
||||||
|
for i, (seq, label) in enumerate(zip(sequences, cluster_labels)):
|
||||||
|
f.write(f"Sequence_{i}\tCluster_{label}\t{seq}\n")
|
||||||
|
|
||||||
|
print("\nCluster assignments saved to: cluster_assignments.txt")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Workflow Tips
|
||||||
|
|
||||||
|
### Memory Management for Large Datasets
|
||||||
|
|
||||||
|
```python
|
||||||
|
def process_large_dataset(sequences, batch_size=32):
|
||||||
|
"""Process large dataset with memory management."""
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in range(0, len(sequences), batch_size):
|
||||||
|
batch = sequences[i:i + batch_size]
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
batch_results = [process_sequence(seq) for seq in batch]
|
||||||
|
results.extend(batch_results)
|
||||||
|
|
||||||
|
# Clear memory
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if (i + batch_size) % 100 == 0:
|
||||||
|
print(f"Processed {min(i + batch_size, len(sequences))}/{len(sequences)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parallel Processing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def parallel_workflow(sequences, n_workers=4):
|
||||||
|
"""Process sequences in parallel."""
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
||||||
|
results = list(executor.map(process_sequence, sequences))
|
||||||
|
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
These workflows provide comprehensive examples for common ESM use cases. Adapt them to your specific needs and always validate results with appropriate biological experiments.
|
||||||
@@ -0,0 +1,438 @@
|
|||||||
|
# Theoretical Foundations of scvi-tools
|
||||||
|
|
||||||
|
This document explains the mathematical and statistical principles underlying scvi-tools.
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
### Variational Inference
|
||||||
|
|
||||||
|
**What is it?**
|
||||||
|
Variational inference is a technique for approximating complex probability distributions. In single-cell analysis, we want to understand the posterior distribution p(z|x) - the probability of latent variables z given observed data x.
|
||||||
|
|
||||||
|
**Why use it?**
|
||||||
|
- Exact inference is computationally intractable for complex models
|
||||||
|
- Scales to large datasets (millions of cells)
|
||||||
|
- Provides uncertainty quantification
|
||||||
|
- Enables Bayesian reasoning about cell states
|
||||||
|
|
||||||
|
**How does it work?**
|
||||||
|
1. Define a simpler approximate distribution q(z|x) with learnable parameters
|
||||||
|
2. Minimize the KL divergence between q(z|x) and true posterior p(z|x)
|
||||||
|
3. Equivalent to maximizing the Evidence Lower Bound (ELBO)
|
||||||
|
|
||||||
|
**ELBO Objective**:
|
||||||
|
```
|
||||||
|
ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))
|
||||||
|
↑ ↑
|
||||||
|
Reconstruction Regularization
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Reconstruction term**: Model should generate data similar to observed
|
||||||
|
- **Regularization term**: Latent representation should match prior
|
||||||
|
|
||||||
|
### Variational Autoencoders (VAEs)
|
||||||
|
|
||||||
|
**Architecture**:
|
||||||
|
```
|
||||||
|
x (observed data)
|
||||||
|
↓
|
||||||
|
[Encoder Neural Network]
|
||||||
|
↓
|
||||||
|
z (latent representation)
|
||||||
|
↓
|
||||||
|
[Decoder Neural Network]
|
||||||
|
↓
|
||||||
|
x̂ (reconstructed data)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Encoder**: Maps cells (x) to latent space (z)
|
||||||
|
- Learns q(z|x), the approximate posterior
|
||||||
|
- Parameterized by neural network with learnable weights
|
||||||
|
- Outputs mean and variance of latent distribution
|
||||||
|
|
||||||
|
**Decoder**: Maps latent space (z) back to gene space
|
||||||
|
- Learns p(x|z), the likelihood
|
||||||
|
- Generates gene expression from latent representation
|
||||||
|
- Models count distributions (Negative Binomial, Zero-Inflated NB)
|
||||||
|
|
||||||
|
**Reparameterization Trick**:
|
||||||
|
- Allows backpropagation through stochastic sampling
|
||||||
|
- Sample z = μ + σ ⊙ ε, where ε ~ N(0,1)
|
||||||
|
- Enables end-to-end training with gradient descent
|
||||||
|
|
||||||
|
### Amortized Inference
|
||||||
|
|
||||||
|
**Concept**: Share encoder parameters across all cells.
|
||||||
|
|
||||||
|
**Traditional inference**: Learn separate latent variables for each cell
|
||||||
|
- n_cells × n_latent parameters
|
||||||
|
- Doesn't scale to large datasets
|
||||||
|
|
||||||
|
**Amortized inference**: Learn single encoder for all cells
|
||||||
|
- Fixed number of parameters regardless of cell count
|
||||||
|
- Enables fast inference on new cells
|
||||||
|
- Transfers learned patterns across dataset
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Scalable to millions of cells
|
||||||
|
- Fast inference on query data
|
||||||
|
- Leverages shared structure across cells
|
||||||
|
- Enables few-shot learning
|
||||||
|
|
||||||
|
## Statistical Modeling
|
||||||
|
|
||||||
|
### Count Data Distributions
|
||||||
|
|
||||||
|
Single-cell data are counts (integer-valued), requiring appropriate distributions.
|
||||||
|
|
||||||
|
#### Negative Binomial (NB)
|
||||||
|
```
|
||||||
|
x ~ NB(μ, θ)
|
||||||
|
```
|
||||||
|
- **μ (mean)**: Expected expression level
|
||||||
|
- **θ (dispersion)**: Controls variance
|
||||||
|
- **Variance**: Var(x) = μ + μ²/θ
|
||||||
|
|
||||||
|
**When to use**: Gene expression without zero-inflation
|
||||||
|
- More flexible than Poisson (allows overdispersion)
|
||||||
|
- Models technical and biological variation
|
||||||
|
|
||||||
|
#### Zero-Inflated Negative Binomial (ZINB)
|
||||||
|
```
|
||||||
|
x ~ π·δ₀ + (1-π)·NB(μ, θ)
|
||||||
|
```
|
||||||
|
- **π (dropout rate)**: Probability of technical zero
|
||||||
|
- **δ₀**: Point mass at zero
|
||||||
|
- **NB(μ, θ)**: Expression when not dropped out
|
||||||
|
|
||||||
|
**When to use**: Sparse scRNA-seq data
|
||||||
|
- Models technical dropout separately from biological zeros
|
||||||
|
- Better fit for highly sparse data (e.g., 10x data)
|
||||||
|
|
||||||
|
#### Poisson
|
||||||
|
```
|
||||||
|
x ~ Poisson(μ)
|
||||||
|
```
|
||||||
|
- Simplest count distribution
|
||||||
|
- Mean equals variance: Var(x) = μ
|
||||||
|
|
||||||
|
**When to use**: Less common; ATAC-seq fragment counts
|
||||||
|
- More restrictive than NB
|
||||||
|
- Faster computation
|
||||||
|
|
||||||
|
### Batch Correction Framework
|
||||||
|
|
||||||
|
**Problem**: Technical variation confounds biological signal
|
||||||
|
- Different sequencing runs, protocols, labs
|
||||||
|
- Must remove technical effects while preserving biology
|
||||||
|
|
||||||
|
**scvi-tools approach**:
|
||||||
|
1. Encode batch as categorical variable s
|
||||||
|
2. Include s in generative model
|
||||||
|
3. Latent space z is batch-invariant
|
||||||
|
4. Decoder conditions on s for batch-specific effects
|
||||||
|
|
||||||
|
**Mathematical formulation**:
|
||||||
|
```
|
||||||
|
Encoder: q(z|x, s) - batch-aware encoding
|
||||||
|
Latent: z - batch-corrected representation
|
||||||
|
Decoder: p(x|z, s) - batch-specific decoding
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key insight**: Batch info flows through decoder, not latent space
|
||||||
|
- z captures biological variation
|
||||||
|
- s explains technical variation
|
||||||
|
- Separable biology and batch effects
|
||||||
|
|
||||||
|
### Deep Generative Modeling
|
||||||
|
|
||||||
|
**Generative model**: Learns p(x), the data distribution
|
||||||
|
|
||||||
|
**Process**:
|
||||||
|
1. Sample latent variable: z ~ p(z) = N(0, I)
|
||||||
|
2. Generate expression: x ~ p(x|z)
|
||||||
|
3. Joint distribution: p(x, z) = p(x|z)p(z)
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Generate synthetic cells
|
||||||
|
- Impute missing values
|
||||||
|
- Quantify uncertainty
|
||||||
|
- Perform counterfactual predictions
|
||||||
|
|
||||||
|
**Inference network**: Inverts generative process
|
||||||
|
- Given x, infer z
|
||||||
|
- q(z|x) approximates true posterior p(z|x)
|
||||||
|
|
||||||
|
## Model Architecture Details
|
||||||
|
|
||||||
|
### scVI Architecture
|
||||||
|
|
||||||
|
**Input**: Gene expression counts x ∈ ℕ^G (G genes)
|
||||||
|
|
||||||
|
**Encoder**:
|
||||||
|
```
|
||||||
|
h = ReLU(W₁·x + b₁)
|
||||||
|
μ_z = W₂·h + b₂
|
||||||
|
log σ²_z = W₃·h + b₃
|
||||||
|
z ~ N(μ_z, σ²_z)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Latent space**: z ∈ ℝ^d (typically d=10-30)
|
||||||
|
|
||||||
|
**Decoder**:
|
||||||
|
```
|
||||||
|
h = ReLU(W₄·z + b₄)
|
||||||
|
μ = softmax(W₅·h + b₅) · library_size
|
||||||
|
θ = exp(W₆·h + b₆)
|
||||||
|
π = sigmoid(W₇·h + b₇) # for ZINB
|
||||||
|
x ~ ZINB(μ, θ, π)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Loss function (ELBO)**:
|
||||||
|
```
|
||||||
|
L = E_q[log p(x|z)] - KL(q(z|x) || N(0,I))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Handling Covariates
|
||||||
|
|
||||||
|
**Categorical covariates** (batch, donor, etc.):
|
||||||
|
- One-hot encoded: s ∈ {0,1}^K
|
||||||
|
- Concatenate with latent: [z, s]
|
||||||
|
- Or use conditional layers
|
||||||
|
|
||||||
|
**Continuous covariates** (library size, percent_mito):
|
||||||
|
- Standardize to zero mean, unit variance
|
||||||
|
- Include in encoder and/or decoder
|
||||||
|
|
||||||
|
**Covariate injection strategies**:
|
||||||
|
- **Concatenation**: [z, s] fed to decoder
|
||||||
|
- **Deep injection**: s added at multiple layers
|
||||||
|
- **Conditional batch norm**: Batch-specific normalization
|
||||||
|
|
||||||
|
## Advanced Theoretical Concepts
|
||||||
|
|
||||||
|
### Transfer Learning (scArches)
|
||||||
|
|
||||||
|
**Concept**: Use pretrained model as initialization for new data
|
||||||
|
|
||||||
|
**Process**:
|
||||||
|
1. Train reference model on large dataset
|
||||||
|
2. Freeze encoder parameters
|
||||||
|
3. Fine-tune decoder on query data
|
||||||
|
4. Or fine-tune all with lower learning rate
|
||||||
|
|
||||||
|
**Why it works**:
|
||||||
|
- Encoder learns general cellular representations
|
||||||
|
- Decoder adapts to query-specific characteristics
|
||||||
|
- Prevents catastrophic forgetting
|
||||||
|
|
||||||
|
**Applications**:
|
||||||
|
- Query-to-reference mapping
|
||||||
|
- Few-shot learning for rare cell types
|
||||||
|
- Rapid analysis of new datasets
|
||||||
|
|
||||||
|
### Multi-Resolution Modeling (MrVI)
|
||||||
|
|
||||||
|
**Idea**: Separate shared and sample-specific variation
|
||||||
|
|
||||||
|
**Latent space decomposition**:
|
||||||
|
```
|
||||||
|
z = z_shared + z_sample
|
||||||
|
```
|
||||||
|
- **z_shared**: Common across samples
|
||||||
|
- **z_sample**: Sample-specific effects
|
||||||
|
|
||||||
|
**Hierarchical structure**:
|
||||||
|
```
|
||||||
|
Sample level: ρ_s ~ N(0, I)
|
||||||
|
Cell level: z_i ~ N(ρ_{s(i)}, σ²)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Disentangle biological sources of variation
|
||||||
|
- Compare samples at different resolutions
|
||||||
|
- Identify sample-specific cell states
|
||||||
|
|
||||||
|
### Counterfactual Prediction
|
||||||
|
|
||||||
|
**Goal**: Predict outcome under different conditions
|
||||||
|
|
||||||
|
**Example**: "What would this cell look like if from different batch?"
|
||||||
|
|
||||||
|
**Method**:
|
||||||
|
1. Encode cell to latent: z = Encoder(x, s_original)
|
||||||
|
2. Decode with new condition: x_new = Decoder(z, s_new)
|
||||||
|
3. x_new is counterfactual prediction
|
||||||
|
|
||||||
|
**Applications**:
|
||||||
|
- Batch effect assessment
|
||||||
|
- Predicting treatment response
|
||||||
|
- In silico perturbation studies
|
||||||
|
|
||||||
|
### Posterior Predictive Distribution
|
||||||
|
|
||||||
|
**Definition**: Distribution of new data given observed data
|
||||||
|
|
||||||
|
```
|
||||||
|
p(x_new | x_observed) = ∫ p(x_new|z) q(z|x_observed) dz
|
||||||
|
```
|
||||||
|
|
||||||
|
**Estimation**: Sample z from q(z|x), generate x_new from p(x_new|z)
|
||||||
|
|
||||||
|
**Uses**:
|
||||||
|
- Uncertainty quantification
|
||||||
|
- Robust predictions
|
||||||
|
- Outlier detection
|
||||||
|
|
||||||
|
## Differential Expression Framework
|
||||||
|
|
||||||
|
### Bayesian Approach
|
||||||
|
|
||||||
|
**Traditional methods**: Compare point estimates
|
||||||
|
- Wilcoxon, t-test, etc.
|
||||||
|
- Ignore uncertainty
|
||||||
|
- Require pseudocounts
|
||||||
|
|
||||||
|
**scvi-tools approach**: Compare distributions
|
||||||
|
- Sample from posterior: μ_A ~ p(μ|x_A), μ_B ~ p(μ|x_B)
|
||||||
|
- Compute log fold-change: LFC = log(μ_B) - log(μ_A)
|
||||||
|
- Posterior distribution of LFC quantifies uncertainty
|
||||||
|
|
||||||
|
### Bayes Factor
|
||||||
|
|
||||||
|
**Definition**: Ratio of posterior odds to prior odds
|
||||||
|
|
||||||
|
```
|
||||||
|
BF = P(H₁|data) / P(H₀|data)
|
||||||
|
─────────────────────────
|
||||||
|
P(H₁) / P(H₀)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Interpretation**:
|
||||||
|
- BF > 3: Moderate evidence for H₁
|
||||||
|
- BF > 10: Strong evidence
|
||||||
|
- BF > 100: Decisive evidence
|
||||||
|
|
||||||
|
**In scvi-tools**: Used to rank genes by evidence for DE
|
||||||
|
|
||||||
|
### False Discovery Proportion (FDP)
|
||||||
|
|
||||||
|
**Goal**: Control expected false discovery rate
|
||||||
|
|
||||||
|
**Procedure**:
|
||||||
|
1. For each gene, compute posterior probability of DE
|
||||||
|
2. Rank genes by evidence (Bayes factor)
|
||||||
|
3. Select top k genes such that E[FDP] ≤ α
|
||||||
|
|
||||||
|
**Advantage over p-values**:
|
||||||
|
- Fully Bayesian
|
||||||
|
- Natural for posterior inference
|
||||||
|
- No arbitrary thresholds
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### Optimization
|
||||||
|
|
||||||
|
**Optimizer**: Adam (adaptive learning rates)
|
||||||
|
- Default lr = 0.001
|
||||||
|
- Momentum parameters: β₁=0.9, β₂=0.999
|
||||||
|
|
||||||
|
**Training loop**:
|
||||||
|
1. Sample mini-batch of cells
|
||||||
|
2. Compute ELBO loss
|
||||||
|
3. Backpropagate gradients
|
||||||
|
4. Update parameters with Adam
|
||||||
|
5. Repeat until convergence
|
||||||
|
|
||||||
|
**Convergence criteria**:
|
||||||
|
- ELBO plateaus on validation set
|
||||||
|
- Early stopping prevents overfitting
|
||||||
|
- Typically 200-500 epochs
|
||||||
|
|
||||||
|
### Regularization
|
||||||
|
|
||||||
|
**KL annealing**: Gradually increase KL weight
|
||||||
|
- Prevents posterior collapse
|
||||||
|
- Starts at 0, increases to 1 over epochs
|
||||||
|
|
||||||
|
**Dropout**: Random neuron dropping during training
|
||||||
|
- Default: 0.1 dropout rate
|
||||||
|
- Prevents overfitting
|
||||||
|
- Improves generalization
|
||||||
|
|
||||||
|
**Weight decay**: L2 regularization on weights
|
||||||
|
- Prevents large weights
|
||||||
|
- Improves stability
|
||||||
|
|
||||||
|
### Scalability
|
||||||
|
|
||||||
|
**Mini-batch training**:
|
||||||
|
- Process subset of cells per iteration
|
||||||
|
- Batch size: 64-256 cells
|
||||||
|
- Enables scaling to millions of cells
|
||||||
|
|
||||||
|
**Stochastic optimization**:
|
||||||
|
- Estimates ELBO on mini-batches
|
||||||
|
- Unbiased gradient estimates
|
||||||
|
- Converges to optimal solution
|
||||||
|
|
||||||
|
**GPU acceleration**:
|
||||||
|
- Neural networks naturally parallelize
|
||||||
|
- Order of magnitude speedup
|
||||||
|
- Essential for large datasets
|
||||||
|
|
||||||
|
## Connections to Other Methods
|
||||||
|
|
||||||
|
### vs. PCA
|
||||||
|
- **PCA**: Linear, deterministic
|
||||||
|
- **scVI**: Nonlinear, probabilistic
|
||||||
|
- **Advantage**: scVI captures complex structure, handles counts
|
||||||
|
|
||||||
|
### vs. t-SNE/UMAP
|
||||||
|
- **t-SNE/UMAP**: Visualization-focused
|
||||||
|
- **scVI**: Full generative model
|
||||||
|
- **Advantage**: scVI enables downstream tasks (DE, imputation)
|
||||||
|
|
||||||
|
### vs. Seurat Integration
|
||||||
|
- **Seurat**: Anchor-based alignment
|
||||||
|
- **scVI**: Probabilistic modeling
|
||||||
|
- **Advantage**: scVI provides uncertainty, works for multiple batches
|
||||||
|
|
||||||
|
### vs. Harmony
|
||||||
|
- **Harmony**: PCA + batch correction
|
||||||
|
- **scVI**: VAE-based
|
||||||
|
- **Advantage**: scVI handles counts natively, more flexible
|
||||||
|
|
||||||
|
## Mathematical Notation
|
||||||
|
|
||||||
|
**Common symbols**:
|
||||||
|
- x: Observed gene expression (counts)
|
||||||
|
- z: Latent representation
|
||||||
|
- θ: Model parameters
|
||||||
|
- q(z|x): Approximate posterior (encoder)
|
||||||
|
- p(x|z): Likelihood (decoder)
|
||||||
|
- p(z): Prior on latent variables
|
||||||
|
- μ, σ²: Mean and variance
|
||||||
|
- π: Dropout probability (ZINB)
|
||||||
|
- θ (in NB): Dispersion parameter
|
||||||
|
- s: Batch/covariate indicator
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
**Key Papers**:
|
||||||
|
1. Lopez et al. (2018): "Deep generative modeling for single-cell transcriptomics"
|
||||||
|
2. Xu et al. (2021): "Probabilistic harmonization and annotation of single-cell transcriptomics"
|
||||||
|
3. Boyeau et al. (2019): "Deep generative models for detecting differential expression in single cells"
|
||||||
|
|
||||||
|
**Concepts to explore**:
|
||||||
|
- Variational inference in machine learning
|
||||||
|
- Bayesian deep learning
|
||||||
|
- Information theory (KL divergence, mutual information)
|
||||||
|
- Generative models (GANs, normalizing flows, diffusion models)
|
||||||
|
- Probabilistic programming (Pyro, PyTorch)
|
||||||
|
|
||||||
|
**Mathematical background**:
|
||||||
|
- Probability theory and statistics
|
||||||
|
- Linear algebra and calculus
|
||||||
|
- Optimization theory
|
||||||
|
- Information theory
|
||||||
Reference in New Issue
Block a user