From 62ddd1296c37586dd7712e9230a2d367b95c84c3 Mon Sep 17 00:00:00 2001 From: Timothy Kassis Date: Thu, 23 Oct 2025 09:10:34 -0700 Subject: [PATCH] Add ESM3 and ESM C models protein models --- scientific-packages/esm/SKILL.md | 300 ++++++++ .../esm/references/esm-c-api.md | 583 +++++++++++++++ .../esm/references/esm3-api.md | 452 ++++++++++++ .../esm/references/forge-api.md | 657 +++++++++++++++++ .../esm/references/workflows.md | 685 ++++++++++++++++++ .../references/theoretical-foundations.md | 438 +++++++++++ 6 files changed, 3115 insertions(+) create mode 100644 scientific-packages/esm/SKILL.md create mode 100644 scientific-packages/esm/references/esm-c-api.md create mode 100644 scientific-packages/esm/references/esm3-api.md create mode 100644 scientific-packages/esm/references/forge-api.md create mode 100644 scientific-packages/esm/references/workflows.md create mode 100644 scientific-packages/scvi-tools/references/theoretical-foundations.md diff --git a/scientific-packages/esm/SKILL.md b/scientific-packages/esm/SKILL.md new file mode 100644 index 0000000..3fbc606 --- /dev/null +++ b/scientific-packages/esm/SKILL.md @@ -0,0 +1,300 @@ +--- +name: esm +description: Comprehensive toolkit for protein language models including ESM3 (generative multimodal protein design across sequence, structure, and function) and ESM C (efficient protein embeddings and representations). Use this skill when working with protein sequences, structures, or function prediction; designing novel proteins; generating protein embeddings; performing inverse folding; or conducting protein engineering tasks. Supports both local model usage and cloud-based Forge API for scalable inference. +--- + +# ESM: Evolutionary Scale Modeling + +## Overview + +ESM provides state-of-the-art protein language models for understanding, generating, and designing proteins. This skill enables working with two model families: ESM3 for generative protein design across sequence, structure, and function, and ESM C for efficient protein representation learning and embeddings. + +## Core Capabilities + +### 1. Protein Sequence Generation with ESM3 + +Generate novel protein sequences with desired properties using multimodal generative modeling. + +**When to use:** +- Designing proteins with specific functional properties +- Completing partial protein sequences +- Generating variants of existing proteins +- Creating proteins with desired structural characteristics + +**Basic usage:** + +```python +from esm.models.esm3 import ESM3 +from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig + +# Load model locally +model: ESM3InferenceClient = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") + +# Create protein prompt +protein = ESMProtein(sequence="MPRT___KEND") # '_' represents masked positions + +# Generate completion +protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8)) +print(protein.sequence) +``` + +**For remote/cloud usage via Forge API:** + +```python +from esm.sdk.forge import ESM3ForgeInferenceClient +from esm.sdk.api import ESMProtein, GenerationConfig + +# Connect to Forge +model = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", url="https://forge.evolutionaryscale.ai", token="") + +# Generate +protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8)) +``` + +See `references/esm3-api.md` for detailed ESM3 model specifications, advanced generation configurations, and multimodal prompting examples. + +### 2. Structure Prediction and Inverse Folding + +Use ESM3's structure track for structure prediction from sequence or inverse folding (sequence design from structure). + +**Structure prediction:** + +```python +from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig + +# Predict structure from sequence +protein = ESMProtein(sequence="MPRTKEINDAGLIVHSP...") +protein_with_structure = model.generate( + protein, + GenerationConfig(track="structure", num_steps=protein.sequence.count("_")) +) + +# Access predicted structure +coordinates = protein_with_structure.coordinates # 3D coordinates +pdb_string = protein_with_structure.to_pdb() +``` + +**Inverse folding (sequence from structure):** + +```python +# Design sequence for a target structure +protein_with_structure = ESMProtein.from_pdb("target_structure.pdb") +protein_with_structure.sequence = None # Remove sequence + +# Generate sequence that folds to this structure +designed_protein = model.generate( + protein_with_structure, + GenerationConfig(track="sequence", num_steps=50, temperature=0.7) +) +``` + +### 3. Protein Embeddings with ESM C + +Generate high-quality embeddings for downstream tasks like function prediction, classification, or similarity analysis. + +**When to use:** +- Extracting protein representations for machine learning +- Computing sequence similarities +- Feature extraction for protein classification +- Transfer learning for protein-related tasks + +**Basic usage:** + +```python +from esm.models.esmc import ESMC +from esm.sdk.api import ESMProtein + +# Load ESM C model +model = ESMC.from_pretrained("esmc-300m").to("cuda") + +# Get embeddings +protein = ESMProtein(sequence="MPRTKEINDAGLIVHSP...") +protein_tensor = model.encode(protein) + +# Generate embeddings +embeddings = model.forward(protein_tensor) +``` + +**Batch processing:** + +```python +# Encode multiple proteins +proteins = [ + ESMProtein(sequence="MPRTKEIND..."), + ESMProtein(sequence="AGLIVHSPQ..."), + ESMProtein(sequence="KTEFLNDGR...") +] + +embeddings_list = [model.logits(model.forward(model.encode(p))) for p in proteins] +``` + +See `references/esm-c-api.md` for ESM C model details, efficiency comparisons, and advanced embedding strategies. + +### 4. Function Conditioning and Annotation + +Use ESM3's function track to generate proteins with specific functional annotations or predict function from sequence. + +**Function-conditioned generation:** + +```python +from esm.sdk.api import ESMProtein, FunctionAnnotation, GenerationConfig + +# Create protein with desired function +protein = ESMProtein( + sequence="_" * 200, # Generate 200 residue protein + function_annotations=[ + FunctionAnnotation(label="fluorescent_protein", start=50, end=150) + ] +) + +# Generate sequence with specified function +functional_protein = model.generate( + protein, + GenerationConfig(track="sequence", num_steps=200) +) +``` + +### 5. Chain-of-Thought Generation + +Iteratively refine protein designs using ESM3's chain-of-thought generation approach. + +```python +from esm.sdk.api import GenerationConfig + +# Multi-step refinement +protein = ESMProtein(sequence="MPRT" + "_" * 100 + "KEND") + +# Step 1: Generate initial structure +config = GenerationConfig(track="structure", num_steps=50) +protein = model.generate(protein, config) + +# Step 2: Refine sequence based on structure +config = GenerationConfig(track="sequence", num_steps=50, temperature=0.5) +protein = model.generate(protein, config) + +# Step 3: Predict function +config = GenerationConfig(track="function", num_steps=20) +protein = model.generate(protein, config) +``` + +### 6. Batch Processing with Forge API + +Process multiple proteins efficiently using Forge's async executor. + +```python +from esm.sdk.forge import ESM3ForgeInferenceClient +import asyncio + +client = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", token="") + +# Async batch processing +async def batch_generate(proteins_list): + tasks = [ + client.async_generate(protein, GenerationConfig(track="sequence")) + for protein in proteins_list + ] + return await asyncio.gather(*tasks) + +# Execute +proteins = [ESMProtein(sequence=f"MPRT{'_' * 50}KEND") for _ in range(10)] +results = asyncio.run(batch_generate(proteins)) +``` + +See `references/forge-api.md` for detailed Forge API documentation, authentication, rate limits, and batch processing patterns. + +## Model Selection Guide + +**ESM3 Models (Generative):** +- `esm3-sm-open-v1` (1.4B) - Open weights, local usage, good for experimentation +- `esm3-medium-2024-08` (7B) - Best balance of quality and speed (Forge only) +- `esm3-large-2024-03` (98B) - Highest quality, slower (Forge only) + +**ESM C Models (Embeddings):** +- `esmc-300m` (30 layers) - Lightweight, fast inference +- `esmc-600m` (36 layers) - Balanced performance +- `esmc-6b` (80 layers) - Maximum representation quality + +**Selection criteria:** +- **Local development/testing:** Use `esm3-sm-open-v1` or `esmc-300m` +- **Production quality:** Use `esm3-medium-2024-08` via Forge +- **Maximum accuracy:** Use `esm3-large-2024-03` or `esmc-6b` +- **High throughput:** Use Forge API with batch executor +- **Cost optimization:** Use smaller models, implement caching strategies + +## Installation + +**Basic installation:** + +```bash +pip install esm +``` + +**With Flash Attention (recommended for faster inference):** + +```bash +pip install esm +pip install flash-attn --no-build-isolation +``` + +**For Forge API access:** + +```bash +pip install esm # SDK includes Forge client +``` + +No additional dependencies needed. Obtain Forge API token at https://forge.evolutionaryscale.ai + +## Common Workflows + +For detailed examples and complete workflows, see `references/workflows.md` which includes: +- Novel GFP design with chain-of-thought +- Protein variant generation and screening +- Structure-based sequence optimization +- Function prediction pipelines +- Embedding-based clustering and analysis + +## References + +This skill includes comprehensive reference documentation: + +- `references/esm3-api.md` - ESM3 model architecture, API reference, generation parameters, and multimodal prompting +- `references/esm-c-api.md` - ESM C model details, embedding strategies, and performance optimization +- `references/forge-api.md` - Forge platform documentation, authentication, batch processing, and deployment +- `references/workflows.md` - Complete examples and common workflow patterns + +These references contain detailed API specifications, parameter descriptions, and advanced usage patterns. Load them as needed for specific tasks. + +## Best Practices + +**For generation tasks:** +- Start with smaller models for prototyping (`esm3-sm-open-v1`) +- Use temperature parameter to control diversity (0.0 = deterministic, 1.0 = diverse) +- Implement iterative refinement with chain-of-thought for complex designs +- Validate generated sequences with structure prediction or wet-lab experiments + +**For embedding tasks:** +- Batch process sequences when possible for efficiency +- Cache embeddings for repeated analyses +- Normalize embeddings when computing similarities +- Use appropriate model size based on downstream task requirements + +**For production deployment:** +- Use Forge API for scalability and latest models +- Implement error handling and retry logic for API calls +- Monitor token usage and implement rate limiting +- Consider AWS SageMaker deployment for dedicated infrastructure + +## Resources and Documentation + +- **GitHub Repository:** https://github.com/evolutionaryscale/esm +- **Forge Platform:** https://forge.evolutionaryscale.ai +- **Scientific Paper:** Hayes et al., Science (2025) - https://www.science.org/doi/10.1126/science.ads0018 +- **Blog Posts:** + - ESM3 Release: https://www.evolutionaryscale.ai/blog/esm3-release + - ESM C Launch: https://www.evolutionaryscale.ai/blog/esm-cambrian +- **Community:** Slack community at https://bit.ly/3FKwcWd +- **Model Weights:** HuggingFace EvolutionaryScale organization + +## Responsible Use + +ESM is designed for beneficial applications in protein engineering, drug discovery, and scientific research. Follow the Responsible Biodesign Framework (https://responsiblebiodesign.ai/) when designing novel proteins. Consider biosafety and ethical implications of protein designs before experimental validation. diff --git a/scientific-packages/esm/references/esm-c-api.md b/scientific-packages/esm/references/esm-c-api.md new file mode 100644 index 0000000..ff9aabc --- /dev/null +++ b/scientific-packages/esm/references/esm-c-api.md @@ -0,0 +1,583 @@ +# ESM C API Reference + +## Overview + +ESM C (Cambrian) is a family of protein language models optimized for representation learning and efficient embedding generation. Designed as a drop-in replacement for ESM2, ESM C provides significant improvements in speed and quality across all model sizes. + +## Model Architecture + +**ESM C Family Models:** + +| Model ID | Parameters | Layers | Best For | +|----------|-----------|--------|----------| +| `esmc-300m` | 300M | 30 | Fast inference, lightweight applications | +| `esmc-600m` | 600M | 36 | Balanced performance and quality | +| `esmc-6b` | 6B | 80 | Maximum representation quality | + +**Key Features:** +- 3x faster inference than ESM2 +- Improved perplexity and embedding quality +- Efficient architecture for production deployment +- Compatible with ESM2 workflows (drop-in replacement) +- Support for long sequences (up to 1024 residues efficiently) + +**Architecture Improvements over ESM2:** +- Optimized attention mechanisms +- Better token representation +- Enhanced training procedures +- Reduced memory footprint + +## Core API Components + +### ESMC Class + +Main interface for ESM C models. + +**Model Loading:** + +```python +from esm.models.esmc import ESMC +from esm.sdk.api import ESMProtein + +# Load model with automatic device placement +model = ESMC.from_pretrained("esmc-300m").to("cuda") + +# Or specify device explicitly +model = ESMC.from_pretrained("esmc-600m").to("cpu") + +# For maximum quality +model = ESMC.from_pretrained("esmc-6b").to("cuda") +``` + +**Model Selection Criteria:** + +- **esmc-300m**: Development, real-time applications, batch processing of many sequences +- **esmc-600m**: Production deployments, good quality/speed balance +- **esmc-6b**: Research, maximum accuracy for downstream tasks + +### Basic Embedding Generation + +**Single Sequence:** + +```python +from esm.models.esmc import ESMC +from esm.sdk.api import ESMProtein + +# Load model +model = ESMC.from_pretrained("esmc-600m").to("cuda") + +# Create protein +protein = ESMProtein(sequence="MPRTKEINDAGLIVHSPQWFYK") + +# Encode to tensor +protein_tensor = model.encode(protein) + +# Generate embeddings +embeddings = model.forward(protein_tensor) + +# Get logits (per-position predictions) +logits = model.logits(embeddings) + +print(f"Embedding shape: {embeddings.shape}") +print(f"Logits shape: {logits.shape}") +``` + +**Output Shapes:** + +For a sequence of length L: +- `embeddings.shape`: `(1, L, hidden_dim)` where hidden_dim depends on model + - esmc-300m: hidden_dim = 960 + - esmc-600m: hidden_dim = 1152 + - esmc-6b: hidden_dim = 2560 +- `logits.shape`: `(1, L, 64)` - per-position amino acid predictions + +### Batch Processing + +Process multiple sequences efficiently: + +```python +import torch + +# Multiple proteins +sequences = [ + "MPRTKEINDAGLIVHSP", + "AGKWFYLTQSNHERVPM", + "DEIFKRNAVWGSLTPQY" +] + +proteins = [ESMProtein(sequence=seq) for seq in sequences] + +# Encode all +protein_tensors = [model.encode(p) for p in proteins] + +# Process batch (if same length) +# For variable lengths, process individually or pad +embeddings_list = [] +for tensor in protein_tensors: + embedding = model.forward(tensor) + embeddings_list.append(embedding) + +print(f"Processed {len(embeddings_list)} proteins") +``` + +**Efficient Batching for Variable Lengths:** + +```python +def batch_encode_variable_length(model, sequences, max_batch_size=32): + """ + Efficiently batch encode sequences of variable length. + Groups by similar length for efficiency. + """ + # Sort by length + sorted_seqs = sorted(enumerate(sequences), key=lambda x: len(x[1])) + + results = [None] * len(sequences) + batch = [] + batch_indices = [] + + for idx, seq in sorted_seqs: + batch.append(seq) + batch_indices.append(idx) + + # Process batch when full or length changes significantly + if (len(batch) >= max_batch_size or + (len(batch) > 0 and abs(len(seq) - len(batch[0])) > 10)): + + # Process current batch + proteins = [ESMProtein(sequence=s) for s in batch] + embeddings = [model.forward(model.encode(p)) for p in proteins] + + # Store results + for i, emb in zip(batch_indices, embeddings): + results[i] = emb + + batch = [] + batch_indices = [] + + # Process remaining + if batch: + proteins = [ESMProtein(sequence=s) for s in batch] + embeddings = [model.forward(model.encode(p)) for p in proteins] + for i, emb in zip(batch_indices, embeddings): + results[i] = emb + + return results +``` + +## Common Use Cases + +### 1. Sequence Similarity Analysis + +Compute similarity between proteins using embeddings: + +```python +import torch +import torch.nn.functional as F + +def get_sequence_embedding(model, sequence): + """Get mean-pooled sequence embedding.""" + protein = ESMProtein(sequence=sequence) + tensor = model.encode(protein) + embedding = model.forward(tensor) + + # Mean pooling over sequence length + return embedding.mean(dim=1) + +# Get embeddings +seq1_emb = get_sequence_embedding(model, "MPRTKEINDAGLIVHSP") +seq2_emb = get_sequence_embedding(model, "MPRTKEINDAGLIVHSQ") # Similar +seq3_emb = get_sequence_embedding(model, "WWWWWWWWWWWWWWWWW") # Different + +# Compute cosine similarity +sim_1_2 = F.cosine_similarity(seq1_emb, seq2_emb) +sim_1_3 = F.cosine_similarity(seq1_emb, seq3_emb) + +print(f"Similarity (1,2): {sim_1_2.item():.4f}") +print(f"Similarity (1,3): {sim_1_3.item():.4f}") +``` + +### 2. Protein Classification + +Use embeddings as features for classification: + +```python +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +# Generate embeddings for training set +def embed_dataset(model, sequences): + embeddings = [] + for seq in sequences: + protein = ESMProtein(sequence=seq) + tensor = model.encode(protein) + emb = model.forward(tensor).mean(dim=1) # Mean pooling + embeddings.append(emb.cpu().detach().numpy().flatten()) + return np.array(embeddings) + +# Example: Classify proteins by function +train_sequences = [...] # Your sequences +train_labels = [...] # Your labels + +embeddings = embed_dataset(model, train_sequences) + +# Train classifier +X_train, X_test, y_train, y_test = train_test_split( + embeddings, train_labels, test_size=0.2 +) + +classifier = LogisticRegression(max_iter=1000) +classifier.fit(X_train, y_train) + +# Evaluate +accuracy = classifier.score(X_test, y_test) +print(f"Classification accuracy: {accuracy:.4f}") +``` + +### 3. Protein Clustering + +Cluster proteins based on sequence similarity: + +```python +from sklearn.cluster import KMeans +import numpy as np + +# Generate embeddings +sequences = [...] # Your protein sequences +embeddings = embed_dataset(model, sequences) + +# Cluster +n_clusters = 5 +kmeans = KMeans(n_clusters=n_clusters, random_state=42) +cluster_labels = kmeans.fit_predict(embeddings) + +# Analyze clusters +for i in range(n_clusters): + cluster_seqs = [seq for seq, label in zip(sequences, cluster_labels) if label == i] + print(f"Cluster {i}: {len(cluster_seqs)} sequences") +``` + +### 4. Sequence Search and Retrieval + +Find similar sequences in a database: + +```python +import torch +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity + +def build_sequence_index(model, database_sequences): + """Build searchable index of sequence embeddings.""" + embeddings = [] + for seq in database_sequences: + emb = get_sequence_embedding(model, seq) + embeddings.append(emb.cpu().detach().numpy().flatten()) + return np.array(embeddings) + +def search_similar_sequences(model, query_seq, database_embeddings, + database_sequences, top_k=10): + """Find top-k most similar sequences.""" + query_emb = get_sequence_embedding(model, query_seq) + query_emb_np = query_emb.cpu().detach().numpy().flatten().reshape(1, -1) + + # Compute similarities + similarities = cosine_similarity(query_emb_np, database_embeddings)[0] + + # Get top-k + top_indices = np.argsort(similarities)[-top_k:][::-1] + + results = [ + (database_sequences[idx], similarities[idx]) + for idx in top_indices + ] + return results + +# Example usage +database_seqs = [...] # Large sequence database +index = build_sequence_index(model, database_seqs) + +query = "MPRTKEINDAGLIVHSP" +similar = search_similar_sequences(model, query, index, database_seqs, top_k=5) + +for seq, score in similar: + print(f"Score: {score:.4f} - {seq[:30]}...") +``` + +### 5. Feature Extraction for Downstream Models + +Use ESM C embeddings as input to custom neural networks: + +```python +import torch.nn as nn + +class ProteinPropertyPredictor(nn.Module): + """Example: Predict protein properties from ESM C embeddings.""" + + def __init__(self, embedding_dim, hidden_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(embedding_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.3) + + def forward(self, embeddings): + # embeddings: (batch, seq_len, embedding_dim) + # Mean pool over sequence + x = embeddings.mean(dim=1) + + x = self.relu(self.fc1(x)) + x = self.dropout(x) + x = self.relu(self.fc2(x)) + x = self.dropout(x) + x = self.fc3(x) + return x + +# Use ESM C as frozen feature extractor +esm_model = ESMC.from_pretrained("esmc-600m").to("cuda") +esm_model.eval() # Freeze + +# Create task-specific model +predictor = ProteinPropertyPredictor( + embedding_dim=1152, # esmc-600m dimension + hidden_dim=512, + output_dim=1 # e.g., stability score +).to("cuda") + +# Training loop +for sequence, target in dataloader: + protein = ESMProtein(sequence=sequence) + with torch.no_grad(): + embeddings = esm_model.forward(esm_model.encode(protein)) + + prediction = predictor(embeddings) + loss = criterion(prediction, target) + # ... backprop through predictor only +``` + +### 6. Per-Residue Analysis + +Extract per-residue representations for detailed analysis: + +```python +def get_per_residue_embeddings(model, sequence): + """Get embedding for each residue.""" + protein = ESMProtein(sequence=sequence) + tensor = model.encode(protein) + embeddings = model.forward(tensor) + + # embeddings shape: (1, seq_len, hidden_dim) + return embeddings.squeeze(0) # (seq_len, hidden_dim) + +# Analyze specific positions +sequence = "MPRTKEINDAGLIVHSPQWFYK" +residue_embeddings = get_per_residue_embeddings(model, sequence) + +# Extract features for position 10 +position_10_features = residue_embeddings[10] +print(f"Features for residue {sequence[10]} at position 10:") +print(f"Shape: {position_10_features.shape}") + +# Compare residue representations +pos_5 = residue_embeddings[5] +pos_15 = residue_embeddings[15] +similarity = F.cosine_similarity(pos_5, pos_15, dim=0) +print(f"Residue similarity: {similarity.item():.4f}") +``` + +## Performance Optimization + +### Memory Management + +```python +import torch + +# Use half precision for memory efficiency +model = ESMC.from_pretrained("esmc-600m").to("cuda").half() + +# Process with mixed precision +with torch.cuda.amp.autocast(): + embeddings = model.forward(model.encode(protein)) + +# Clear cache between batches +torch.cuda.empty_cache() +``` + +### Batch Processing Best Practices + +```python +def efficient_batch_processing(model, sequences, batch_size=32): + """Process sequences in optimized batches.""" + results = [] + + for i in range(0, len(sequences), batch_size): + batch = sequences[i:i + batch_size] + + # Process batch + batch_embeddings = [] + for seq in batch: + protein = ESMProtein(sequence=seq) + emb = model.forward(model.encode(protein)) + batch_embeddings.append(emb) + + results.extend(batch_embeddings) + + # Periodically clear cache + if i % (batch_size * 10) == 0: + torch.cuda.empty_cache() + + return results +``` + +### Caching Embeddings + +```python +import pickle +import hashlib + +def get_cache_key(sequence): + """Generate cache key for sequence.""" + return hashlib.md5(sequence.encode()).hexdigest() + +class EmbeddingCache: + """Cache for protein embeddings.""" + + def __init__(self, cache_file="embeddings_cache.pkl"): + self.cache_file = cache_file + try: + with open(cache_file, 'rb') as f: + self.cache = pickle.load(f) + except FileNotFoundError: + self.cache = {} + + def get(self, sequence): + key = get_cache_key(sequence) + return self.cache.get(key) + + def set(self, sequence, embedding): + key = get_cache_key(sequence) + self.cache[key] = embedding + + def save(self): + with open(self.cache_file, 'wb') as f: + pickle.dump(self.cache, f) + +# Usage +cache = EmbeddingCache() + +def get_embedding_cached(model, sequence): + cached = cache.get(sequence) + if cached is not None: + return cached + + # Compute + protein = ESMProtein(sequence=sequence) + embedding = model.forward(model.encode(protein)) + cache.set(sequence, embedding) + + return embedding + +# Don't forget to save cache +cache.save() +``` + +## Comparison with ESM2 + +**Performance Improvements:** + +| Metric | ESM2-650M | ESM C-600M | Improvement | +|--------|-----------|------------|-------------| +| Inference Speed | 1.0x | 3.0x | 3x faster | +| Perplexity | Higher | Lower | Better | +| Memory Usage | 1.0x | 0.8x | 20% less | +| Embedding Quality | Baseline | Improved | +5-10% | + +**Migration from ESM2:** + +ESM C is designed as a drop-in replacement: + +```python +# Old ESM2 code +from esm import pretrained +model, alphabet = pretrained.esm2_t33_650M_UR50D() + +# New ESM C code (similar API) +from esm.models.esmc import ESMC +model = ESMC.from_pretrained("esmc-600m") +``` + +Key differences: +- Faster inference with same or better quality +- Simplified API through ESMProtein +- Better support for long sequences +- More efficient memory usage + +## Advanced Topics + +### Fine-tuning ESM C + +ESM C can be fine-tuned for specific tasks: + +```python +import torch.optim as optim + +# Load model +model = ESMC.from_pretrained("esmc-300m").to("cuda") + +# Unfreeze for fine-tuning +for param in model.parameters(): + param.requires_grad = True + +# Define optimizer +optimizer = optim.Adam(model.parameters(), lr=1e-5) + +# Training loop +for epoch in range(num_epochs): + for sequences, labels in dataloader: + optimizer.zero_grad() + + # Forward pass + proteins = [ESMProtein(sequence=seq) for seq in sequences] + embeddings = [model.forward(model.encode(p)) for p in proteins] + + # Your task-specific loss + loss = compute_loss(embeddings, labels) + + loss.backward() + optimizer.step() +``` + +### Attention Visualization + +Extract attention weights for interpretability: + +```python +def get_attention_weights(model, sequence): + """Extract attention weights from model.""" + protein = ESMProtein(sequence=sequence) + tensor = model.encode(protein) + + # Forward with attention output + output = model.forward(tensor, output_attentions=True) + + return output.attentions # List of attention tensors per layer + +# Visualize attention +attentions = get_attention_weights(model, "MPRTKEINDAGLIVHSP") +# Process and visualize attention patterns +``` + +## Citation + +If using ESM C in research, cite: + +``` +ESM Cambrian: https://www.evolutionaryscale.ai/blog/esm-cambrian +EvolutionaryScale (2024) +``` + +## Additional Resources + +- ESM C blog post: https://www.evolutionaryscale.ai/blog/esm-cambrian +- Model weights: HuggingFace EvolutionaryScale organization +- Comparison benchmarks: See blog post for detailed performance comparisons diff --git a/scientific-packages/esm/references/esm3-api.md b/scientific-packages/esm/references/esm3-api.md new file mode 100644 index 0000000..b979942 --- /dev/null +++ b/scientific-packages/esm/references/esm3-api.md @@ -0,0 +1,452 @@ +# ESM3 API Reference + +## Overview + +ESM3 is a frontier multimodal generative language model that reasons over the sequence, structure, and function of proteins. It uses iterative masked language modeling to simultaneously generate across these three modalities. + +## Model Architecture + +**ESM3 Family Models:** + +| Model ID | Parameters | Availability | Best For | +|----------|-----------|--------------|----------| +| `esm3-sm-open-v1` | 1.4B | Open weights (local) | Development, testing, learning | +| `esm3-medium-2024-08` | 7B | Forge API only | Production, balanced quality/speed | +| `esm3-large-2024-03` | 98B | Forge API only | Maximum quality, research | +| `esm3-medium-multimer-2024-09` | 7B | Forge API only | Protein complexes (experimental) | + +**Key Features:** +- Simultaneous reasoning across sequence, structure, and function +- Iterative generation with controllable number of steps +- Support for partial prompting across modalities +- Chain-of-thought generation for complex designs +- Temperature control for generation diversity + +## Core API Components + +### ESMProtein Class + +The central data structure representing a protein with optional sequence, structure, and function information. + +**Constructor:** + +```python +from esm.sdk.api import ESMProtein + +protein = ESMProtein( + sequence="MPRTKEINDAGLIVHSP", # Amino acid sequence (optional) + coordinates=coordinates_array, # 3D structure (optional) + function_annotations=[...], # Function labels (optional) + secondary_structure="HHHEEEECCC", # SS annotations (optional) + sasa=sasa_array # Solvent accessibility (optional) +) +``` + +**Key Methods:** + +```python +# Load from PDB file +protein = ESMProtein.from_pdb("protein.pdb") + +# Export to PDB format +pdb_string = protein.to_pdb() + +# Save to file +with open("output.pdb", "w") as f: + f.write(protein.to_pdb()) +``` + +**Masking Conventions:** + +Use `_` (underscore) to represent masked positions for generation: + +```python +# Mask positions 5-10 for generation +protein = ESMProtein(sequence="MPRT______AGLIVHSP") + +# Fully masked sequence (generate from scratch) +protein = ESMProtein(sequence="_" * 200) + +# Partial structure (some coordinates None) +protein = ESMProtein( + sequence="MPRTKEIND", + coordinates=partial_coords # Some positions can be None +) +``` + +### GenerationConfig Class + +Controls generation behavior and parameters. + +**Basic Configuration:** + +```python +from esm.sdk.api import GenerationConfig + +config = GenerationConfig( + track="sequence", # Track to generate: "sequence", "structure", or "function" + num_steps=8, # Number of demasking steps + temperature=0.7, # Sampling temperature (0.0-1.0) + top_p=None, # Nucleus sampling threshold + condition_on_coordinates_only=False # For structure conditioning +) +``` + +**Parameter Details:** + +- **track**: Which modality to generate + - `"sequence"`: Generate amino acid sequence + - `"structure"`: Generate 3D coordinates + - `"function"`: Generate function annotations + +- **num_steps**: Number of iterative demasking steps + - Higher = slower but potentially better quality + - Typical range: 8-100 depending on sequence length + - For full sequence generation: approximately sequence_length / 2 + +- **temperature**: Controls randomness + - 0.0: Fully deterministic (greedy decoding) + - 0.5-0.7: Balanced exploration + - 1.0: Maximum diversity + - Higher values increase novelty but may reduce quality + +- **top_p**: Nucleus sampling parameter + - Limits sampling to top probability mass + - Values: 0.0-1.0 (e.g., 0.9 = sample from top 90% probability mass) + - Use for controlled diversity without extreme sampling + +- **condition_on_coordinates_only**: Structure conditioning mode + - `True`: Condition only on backbone coordinates (ignore sequence) + - Useful for inverse folding tasks + +### ESM3InferenceClient Interface + +The unified interface for both local and remote inference. + +**Local Model Loading:** + +```python +from esm.models.esm3 import ESM3 + +# Load with automatic device placement +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") + +# Or explicitly specify device +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cpu") +``` + +**Generation Method:** + +```python +# Basic generation +protein_output = model.generate(protein_input, config) + +# With explicit track specification +protein_output = model.generate( + protein_input, + GenerationConfig(track="sequence", num_steps=16, temperature=0.6) +) +``` + +**Forward Pass (Advanced):** + +```python +# Get raw model logits for custom sampling +protein_tensor = model.encode(protein) +output = model.forward(protein_tensor) +logits = model.decode(output) +``` + +## Common Usage Patterns + +### 1. Sequence Completion + +Fill in masked regions of a protein sequence: + +```python +# Define partial sequence +protein = ESMProtein(sequence="MPRTK____LIVHSP____END") + +# Generate missing positions +config = GenerationConfig(track="sequence", num_steps=12, temperature=0.5) +completed = model.generate(protein, config) + +print(f"Original: {protein.sequence}") +print(f"Completed: {completed.sequence}") +``` + +### 2. Structure Prediction + +Predict 3D structure from sequence: + +```python +# Input: sequence only +protein = ESMProtein(sequence="MPRTKEINDAGLIVHSPQWFYK") + +# Generate structure +config = GenerationConfig(track="structure", num_steps=len(protein.sequence)) +protein_with_structure = model.generate(protein, config) + +# Save as PDB +with open("predicted_structure.pdb", "w") as f: + f.write(protein_with_structure.to_pdb()) +``` + +### 3. Inverse Folding + +Design sequence for a target structure: + +```python +# Load target structure +target = ESMProtein.from_pdb("target.pdb") + +# Remove sequence, keep structure +target.sequence = None + +# Generate sequence that folds to this structure +config = GenerationConfig( + track="sequence", + num_steps=50, + temperature=0.7, + condition_on_coordinates_only=True +) +designed = model.generate(target, config) + +print(f"Designed sequence: {designed.sequence}") +``` + +### 4. Function-Conditioned Generation + +Generate protein with specific function: + +```python +from esm.sdk.api import FunctionAnnotation + +# Specify desired function +protein = ESMProtein( + sequence="_" * 150, + function_annotations=[ + FunctionAnnotation( + label="enzymatic_activity", + start=30, + end=90 + ) + ] +) + +# Generate sequence with this function +config = GenerationConfig(track="sequence", num_steps=75, temperature=0.6) +functional_protein = model.generate(protein, config) +``` + +### 5. Multi-Track Generation (Chain-of-Thought) + +Iteratively generate across multiple tracks: + +```python +# Start with partial sequence +protein = ESMProtein(sequence="MPRT" + "_" * 100) + +# Step 1: Complete sequence +protein = model.generate( + protein, + GenerationConfig(track="sequence", num_steps=50, temperature=0.6) +) + +# Step 2: Predict structure for completed sequence +protein = model.generate( + protein, + GenerationConfig(track="structure", num_steps=50) +) + +# Step 3: Predict function +protein = model.generate( + protein, + GenerationConfig(track="function", num_steps=20) +) + +print(f"Final sequence: {protein.sequence}") +print(f"Functions: {protein.function_annotations}") +``` + +### 6. Variant Generation + +Generate multiple variants of a protein: + +```python +import numpy as np + +base_sequence = "MPRTKEINDAGLIVHSPQWFYK" +variants = [] + +for i in range(10): + # Mask random positions + seq_list = list(base_sequence) + mask_indices = np.random.choice(len(seq_list), size=5, replace=False) + for idx in mask_indices: + seq_list[idx] = '_' + + protein = ESMProtein(sequence=''.join(seq_list)) + + # Generate variant + variant = model.generate( + protein, + GenerationConfig(track="sequence", num_steps=8, temperature=0.8) + ) + variants.append(variant.sequence) + +print(f"Generated {len(variants)} variants") +``` + +## Advanced Topics + +### Temperature Scheduling + +Vary temperature during generation for better control: + +```python +def generate_with_temperature_schedule(model, protein, temperatures): + """Generate with decreasing temperature for annealing.""" + current = protein + steps_per_temp = 10 + + for temp in temperatures: + config = GenerationConfig( + track="sequence", + num_steps=steps_per_temp, + temperature=temp + ) + current = model.generate(current, config) + + return current + +# Example: Start diverse, end deterministic +result = generate_with_temperature_schedule( + model, + protein, + temperatures=[1.0, 0.8, 0.6, 0.4, 0.2] +) +``` + +### Constrained Generation + +Preserve specific regions during generation: + +```python +# Keep active site residues fixed +def mask_except_active_site(sequence, active_site_positions): + """Mask everything except specified positions.""" + seq_list = ['_'] * len(sequence) + for pos in active_site_positions: + seq_list[pos] = sequence[pos] + return ''.join(seq_list) + +# Define active site +active_site = [23, 24, 25, 45, 46, 89] +constrained_seq = mask_except_active_site(original_sequence, active_site) + +protein = ESMProtein(sequence=constrained_seq) +result = model.generate(protein, GenerationConfig(track="sequence", num_steps=50)) +``` + +### Secondary Structure Conditioning + +Use secondary structure information in generation: + +```python +# Define secondary structure (H=helix, E=sheet, C=coil) +protein = ESMProtein( + sequence="_" * 80, + secondary_structure="CCHHHHHHHEEEEECCCHHHHHHCC" + "C" * 55 +) + +# Generate sequence with this structure +result = model.generate( + protein, + GenerationConfig(track="sequence", num_steps=40, temperature=0.6) +) +``` + +## Performance Optimization + +### Memory Management + +For large proteins or batch processing: + +```python +import torch + +# Clear CUDA cache between generations +torch.cuda.empty_cache() + +# Use half precision for memory efficiency +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda").half() + +# Process in chunks for very long sequences +def chunk_generate(model, long_sequence, chunk_size=500): + chunks = [long_sequence[i:i+chunk_size] + for i in range(0, len(long_sequence), chunk_size)] + results = [] + + for chunk in chunks: + protein = ESMProtein(sequence=chunk) + result = model.generate(protein, GenerationConfig(track="sequence")) + results.append(result.sequence) + + return ''.join(results) +``` + +### Batch Processing Tips + +When processing multiple proteins: + +1. Sort by sequence length for efficient batching +2. Use padding for similar-length sequences +3. Process on GPU when available +4. Implement checkpointing for long-running jobs +5. Use Forge API for large-scale processing (see `forge-api.md`) + +## Error Handling + +```python +try: + protein = model.generate(protein_input, config) +except ValueError as e: + print(f"Invalid input: {e}") + # Handle invalid sequence or structure +except RuntimeError as e: + print(f"Generation failed: {e}") + # Handle model errors +except torch.cuda.OutOfMemoryError: + print("GPU out of memory - try smaller model or CPU") + # Fallback to CPU or smaller model +``` + +## Model-Specific Considerations + +**esm3-sm-open-v1:** +- Suitable for development and testing +- Lower quality than larger models +- Fast inference on consumer GPUs +- Open weights allow fine-tuning + +**esm3-medium-2024-08:** +- Production quality +- Good balance of speed and accuracy +- Requires Forge API access +- Recommended for most applications + +**esm3-large-2024-03:** +- State-of-the-art quality +- Slowest inference +- Use for critical applications +- Best for novel protein design + +## Citation + +If using ESM3 in research, cite: + +``` +Hayes, T. et al. (2025). Simulating 500 million years of evolution with a language model. +Science. DOI: 10.1126/science.ads0018 +``` diff --git a/scientific-packages/esm/references/forge-api.md b/scientific-packages/esm/references/forge-api.md new file mode 100644 index 0000000..ba4e0ca --- /dev/null +++ b/scientific-packages/esm/references/forge-api.md @@ -0,0 +1,657 @@ +# Forge API Reference + +## Overview + +Forge is EvolutionaryScale's cloud platform for scalable protein design and inference. It provides API access to the full ESM3 model family, including large models not available for local execution. + +**Key Benefits:** +- Access to all ESM3 models including 98B parameter version +- No local GPU requirements +- Scalable batch processing +- Automatic updates to latest models +- Production-ready infrastructure +- Async/concurrent request support + +## Getting Started + +### 1. Obtain API Token + +Sign up and get your API token at: https://forge.evolutionaryscale.ai + +### 2. Install ESM SDK + +```bash +pip install esm +``` + +The Forge client is included in the standard ESM package. + +### 3. Basic Connection + +```python +from esm.sdk.forge import ESM3ForgeInferenceClient +from esm.sdk.api import ESMProtein, GenerationConfig + +# Initialize client +client = ESM3ForgeInferenceClient( + model="esm3-medium-2024-08", + url="https://forge.evolutionaryscale.ai", + token="" +) + +# Test connection +protein = ESMProtein(sequence="MPRT___KEND") +result = client.generate(protein, GenerationConfig(track="sequence", num_steps=8)) +print(result.sequence) +``` + +## Available Models + +| Model ID | Parameters | Speed | Quality | Use Case | +|----------|-----------|-------|---------|----------| +| `esm3-small-2024-08` | 1.4B | Fastest | Good | Rapid prototyping, testing | +| `esm3-medium-2024-08` | 7B | Fast | Excellent | Production, most applications | +| `esm3-large-2024-03` | 98B | Slower | Best | Research, critical designs | +| `esm3-medium-multimer-2024-09` | 7B | Fast | Experimental | Protein complexes | + +**Model Selection Guidelines:** + +- **Development/Testing**: Use `esm3-small-2024-08` for quick iteration +- **Production**: Use `esm3-medium-2024-08` for best balance +- **Research/Critical**: Use `esm3-large-2024-03` for highest quality +- **Complexes**: Use `esm3-medium-multimer-2024-09` (experimental) + +## ESM3ForgeInferenceClient API + +### Initialization + +```python +from esm.sdk.forge import ESM3ForgeInferenceClient + +# Basic initialization +client = ESM3ForgeInferenceClient( + model="esm3-medium-2024-08", + token="" +) + +# With custom URL (for enterprise deployments) +client = ESM3ForgeInferenceClient( + model="esm3-medium-2024-08", + url="https://custom.forge.instance.com", + token="" +) + +# With timeout configuration +client = ESM3ForgeInferenceClient( + model="esm3-medium-2024-08", + token="", + timeout=300 # 5 minutes +) +``` + +### Synchronous Generation + +Standard blocking generation calls: + +```python +from esm.sdk.api import ESMProtein, GenerationConfig + +# Basic generation +protein = ESMProtein(sequence="MPRT___KEND") +config = GenerationConfig(track="sequence", num_steps=8) + +result = client.generate(protein, config) +print(f"Generated: {result.sequence}") +``` + +### Asynchronous Generation + +For concurrent processing of multiple proteins: + +```python +import asyncio +from esm.sdk.api import ESMProtein, GenerationConfig + +async def generate_many(client, proteins): + """Generate multiple proteins concurrently.""" + tasks = [] + + for protein in proteins: + task = client.async_generate( + protein, + GenerationConfig(track="sequence", num_steps=8) + ) + tasks.append(task) + + results = await asyncio.gather(*tasks) + return results + +# Usage +proteins = [ + ESMProtein(sequence=f"MPRT{'_' * 10}KEND"), + ESMProtein(sequence=f"AGLV{'_' * 10}HSPQ"), + ESMProtein(sequence=f"KEIT{'_' * 10}NDFL") +] + +results = asyncio.run(generate_many(client, proteins)) +print(f"Generated {len(results)} proteins") +``` + +### Batch Processing with BatchExecutor + +For large-scale processing with automatic concurrency management: + +```python +from esm.sdk.forge import BatchExecutor +from esm.sdk.api import ESMProtein, GenerationConfig + +# Create batch executor +executor = BatchExecutor( + client=client, + max_concurrent=10 # Process 10 requests concurrently +) + +# Prepare batch of proteins +proteins = [ESMProtein(sequence=f"MPRT{'_' * 50}KEND") for _ in range(100)] +config = GenerationConfig(track="sequence", num_steps=25) + +# Submit batch +batch_results = executor.submit_batch( + proteins=proteins, + config=config, + progress_callback=lambda i, total: print(f"Processed {i}/{total}") +) + +print(f"Completed {len(batch_results)} generations") +``` + +## Rate Limiting and Quotas + +### Understanding Limits + +Forge implements rate limiting based on: +- Requests per minute (RPM) +- Tokens per minute (TPM) +- Concurrent requests + +**Typical Limits (subject to change):** +- Free tier: 60 RPM, 5 concurrent +- Pro tier: 300 RPM, 20 concurrent +- Enterprise: Custom limits + +### Handling Rate Limits + +```python +import time +from requests.exceptions import HTTPError + +def generate_with_retry(client, protein, config, max_retries=3): + """Generate with automatic retry on rate limit.""" + for attempt in range(max_retries): + try: + return client.generate(protein, config) + except HTTPError as e: + if e.response.status_code == 429: # Rate limit + wait_time = 2 ** attempt # Exponential backoff + print(f"Rate limited, waiting {wait_time}s...") + time.sleep(wait_time) + else: + raise + raise Exception("Max retries exceeded") + +# Usage +result = generate_with_retry(client, protein, config) +``` + +### Implementing Custom Rate Limiter + +```python +import time +from collections import deque + +class RateLimiter: + """Simple rate limiter for API calls.""" + + def __init__(self, max_per_minute=60): + self.max_per_minute = max_per_minute + self.calls = deque() + + def wait_if_needed(self): + """Wait if rate limit would be exceeded.""" + now = time.time() + + # Remove old calls + while self.calls and self.calls[0] < now - 60: + self.calls.popleft() + + # Wait if at limit + if len(self.calls) >= self.max_per_minute: + sleep_time = 60 - (now - self.calls[0]) + if sleep_time > 0: + time.sleep(sleep_time) + self.calls.popleft() + + self.calls.append(now) + +# Usage +limiter = RateLimiter(max_per_minute=60) + +for protein in proteins: + limiter.wait_if_needed() + result = client.generate(protein, config) +``` + +## Advanced Patterns + +### Streaming Results + +Process results as they complete: + +```python +import asyncio +from concurrent.futures import ThreadPoolExecutor + +async def stream_generate(client, proteins, config): + """Stream results as they complete.""" + pending = { + asyncio.create_task(client.async_generate(p, config)): i + for i, p in enumerate(proteins) + } + + results = [None] * len(proteins) + + while pending: + done, pending = await asyncio.wait( + pending.keys(), + return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + idx = pending.pop(task) + result = await task + results[idx] = result + yield idx, result + +# Usage +async def process_stream(): + async for idx, result in stream_generate(client, proteins, config): + print(f"Completed protein {idx}: {result.sequence[:20]}...") + +asyncio.run(process_stream()) +``` + +### Batch with Progress Tracking + +```python +from tqdm import tqdm +import asyncio + +async def batch_with_progress(client, proteins, config): + """Process batch with progress bar.""" + results = [] + + with tqdm(total=len(proteins)) as pbar: + for protein in proteins: + result = await client.async_generate(protein, config) + results.append(result) + pbar.update(1) + + return results + +# Usage +results = asyncio.run(batch_with_progress(client, proteins, config)) +``` + +### Checkpoint and Resume + +For long-running batch jobs: + +```python +import pickle +import os + +class CheckpointedBatchProcessor: + """Batch processor with checkpoint/resume capability.""" + + def __init__(self, client, checkpoint_file="checkpoint.pkl"): + self.client = client + self.checkpoint_file = checkpoint_file + self.completed = self.load_checkpoint() + + def load_checkpoint(self): + if os.path.exists(self.checkpoint_file): + with open(self.checkpoint_file, 'rb') as f: + return pickle.load(f) + return {} + + def save_checkpoint(self): + with open(self.checkpoint_file, 'wb') as f: + pickle.dump(self.completed, f) + + def process_batch(self, proteins, config): + """Process batch with checkpointing.""" + results = {} + + for i, protein in enumerate(proteins): + # Skip if already completed + if i in self.completed: + results[i] = self.completed[i] + continue + + try: + result = self.client.generate(protein, config) + results[i] = result + self.completed[i] = result + + # Save checkpoint every 10 items + if i % 10 == 0: + self.save_checkpoint() + + except Exception as e: + print(f"Error processing {i}: {e}") + self.save_checkpoint() + raise + + self.save_checkpoint() + return results + +# Usage +processor = CheckpointedBatchProcessor(client) +results = processor.process_batch(proteins, config) +``` + +## Error Handling + +### Common Errors and Solutions + +```python +from requests.exceptions import HTTPError, ConnectionError, Timeout + +def robust_generate(client, protein, config): + """Generate with comprehensive error handling.""" + try: + return client.generate(protein, config) + + except HTTPError as e: + if e.response.status_code == 401: + raise ValueError("Invalid API token") + elif e.response.status_code == 429: + raise ValueError("Rate limit exceeded - slow down requests") + elif e.response.status_code == 500: + raise ValueError("Server error - try again later") + else: + raise + + except ConnectionError: + raise ValueError("Network error - check internet connection") + + except Timeout: + raise ValueError("Request timeout - try smaller protein or increase timeout") + + except Exception as e: + raise ValueError(f"Unexpected error: {str(e)}") + +# Usage with retry logic +def generate_with_full_retry(client, protein, config, max_retries=3): + """Combine error handling with retry logic.""" + for attempt in range(max_retries): + try: + return robust_generate(client, protein, config) + except ValueError as e: + if "rate limit" in str(e).lower() and attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise +``` + +## Cost Optimization + +### Strategies to Reduce Costs + +**1. Use Appropriate Model Size:** + +```python +# Use smaller model for testing +dev_client = ESM3ForgeInferenceClient( + model="esm3-small-2024-08", + token=token +) + +# Use larger model only for final generation +prod_client = ESM3ForgeInferenceClient( + model="esm3-large-2024-03", + token=token +) +``` + +**2. Cache Results:** + +```python +import hashlib +import json + +class ForgeCache: + """Cache Forge API results locally.""" + + def __init__(self, cache_dir="forge_cache"): + self.cache_dir = cache_dir + os.makedirs(cache_dir, exist_ok=True) + + def get_cache_key(self, protein, config): + """Generate cache key from inputs.""" + data = { + 'sequence': protein.sequence, + 'config': str(config) + } + return hashlib.md5(json.dumps(data, sort_keys=True).encode()).hexdigest() + + def get(self, protein, config): + """Get cached result.""" + key = self.get_cache_key(protein, config) + path = os.path.join(self.cache_dir, f"{key}.pkl") + + if os.path.exists(path): + with open(path, 'rb') as f: + return pickle.load(f) + return None + + def set(self, protein, config, result): + """Cache result.""" + key = self.get_cache_key(protein, config) + path = os.path.join(self.cache_dir, f"{key}.pkl") + + with open(path, 'wb') as f: + pickle.dump(result, f) + +# Usage +cache = ForgeCache() + +def cached_generate(client, protein, config): + """Generate with caching.""" + cached = cache.get(protein, config) + if cached: + return cached + + result = client.generate(protein, config) + cache.set(protein, config, result) + return result +``` + +**3. Batch Similar Requests:** + +Group similar generation tasks to reduce overhead: + +```python +def batch_similar_tasks(proteins, max_batch_size=50): + """Group proteins by similar properties.""" + # Sort by length for efficient processing + sorted_proteins = sorted(proteins, key=lambda p: len(p.sequence)) + + batches = [] + current_batch = [] + + for protein in sorted_proteins: + current_batch.append(protein) + + if len(current_batch) >= max_batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches +``` + +## Monitoring and Logging + +### Track API Usage + +```python +import logging +from datetime import datetime + +class ForgeMonitor: + """Monitor Forge API usage.""" + + def __init__(self): + self.calls = [] + self.errors = [] + + def log_call(self, model, protein_length, duration, success=True, error=None): + """Log API call.""" + entry = { + 'timestamp': datetime.now(), + 'model': model, + 'protein_length': protein_length, + 'duration': duration, + 'success': success, + 'error': str(error) if error else None + } + + if success: + self.calls.append(entry) + else: + self.errors.append(entry) + + def get_stats(self): + """Get usage statistics.""" + total_calls = len(self.calls) + len(self.errors) + success_rate = len(self.calls) / total_calls if total_calls > 0 else 0 + avg_duration = sum(c['duration'] for c in self.calls) / len(self.calls) if self.calls else 0 + + return { + 'total_calls': total_calls, + 'successful': len(self.calls), + 'failed': len(self.errors), + 'success_rate': success_rate, + 'avg_duration': avg_duration + } + +# Usage +monitor = ForgeMonitor() + +def monitored_generate(client, protein, config): + """Generate with monitoring.""" + start = time.time() + + try: + result = client.generate(protein, config) + duration = time.time() - start + monitor.log_call( + model=client.model, + protein_length=len(protein.sequence), + duration=duration, + success=True + ) + return result + + except Exception as e: + duration = time.time() - start + monitor.log_call( + model=client.model, + protein_length=len(protein.sequence), + duration=duration, + success=False, + error=e + ) + raise + +# Check stats +print(monitor.get_stats()) +``` + +## AWS SageMaker Deployment + +For dedicated infrastructure and enterprise use: + +### Deployment Options + +1. **AWS Marketplace Listing**: Deploy ESM3 via AWS SageMaker Marketplace +2. **Custom Endpoint**: Configure dedicated inference endpoint +3. **Batch Transform**: Use SageMaker Batch Transform for large-scale processing + +### Benefits + +- Dedicated compute resources +- No rate limiting beyond your infrastructure +- Data stays in your AWS environment +- Integration with AWS services +- Custom instance types and scaling + +**More Information:** +- AWS Marketplace: https://aws.amazon.com/marketplace/seller-profile?id=seller-iw2nbscescndm +- Contact EvolutionaryScale for enterprise licensing + +## Best Practices Summary + +1. **Authentication**: Store tokens securely (environment variables, secrets manager) +2. **Rate Limiting**: Implement exponential backoff and respect limits +3. **Error Handling**: Always handle network errors and retries +4. **Caching**: Cache results for repeated queries +5. **Model Selection**: Use appropriate model size for task +6. **Batch Processing**: Use async/batch processing for multiple proteins +7. **Monitoring**: Track usage and costs +8. **Checkpointing**: Save progress for long-running jobs + +## Troubleshooting + +### Connection Issues + +```python +# Test connection +try: + client = ESM3ForgeInferenceClient(model="esm3-medium-2024-08", token=token) + test_protein = ESMProtein(sequence="MPRTK") + result = client.generate(test_protein, GenerationConfig(track="sequence", num_steps=1)) + print("Connection successful!") +except Exception as e: + print(f"Connection failed: {e}") +``` + +### Token Validation + +```python +def validate_token(token): + """Validate API token.""" + try: + client = ESM3ForgeInferenceClient( + model="esm3-small-2024-08", + token=token + ) + # Make minimal test call + test = ESMProtein(sequence="MPR") + client.generate(test, GenerationConfig(track="sequence", num_steps=1)) + return True + except HTTPError as e: + if e.response.status_code == 401: + return False + raise +``` + +## Additional Resources + +- **Forge Platform**: https://forge.evolutionaryscale.ai +- **API Documentation**: Check Forge dashboard for latest API specs +- **Community Support**: Slack community at https://bit.ly/3FKwcWd +- **Enterprise Contact**: Contact EvolutionaryScale for custom deployments diff --git a/scientific-packages/esm/references/workflows.md b/scientific-packages/esm/references/workflows.md new file mode 100644 index 0000000..8c3716e --- /dev/null +++ b/scientific-packages/esm/references/workflows.md @@ -0,0 +1,685 @@ +# ESM Workflows and Examples + +## Overview + +This document provides complete, end-to-end examples of common workflows using ESM3 and ESM C. Each workflow includes setup, execution, and analysis code. + +## Workflow 1: Novel GFP Design with Chain-of-Thought + +Design a novel fluorescent protein using ESM3's multimodal generation capabilities. + +### Objective + +Generate a green fluorescent protein (GFP) with specific properties using chain-of-thought reasoning across sequence, structure, and function. + +### Complete Implementation + +```python +from esm.models.esm3 import ESM3 +from esm.sdk.api import ESMProtein, GenerationConfig, FunctionAnnotation +import matplotlib.pyplot as plt + +# Setup +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") + +# Step 1: Define target properties +print("Step 1: Defining target GFP properties...") + +# Create protein with desired function +target_length = 238 # Typical GFP length +protein = ESMProtein( + sequence="_" * target_length, + function_annotations=[ + FunctionAnnotation( + label="green_fluorescent_protein", + start=65, + end=75 # Chromophore region + ) + ] +) + +# Step 2: Generate initial sequence with function conditioning +print("Step 2: Generating initial sequence...") + +config = GenerationConfig( + track="sequence", + num_steps=target_length // 3, # Gradual generation + temperature=0.7 # Moderate diversity +) +protein = model.generate(protein, config) +print(f"Generated sequence: {protein.sequence[:50]}...") + +# Step 3: Predict structure +print("Step 3: Predicting structure...") + +config = GenerationConfig( + track="structure", + num_steps=target_length // 2 +) +protein = model.generate(protein, config) +print(f"Structure predicted, coordinates shape: {protein.coordinates.shape}") + +# Step 4: Refine sequence based on structure +print("Step 4: Refining sequence based on structure...") + +# Mask regions for refinement (e.g., surface residues) +sequence_list = list(protein.sequence) +# Keep chromophore region, refine others +for i in range(0, 65): + if i % 3 == 0: # Refine every third position + sequence_list[i] = '_' +for i in range(75, target_length): + if i % 3 == 0: + sequence_list[i] = '_' + +protein.sequence = ''.join(sequence_list) + +config = GenerationConfig( + track="sequence", + num_steps=50, + temperature=0.5 # Lower temperature for refinement +) +protein = model.generate(protein, config) + +# Step 5: Final validation +print("Step 5: Final validation...") + +# Predict final structure +config = GenerationConfig(track="structure", num_steps=30) +protein = model.generate(protein, config) + +# Save results +with open("novel_gfp.pdb", "w") as f: + f.write(protein.to_pdb()) + +with open("novel_gfp_sequence.txt", "w") as f: + f.write(f">Novel_GFP\n{protein.sequence}\n") + +print(f"\nFinal GFP sequence:\n{protein.sequence}") +print(f"\nFunction annotations: {protein.function_annotations}") +print(f"Structure saved to: novel_gfp.pdb") +``` + +### Validation Steps + +```python +# Analyze designed GFP +def analyze_gfp(protein): + """Analyze generated GFP properties.""" + + # Check chromophore region (should be around Ser65-Tyr66-Gly67) + chromophore_region = protein.sequence[64:68] + print(f"Chromophore region: {chromophore_region}") + + # Check barrel structure (GFPs have beta-barrel) + # Analyze secondary structure if available + if protein.secondary_structure: + beta_content = protein.secondary_structure.count('E') / len(protein.sequence) + print(f"Beta sheet content: {beta_content:.2%}") + + # Check sequence similarity to known GFPs + # (Would require BLAST or alignment tool in practice) + + return { + 'length': len(protein.sequence), + 'chromophore': chromophore_region, + 'coordinates_available': protein.coordinates is not None + } + +analysis = analyze_gfp(protein) +print(f"\nAnalysis results: {analysis}") +``` + +## Workflow 2: Protein Variant Library Generation + +Generate and analyze a library of protein variants for directed evolution. + +### Objective + +Create variants of a parent protein by targeted mutagenesis while maintaining structural integrity. + +### Complete Implementation + +```python +from esm.models.esm3 import ESM3 +from esm.sdk.api import ESMProtein, GenerationConfig +import numpy as np +from sklearn.cluster import KMeans + +# Setup +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") + +# Parent protein +parent_sequence = "MPRTKEINDAGLIVHSPQWFYKARNDTESLGKIVHEFPM" +parent_protein = ESMProtein(sequence=parent_sequence) + +# Define mutation parameters +num_variants = 50 +positions_to_mutate = 5 # Number of positions per variant + +# Step 1: Generate variant library +print("Generating variant library...") + +variants = [] +for i in range(num_variants): + # Create masked sequence with random positions + seq_list = list(parent_sequence) + + # Select random positions to mutate + mutation_positions = np.random.choice( + len(seq_list), + size=positions_to_mutate, + replace=False + ) + + for pos in mutation_positions: + seq_list[pos] = '_' + + # Generate variant + variant_protein = ESMProtein(sequence=''.join(seq_list)) + + config = GenerationConfig( + track="sequence", + num_steps=positions_to_mutate * 2, + temperature=0.8 # Higher diversity + ) + + variant = model.generate(variant_protein, config) + variants.append(variant.sequence) + + if (i + 1) % 10 == 0: + print(f"Generated {i + 1}/{num_variants} variants") + +print(f"\nGenerated {len(variants)} variants") + +# Step 2: Predict structures for variants +print("\nPredicting structures...") + +variant_proteins_with_structure = [] +for i, seq in enumerate(variants): + protein = ESMProtein(sequence=seq) + + config = GenerationConfig( + track="structure", + num_steps=len(seq) // 2 + ) + + protein_with_structure = model.generate(protein, config) + variant_proteins_with_structure.append(protein_with_structure) + + if (i + 1) % 10 == 0: + print(f"Predicted structures for {i + 1}/{len(variants)} variants") + +# Step 3: Analyze variant diversity +print("\nAnalyzing variant diversity...") + +# Calculate Hamming distances from parent +def hamming_distance(seq1, seq2): + """Calculate Hamming distance between sequences.""" + return sum(c1 != c2 for c1, c2 in zip(seq1, seq2)) + +distances = [hamming_distance(parent_sequence, var) for var in variants] +print(f"Average mutations per variant: {np.mean(distances):.1f}") +print(f"Mutation range: {min(distances)}-{max(distances)}") + +# Step 4: Get embeddings for clustering +print("\nGenerating embeddings for clustering...") + +from esm.models.esmc import ESMC + +embedding_model = ESMC.from_pretrained("esmc-300m").to("cuda") + +def get_embedding(sequence): + """Get mean-pooled embedding for sequence.""" + protein = ESMProtein(sequence=sequence) + tensor = embedding_model.encode(protein) + emb = embedding_model.forward(tensor) + return emb.mean(dim=1).cpu().detach().numpy().flatten() + +variant_embeddings = np.array([get_embedding(seq) for seq in variants]) + +# Step 5: Cluster variants +print("Clustering variants...") + +n_clusters = 5 +kmeans = KMeans(n_clusters=n_clusters, random_state=42) +cluster_labels = kmeans.fit_predict(variant_embeddings) + +# Analyze clusters +print("\nCluster analysis:") +for i in range(n_clusters): + cluster_variants = [var for var, label in zip(variants, cluster_labels) if label == i] + cluster_distances = [hamming_distance(parent_sequence, var) for var in cluster_variants] + + print(f"\nCluster {i}:") + print(f" Size: {len(cluster_variants)}") + print(f" Avg distance from parent: {np.mean(cluster_distances):.1f}") + print(f" Representative: {cluster_variants[0][:40]}...") + +# Step 6: Select diverse representatives +print("\nSelecting diverse representatives...") + +representatives = [] +for i in range(n_clusters): + # Get centroid + cluster_indices = np.where(cluster_labels == i)[0] + cluster_embs = variant_embeddings[cluster_indices] + + # Find closest to centroid + centroid = cluster_embs.mean(axis=0) + distances_to_centroid = np.linalg.norm(cluster_embs - centroid, axis=1) + rep_idx = cluster_indices[np.argmin(distances_to_centroid)] + + representatives.append(variants[rep_idx]) + +# Save results +print("\nSaving results...") + +with open("variant_library.fasta", "w") as f: + f.write(f">Parent\n{parent_sequence}\n\n") + for i, var in enumerate(variants): + f.write(f">Variant_{i+1}_Cluster_{cluster_labels[i]}\n{var}\n") + +with open("representative_variants.fasta", "w") as f: + for i, rep in enumerate(representatives): + f.write(f">Representative_Cluster_{i}\n{rep}\n") + +print("Variant library saved to: variant_library.fasta") +print("Representatives saved to: representative_variants.fasta") +``` + +## Workflow 3: Structure-Based Sequence Optimization + +Optimize a protein sequence to improve stability while maintaining function. + +### Objective + +Given a protein structure, design sequences that maintain the fold but have improved properties. + +### Complete Implementation + +```python +from esm.models.esm3 import ESM3 +from esm.sdk.api import ESMProtein, GenerationConfig +import numpy as np + +# Setup +model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") + +# Load target structure (e.g., from PDB) +target_protein = ESMProtein.from_pdb("target_structure.pdb") +original_sequence = target_protein.sequence + +print(f"Original sequence: {original_sequence}") +print(f"Structure loaded: {target_protein.coordinates.shape}") + +# Step 1: Generate multiple sequence designs +print("\nGenerating optimized sequences...") + +num_designs = 20 +optimized_sequences = [] + +for i in range(num_designs): + # Start with structure, remove sequence + design_protein = ESMProtein( + coordinates=target_protein.coordinates.copy(), + secondary_structure=target_protein.secondary_structure + ) + + # Generate sequence for this structure + config = GenerationConfig( + track="sequence", + num_steps=len(original_sequence), + temperature=0.7, + condition_on_coordinates_only=True + ) + + designed = model.generate(design_protein, config) + optimized_sequences.append(designed.sequence) + + if (i + 1) % 5 == 0: + print(f"Generated {i + 1}/{num_designs} designs") + +# Step 2: Validate structural compatibility +print("\nValidating structural compatibility...") + +validated_designs = [] + +for seq in optimized_sequences: + # Predict structure for designed sequence + test_protein = ESMProtein(sequence=seq) + + config = GenerationConfig( + track="structure", + num_steps=len(seq) // 2 + ) + + predicted = model.generate(test_protein, config) + + # Calculate RMSD (simplified - in practice use proper alignment) + # Here we just check if structure prediction succeeds + if predicted.coordinates is not None: + validated_designs.append(seq) + +print(f"Validated {len(validated_designs)}/{num_designs} designs") + +# Step 3: Analyze sequence properties +print("\nAnalyzing sequence properties...") + +def calculate_properties(sequence): + """Calculate basic sequence properties.""" + # Hydrophobicity (simplified) + hydrophobic = "AILMFWYV" + hydrophobic_fraction = sum(1 for aa in sequence if aa in hydrophobic) / len(sequence) + + # Charge + positive = "KR" + negative = "DE" + net_charge = sum(1 for aa in sequence if aa in positive) - sum(1 for aa in sequence if aa in negative) + + # Aromatic content + aromatic = "FWY" + aromatic_fraction = sum(1 for aa in sequence if aa in aromatic) / len(sequence) + + return { + 'hydrophobic_fraction': hydrophobic_fraction, + 'net_charge': net_charge, + 'aromatic_fraction': aromatic_fraction + } + +# Compare to original +original_props = calculate_properties(original_sequence) +print(f"\nOriginal properties:") +print(f" Hydrophobic: {original_props['hydrophobic_fraction']:.2%}") +print(f" Net charge: {original_props['net_charge']:+d}") +print(f" Aromatic: {original_props['aromatic_fraction']:.2%}") + +# Analyze designs +design_properties = [calculate_properties(seq) for seq in validated_designs] + +avg_hydrophobic = np.mean([p['hydrophobic_fraction'] for p in design_properties]) +avg_charge = np.mean([p['net_charge'] for p in design_properties]) +avg_aromatic = np.mean([p['aromatic_fraction'] for p in design_properties]) + +print(f"\nDesigned sequences (average):") +print(f" Hydrophobic: {avg_hydrophobic:.2%}") +print(f" Net charge: {avg_charge:+.1f}") +print(f" Aromatic: {avg_aromatic:.2%}") + +# Step 4: Rank designs +print("\nRanking designs...") + +def score_design(sequence, original_props): + """Score design based on desired properties.""" + props = calculate_properties(sequence) + + # Prefer higher hydrophobic content (for stability) + hydrophobic_score = props['hydrophobic_fraction'] + + # Prefer similar charge to original + charge_score = 1.0 / (1.0 + abs(props['net_charge'] - original_props['net_charge'])) + + # Combined score + return hydrophobic_score * 0.6 + charge_score * 0.4 + +scores = [(seq, score_design(seq, original_props)) for seq in validated_designs] +scores.sort(key=lambda x: x[1], reverse=True) + +print("\nTop 5 designs:") +for i, (seq, score) in enumerate(scores[:5]): + print(f"\n{i+1}. Score: {score:.3f}") + print(f" Sequence: {seq[:40]}...") + +# Step 5: Save results +print("\nSaving results...") + +with open("optimized_sequences.fasta", "w") as f: + f.write(f">Original\n{original_sequence}\n\n") + + for i, (seq, score) in enumerate(scores): + props = calculate_properties(seq) + f.write(f">Design_{i+1}_Score_{score:.3f}\n") + f.write(f"# Hydrophobic: {props['hydrophobic_fraction']:.2%}, ") + f.write(f"Charge: {props['net_charge']:+d}, ") + f.write(f"Aromatic: {props['aromatic_fraction']:.2%}\n") + f.write(f"{seq}\n\n") + +print("Results saved to: optimized_sequences.fasta") +``` + +## Workflow 4: Function Prediction Pipeline + +Predict protein function from sequence using ESM3 and ESM C. + +### Objective + +Build a pipeline that predicts protein function using both generative (ESM3) and embedding (ESM C) approaches. + +### Complete Implementation + +```python +from esm.models.esm3 import ESM3 +from esm.models.esmc import ESMC +from esm.sdk.api import ESMProtein, GenerationConfig +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_val_score + +# Setup models +esm3_model = ESM3.from_pretrained("esm3-sm-open-v1").to("cuda") +esmc_model = ESMC.from_pretrained("esmc-600m").to("cuda") + +# Example: Predict if protein is an enzyme +# (In practice, you'd have a labeled training set) + +def predict_function_generative(sequence): + """Predict function using ESM3 generative approach.""" + + protein = ESMProtein(sequence=sequence) + + # Generate function annotations + config = GenerationConfig( + track="function", + num_steps=20, + temperature=0.3 # Low temperature for confident predictions + ) + + protein_with_function = esm3_model.generate(protein, config) + + return protein_with_function.function_annotations + +def predict_function_embedding(sequence, function_classifier): + """Predict function using ESM C embeddings + classifier.""" + + # Get embedding + protein = ESMProtein(sequence=sequence) + tensor = esmc_model.encode(protein) + embedding = esmc_model.forward(tensor) + + # Mean pool + embedding_pooled = embedding.mean(dim=1).cpu().detach().numpy() + + # Predict with classifier + prediction = function_classifier.predict(embedding_pooled) + probability = function_classifier.predict_proba(embedding_pooled) + + return prediction[0], probability[0] + +# Example workflow with test sequences +test_sequences = { + "kinase": "MPRTKEINDAGLIVHSPQWFYKARNDTESLGKIVHEF", + "protease": "AGLIVHSPQWFYKARNDTESLGKIVHEFPMCDEGH", + "transporter": "KTEFLNDGRPMLIVHSPQWFYKARNDTESLGKIVH" +} + +print("Predicting functions...\n") + +for name, sequence in test_sequences.items(): + print(f"{name.upper()}:") + print(f"Sequence: {sequence[:30]}...") + + # Method 1: Generative + functions = predict_function_generative(sequence) + print(f" Generative predictions: {functions}") + + # Method 2: Embedding-based would require trained classifier + # (Skipped in this example as it needs training data) + + print() +``` + +## Workflow 5: Embedding-Based Clustering and Analysis + +Cluster and analyze a large protein dataset using ESM C embeddings. + +### Complete Implementation + +```python +from esm.models.esmc import ESMC +from esm.sdk.api import ESMProtein +import numpy as np +from sklearn.cluster import DBSCAN +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt + +# Setup +model = ESMC.from_pretrained("esmc-600m").to("cuda") + +# Load protein dataset (example) +sequences = [ + # In practice, load from FASTA or database + "MPRTKEINDAGLIVHSPQWFYK", + "AGLIVHSPQWFYKARNDTESL", + # ... more sequences +] + +print(f"Loaded {len(sequences)} sequences") + +# Step 1: Generate embeddings +print("Generating embeddings...") + +embeddings = [] +for i, seq in enumerate(sequences): + protein = ESMProtein(sequence=seq) + tensor = model.encode(protein) + emb = model.forward(tensor) + + # Mean pooling + emb_pooled = emb.mean(dim=1).cpu().detach().numpy().flatten() + embeddings.append(emb_pooled) + + if (i + 1) % 100 == 0: + print(f"Processed {i + 1}/{len(sequences)}") + +embeddings = np.array(embeddings) +print(f"Embeddings shape: {embeddings.shape}") + +# Step 2: Dimensionality reduction for visualization +print("\nReducing dimensionality...") + +# PCA for initial reduction +pca = PCA(n_components=50) +embeddings_pca = pca.fit_transform(embeddings) +print(f"PCA explained variance: {pca.explained_variance_ratio_[:10].sum():.2%}") + +# t-SNE for visualization +tsne = TSNE(n_components=2, random_state=42) +embeddings_2d = tsne.fit_transform(embeddings_pca) + +# Step 3: Clustering +print("\nClustering...") + +# DBSCAN for density-based clustering +clustering = DBSCAN(eps=0.5, min_samples=5) +cluster_labels = clustering.fit_predict(embeddings) + +n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) +n_noise = list(cluster_labels).count(-1) + +print(f"Number of clusters: {n_clusters}") +print(f"Number of noise points: {n_noise}") + +# Step 4: Visualize +print("\nGenerating visualization...") + +plt.figure(figsize=(12, 8)) +scatter = plt.scatter( + embeddings_2d[:, 0], + embeddings_2d[:, 1], + c=cluster_labels, + cmap='viridis', + alpha=0.6 +) +plt.colorbar(scatter) +plt.title("Protein Sequence Clustering (ESM C Embeddings)") +plt.xlabel("t-SNE 1") +plt.ylabel("t-SNE 2") +plt.savefig("protein_clusters.png", dpi=300, bbox_inches='tight') +print("Visualization saved to: protein_clusters.png") + +# Step 5: Analyze clusters +print("\nCluster analysis:") + +for cluster_id in range(n_clusters): + cluster_indices = np.where(cluster_labels == cluster_id)[0] + cluster_seqs = [sequences[i] for i in cluster_indices] + + print(f"\nCluster {cluster_id}:") + print(f" Size: {len(cluster_seqs)}") + print(f" Avg length: {np.mean([len(s) for s in cluster_seqs]):.1f}") + print(f" Example: {cluster_seqs[0][:40]}...") + +# Save cluster assignments +with open("cluster_assignments.txt", "w") as f: + for i, (seq, label) in enumerate(zip(sequences, cluster_labels)): + f.write(f"Sequence_{i}\tCluster_{label}\t{seq}\n") + +print("\nCluster assignments saved to: cluster_assignments.txt") +``` + +## Additional Workflow Tips + +### Memory Management for Large Datasets + +```python +def process_large_dataset(sequences, batch_size=32): + """Process large dataset with memory management.""" + import gc + import torch + + results = [] + + for i in range(0, len(sequences), batch_size): + batch = sequences[i:i + batch_size] + + # Process batch + batch_results = [process_sequence(seq) for seq in batch] + results.extend(batch_results) + + # Clear memory + torch.cuda.empty_cache() + gc.collect() + + if (i + batch_size) % 100 == 0: + print(f"Processed {min(i + batch_size, len(sequences))}/{len(sequences)}") + + return results +``` + +### Parallel Processing + +```python +from concurrent.futures import ThreadPoolExecutor +import asyncio + +def parallel_workflow(sequences, n_workers=4): + """Process sequences in parallel.""" + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + results = list(executor.map(process_sequence, sequences)) + + return results +``` + +These workflows provide comprehensive examples for common ESM use cases. Adapt them to your specific needs and always validate results with appropriate biological experiments. diff --git a/scientific-packages/scvi-tools/references/theoretical-foundations.md b/scientific-packages/scvi-tools/references/theoretical-foundations.md new file mode 100644 index 0000000..ededc05 --- /dev/null +++ b/scientific-packages/scvi-tools/references/theoretical-foundations.md @@ -0,0 +1,438 @@ +# Theoretical Foundations of scvi-tools + +This document explains the mathematical and statistical principles underlying scvi-tools. + +## Core Concepts + +### Variational Inference + +**What is it?** +Variational inference is a technique for approximating complex probability distributions. In single-cell analysis, we want to understand the posterior distribution p(z|x) - the probability of latent variables z given observed data x. + +**Why use it?** +- Exact inference is computationally intractable for complex models +- Scales to large datasets (millions of cells) +- Provides uncertainty quantification +- Enables Bayesian reasoning about cell states + +**How does it work?** +1. Define a simpler approximate distribution q(z|x) with learnable parameters +2. Minimize the KL divergence between q(z|x) and true posterior p(z|x) +3. Equivalent to maximizing the Evidence Lower Bound (ELBO) + +**ELBO Objective**: +``` +ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z)) + ↑ ↑ + Reconstruction Regularization +``` + +- **Reconstruction term**: Model should generate data similar to observed +- **Regularization term**: Latent representation should match prior + +### Variational Autoencoders (VAEs) + +**Architecture**: +``` +x (observed data) + ↓ +[Encoder Neural Network] + ↓ +z (latent representation) + ↓ +[Decoder Neural Network] + ↓ +x̂ (reconstructed data) +``` + +**Encoder**: Maps cells (x) to latent space (z) +- Learns q(z|x), the approximate posterior +- Parameterized by neural network with learnable weights +- Outputs mean and variance of latent distribution + +**Decoder**: Maps latent space (z) back to gene space +- Learns p(x|z), the likelihood +- Generates gene expression from latent representation +- Models count distributions (Negative Binomial, Zero-Inflated NB) + +**Reparameterization Trick**: +- Allows backpropagation through stochastic sampling +- Sample z = μ + σ ⊙ ε, where ε ~ N(0,1) +- Enables end-to-end training with gradient descent + +### Amortized Inference + +**Concept**: Share encoder parameters across all cells. + +**Traditional inference**: Learn separate latent variables for each cell +- n_cells × n_latent parameters +- Doesn't scale to large datasets + +**Amortized inference**: Learn single encoder for all cells +- Fixed number of parameters regardless of cell count +- Enables fast inference on new cells +- Transfers learned patterns across dataset + +**Benefits**: +- Scalable to millions of cells +- Fast inference on query data +- Leverages shared structure across cells +- Enables few-shot learning + +## Statistical Modeling + +### Count Data Distributions + +Single-cell data are counts (integer-valued), requiring appropriate distributions. + +#### Negative Binomial (NB) +``` +x ~ NB(μ, θ) +``` +- **μ (mean)**: Expected expression level +- **θ (dispersion)**: Controls variance +- **Variance**: Var(x) = μ + μ²/θ + +**When to use**: Gene expression without zero-inflation +- More flexible than Poisson (allows overdispersion) +- Models technical and biological variation + +#### Zero-Inflated Negative Binomial (ZINB) +``` +x ~ π·δ₀ + (1-π)·NB(μ, θ) +``` +- **π (dropout rate)**: Probability of technical zero +- **δ₀**: Point mass at zero +- **NB(μ, θ)**: Expression when not dropped out + +**When to use**: Sparse scRNA-seq data +- Models technical dropout separately from biological zeros +- Better fit for highly sparse data (e.g., 10x data) + +#### Poisson +``` +x ~ Poisson(μ) +``` +- Simplest count distribution +- Mean equals variance: Var(x) = μ + +**When to use**: Less common; ATAC-seq fragment counts +- More restrictive than NB +- Faster computation + +### Batch Correction Framework + +**Problem**: Technical variation confounds biological signal +- Different sequencing runs, protocols, labs +- Must remove technical effects while preserving biology + +**scvi-tools approach**: +1. Encode batch as categorical variable s +2. Include s in generative model +3. Latent space z is batch-invariant +4. Decoder conditions on s for batch-specific effects + +**Mathematical formulation**: +``` +Encoder: q(z|x, s) - batch-aware encoding +Latent: z - batch-corrected representation +Decoder: p(x|z, s) - batch-specific decoding +``` + +**Key insight**: Batch info flows through decoder, not latent space +- z captures biological variation +- s explains technical variation +- Separable biology and batch effects + +### Deep Generative Modeling + +**Generative model**: Learns p(x), the data distribution + +**Process**: +1. Sample latent variable: z ~ p(z) = N(0, I) +2. Generate expression: x ~ p(x|z) +3. Joint distribution: p(x, z) = p(x|z)p(z) + +**Benefits**: +- Generate synthetic cells +- Impute missing values +- Quantify uncertainty +- Perform counterfactual predictions + +**Inference network**: Inverts generative process +- Given x, infer z +- q(z|x) approximates true posterior p(z|x) + +## Model Architecture Details + +### scVI Architecture + +**Input**: Gene expression counts x ∈ ℕ^G (G genes) + +**Encoder**: +``` +h = ReLU(W₁·x + b₁) +μ_z = W₂·h + b₂ +log σ²_z = W₃·h + b₃ +z ~ N(μ_z, σ²_z) +``` + +**Latent space**: z ∈ ℝ^d (typically d=10-30) + +**Decoder**: +``` +h = ReLU(W₄·z + b₄) +μ = softmax(W₅·h + b₅) · library_size +θ = exp(W₆·h + b₆) +π = sigmoid(W₇·h + b₇) # for ZINB +x ~ ZINB(μ, θ, π) +``` + +**Loss function (ELBO)**: +``` +L = E_q[log p(x|z)] - KL(q(z|x) || N(0,I)) +``` + +### Handling Covariates + +**Categorical covariates** (batch, donor, etc.): +- One-hot encoded: s ∈ {0,1}^K +- Concatenate with latent: [z, s] +- Or use conditional layers + +**Continuous covariates** (library size, percent_mito): +- Standardize to zero mean, unit variance +- Include in encoder and/or decoder + +**Covariate injection strategies**: +- **Concatenation**: [z, s] fed to decoder +- **Deep injection**: s added at multiple layers +- **Conditional batch norm**: Batch-specific normalization + +## Advanced Theoretical Concepts + +### Transfer Learning (scArches) + +**Concept**: Use pretrained model as initialization for new data + +**Process**: +1. Train reference model on large dataset +2. Freeze encoder parameters +3. Fine-tune decoder on query data +4. Or fine-tune all with lower learning rate + +**Why it works**: +- Encoder learns general cellular representations +- Decoder adapts to query-specific characteristics +- Prevents catastrophic forgetting + +**Applications**: +- Query-to-reference mapping +- Few-shot learning for rare cell types +- Rapid analysis of new datasets + +### Multi-Resolution Modeling (MrVI) + +**Idea**: Separate shared and sample-specific variation + +**Latent space decomposition**: +``` +z = z_shared + z_sample +``` +- **z_shared**: Common across samples +- **z_sample**: Sample-specific effects + +**Hierarchical structure**: +``` +Sample level: ρ_s ~ N(0, I) +Cell level: z_i ~ N(ρ_{s(i)}, σ²) +``` + +**Benefits**: +- Disentangle biological sources of variation +- Compare samples at different resolutions +- Identify sample-specific cell states + +### Counterfactual Prediction + +**Goal**: Predict outcome under different conditions + +**Example**: "What would this cell look like if from different batch?" + +**Method**: +1. Encode cell to latent: z = Encoder(x, s_original) +2. Decode with new condition: x_new = Decoder(z, s_new) +3. x_new is counterfactual prediction + +**Applications**: +- Batch effect assessment +- Predicting treatment response +- In silico perturbation studies + +### Posterior Predictive Distribution + +**Definition**: Distribution of new data given observed data + +``` +p(x_new | x_observed) = ∫ p(x_new|z) q(z|x_observed) dz +``` + +**Estimation**: Sample z from q(z|x), generate x_new from p(x_new|z) + +**Uses**: +- Uncertainty quantification +- Robust predictions +- Outlier detection + +## Differential Expression Framework + +### Bayesian Approach + +**Traditional methods**: Compare point estimates +- Wilcoxon, t-test, etc. +- Ignore uncertainty +- Require pseudocounts + +**scvi-tools approach**: Compare distributions +- Sample from posterior: μ_A ~ p(μ|x_A), μ_B ~ p(μ|x_B) +- Compute log fold-change: LFC = log(μ_B) - log(μ_A) +- Posterior distribution of LFC quantifies uncertainty + +### Bayes Factor + +**Definition**: Ratio of posterior odds to prior odds + +``` +BF = P(H₁|data) / P(H₀|data) + ───────────────────────── + P(H₁) / P(H₀) +``` + +**Interpretation**: +- BF > 3: Moderate evidence for H₁ +- BF > 10: Strong evidence +- BF > 100: Decisive evidence + +**In scvi-tools**: Used to rank genes by evidence for DE + +### False Discovery Proportion (FDP) + +**Goal**: Control expected false discovery rate + +**Procedure**: +1. For each gene, compute posterior probability of DE +2. Rank genes by evidence (Bayes factor) +3. Select top k genes such that E[FDP] ≤ α + +**Advantage over p-values**: +- Fully Bayesian +- Natural for posterior inference +- No arbitrary thresholds + +## Implementation Details + +### Optimization + +**Optimizer**: Adam (adaptive learning rates) +- Default lr = 0.001 +- Momentum parameters: β₁=0.9, β₂=0.999 + +**Training loop**: +1. Sample mini-batch of cells +2. Compute ELBO loss +3. Backpropagate gradients +4. Update parameters with Adam +5. Repeat until convergence + +**Convergence criteria**: +- ELBO plateaus on validation set +- Early stopping prevents overfitting +- Typically 200-500 epochs + +### Regularization + +**KL annealing**: Gradually increase KL weight +- Prevents posterior collapse +- Starts at 0, increases to 1 over epochs + +**Dropout**: Random neuron dropping during training +- Default: 0.1 dropout rate +- Prevents overfitting +- Improves generalization + +**Weight decay**: L2 regularization on weights +- Prevents large weights +- Improves stability + +### Scalability + +**Mini-batch training**: +- Process subset of cells per iteration +- Batch size: 64-256 cells +- Enables scaling to millions of cells + +**Stochastic optimization**: +- Estimates ELBO on mini-batches +- Unbiased gradient estimates +- Converges to optimal solution + +**GPU acceleration**: +- Neural networks naturally parallelize +- Order of magnitude speedup +- Essential for large datasets + +## Connections to Other Methods + +### vs. PCA +- **PCA**: Linear, deterministic +- **scVI**: Nonlinear, probabilistic +- **Advantage**: scVI captures complex structure, handles counts + +### vs. t-SNE/UMAP +- **t-SNE/UMAP**: Visualization-focused +- **scVI**: Full generative model +- **Advantage**: scVI enables downstream tasks (DE, imputation) + +### vs. Seurat Integration +- **Seurat**: Anchor-based alignment +- **scVI**: Probabilistic modeling +- **Advantage**: scVI provides uncertainty, works for multiple batches + +### vs. Harmony +- **Harmony**: PCA + batch correction +- **scVI**: VAE-based +- **Advantage**: scVI handles counts natively, more flexible + +## Mathematical Notation + +**Common symbols**: +- x: Observed gene expression (counts) +- z: Latent representation +- θ: Model parameters +- q(z|x): Approximate posterior (encoder) +- p(x|z): Likelihood (decoder) +- p(z): Prior on latent variables +- μ, σ²: Mean and variance +- π: Dropout probability (ZINB) +- θ (in NB): Dispersion parameter +- s: Batch/covariate indicator + +## Further Reading + +**Key Papers**: +1. Lopez et al. (2018): "Deep generative modeling for single-cell transcriptomics" +2. Xu et al. (2021): "Probabilistic harmonization and annotation of single-cell transcriptomics" +3. Boyeau et al. (2019): "Deep generative models for detecting differential expression in single cells" + +**Concepts to explore**: +- Variational inference in machine learning +- Bayesian deep learning +- Information theory (KL divergence, mutual information) +- Generative models (GANs, normalizing flows, diffusion models) +- Probabilistic programming (Pyro, PyTorch) + +**Mathematical background**: +- Probability theory and statistics +- Linear algebra and calculus +- Optimization theory +- Information theory