Initial commit

This commit is contained in:
Timothy Kassis
2025-10-19 14:01:29 -07:00
parent d85386c32b
commit 152d0d54de
15 changed files with 4569 additions and 0 deletions

View File

@@ -0,0 +1,415 @@
---
name: arboreto
description: Toolkit for gene regulatory network (GRN) inference from expression data using machine learning. Use this skill when working with gene expression matrices to infer regulatory relationships, performing single-cell RNA-seq analysis, or integrating with pySCENIC workflows. Supports both GRNBoost2 (fast gradient boosting) and GENIE3 (Random Forest) algorithms with distributed computing via Dask.
---
# Arboreto - Gene Regulatory Network Inference
## Overview
Arboreto is a Python library for inferring gene regulatory networks (GRNs) from gene expression data using machine learning algorithms. It enables scalable GRN inference from single machines to multi-node clusters using Dask for distributed computing. The skill provides comprehensive support for both GRNBoost2 (fast gradient boosting) and GENIE3 (Random Forest) algorithms.
## When to Use This Skill
Apply this skill when:
- Inferring regulatory relationships between genes from expression data
- Analyzing single-cell or bulk RNA-seq data to identify transcription factor targets
- Building the GRN inference component of a pySCENIC pipeline
- Comparing GRNBoost2 and GENIE3 algorithm performance
- Setting up distributed computing for large-scale genomic analyses
- Troubleshooting arboreto installation or runtime issues
## Core Capabilities
### 1. Basic GRN Inference
For standard gene regulatory network inference tasks:
**Key considerations:**
- Expression data format: Rows = observations (cells/samples), Columns = genes
- If data has genes as rows, transpose it first: `expression_df.T`
- Always include `seed` parameter for reproducible results
- Transcription factor list is optional but recommended for focused analysis
**Typical workflow:**
```python
import pandas as pd
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
# Load expression data (ensure correct orientation)
expression_data = pd.read_csv('expression_data.tsv', sep='\t', index_col=0)
# Optional: Load TF names
tf_names = load_tf_names('transcription_factors.txt')
# Run inference
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=42 # For reproducibility
)
# Save results
network.to_csv('network_output.tsv', sep='\t', index=False)
```
**Output format:**
- DataFrame with columns: `['TF', 'target', 'importance']`
- Higher importance scores indicate stronger predicted regulatory relationships
- Typically sorted by importance (descending)
**Multiprocessing requirement:**
All arboreto code must include `if __name__ == '__main__':` protection due to Dask's multiprocessing requirements:
```python
if __name__ == '__main__':
# Arboreto code goes here
network = grnboost2(expression_data=expr_data, seed=42)
```
### 2. Algorithm Selection
**GRNBoost2 (Recommended for most cases):**
- ~10-100x faster than GENIE3
- Uses stochastic gradient boosting with early-stopping
- Best for: Large datasets (>10k observations), time-sensitive analyses
- Function: `arboreto.algo.grnboost2()`
**GENIE3:**
- Uses Random Forest regression
- More established, classical approach
- Best for: Small datasets, methodological comparisons, reproducing published results
- Function: `arboreto.algo.genie3()`
**When to compare both algorithms:**
Use the provided `compare_algorithms.py` script when:
- Validating results for critical analyses
- Benchmarking performance on new datasets
- Publishing research requiring methodological comparisons
### 3. Distributed Computing
**Local execution (default):**
Arboreto automatically creates a local Dask client. No configuration needed:
```python
network = grnboost2(expression_data=expr_data)
```
**Custom local cluster (recommended for better control):**
```python
from dask.distributed import Client, LocalCluster
# Configure cluster
cluster = LocalCluster(
n_workers=4,
threads_per_worker=2,
memory_limit='4GB',
diagnostics_port=8787 # Dashboard at http://localhost:8787
)
client = Client(cluster)
# Run inference
network = grnboost2(
expression_data=expr_data,
client_or_address=client
)
# Clean up
client.close()
cluster.close()
```
**Distributed cluster (multi-node):**
On scheduler node:
```bash
dask-scheduler --no-bokeh
```
On worker nodes:
```bash
dask-worker scheduler-address:8786 --local-dir /tmp
```
In Python:
```python
from dask.distributed import Client
client = Client('scheduler-address:8786')
network = grnboost2(expression_data=expr_data, client_or_address=client)
```
### 4. Data Preparation
**Common data format issues:**
1. **Transposed data** (genes as rows instead of columns):
```python
# If genes are rows, transpose
expression_data = pd.read_csv('data.tsv', sep='\t', index_col=0).T
```
2. **Missing gene names:**
```python
# Provide gene names if using numpy array
network = grnboost2(
expression_data=expr_array,
gene_names=['Gene1', 'Gene2', 'Gene3', ...],
seed=42
)
```
3. **Transcription factor specification:**
```python
# Option 1: Python list
tf_names = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
# Option 2: Load from file (one TF per line)
from arboreto.utils import load_tf_names
tf_names = load_tf_names('tf_names.txt')
```
### 5. Reproducibility
Always specify a seed for consistent results:
```python
network = grnboost2(expression_data=expr_data, seed=42)
```
Without a seed, results will vary between runs due to algorithm randomness.
### 6. Result Interpretation
**Understanding the output:**
- `TF`: Transcription factor (regulator) gene
- `target`: Target gene being regulated
- `importance`: Strength of predicted regulatory relationship
**Typical post-processing:**
```python
# Filter by importance threshold
high_confidence = network[network['importance'] > 10]
# Get top N predictions
top_predictions = network.head(1000)
# Find all targets of a specific TF
sox2_targets = network[network['TF'] == 'Sox2']
# Count regulations per TF
tf_counts = network['TF'].value_counts()
```
## Installation
**Recommended (via conda):**
```bash
conda install -c bioconda arboreto
```
**Via pip:**
```bash
pip install arboreto
```
**From source:**
```bash
git clone https://github.com/tmoerman/arboreto.git
cd arboreto
pip install .
```
**Dependencies:**
- pandas
- numpy
- scikit-learn
- scipy
- dask
- distributed
## Troubleshooting
### Issue: Bokeh error when launching Dask scheduler
**Error:** `TypeError: got an unexpected keyword argument 'host'`
**Solutions:**
- Use `dask-scheduler --no-bokeh` to disable Bokeh
- Upgrade to Dask distributed >= 0.20.0
### Issue: Workers not connecting to scheduler
**Symptoms:** Worker processes start but fail to establish connections
**Solutions:**
- Remove `dask-worker-space` directory before restarting workers
- Specify adequate `local_dir` when creating cluster:
```python
cluster = LocalCluster(
worker_kwargs={'local_dir': '/tmp'}
)
```
### Issue: Memory errors with large datasets
**Solutions:**
- Increase worker memory limits: `memory_limit='8GB'`
- Distribute across more nodes
- Reduce dataset size through preprocessing (e.g., feature selection)
- Ensure expression matrix fits in available RAM
### Issue: Inconsistent results across runs
**Solution:** Always specify a `seed` parameter:
```python
network = grnboost2(expression_data=expr_data, seed=42)
```
### Issue: Import errors or missing dependencies
**Solution:** Use conda installation to handle numerical library dependencies:
```bash
conda create --name arboreto-env
conda activate arboreto-env
conda install -c bioconda arboreto
```
## Provided Scripts
This skill includes ready-to-use scripts for common workflows:
### scripts/basic_grn_inference.py
Command-line tool for standard GRN inference workflow.
**Usage:**
```bash
python scripts/basic_grn_inference.py expression_data.tsv \
-t tf_names.txt \
-o network.tsv \
-s 42 \
--transpose # if genes are rows
```
**Features:**
- Automatic data loading and validation
- Optional TF list specification
- Configurable output format
- Data transposition support
- Summary statistics
### scripts/distributed_inference.py
GRN inference with custom Dask cluster configuration.
**Usage:**
```bash
python scripts/distributed_inference.py expression_data.tsv \
-t tf_names.txt \
-w 8 \
-m 4GB \
--threads 2 \
--dashboard-port 8787
```
**Features:**
- Configurable worker count and memory limits
- Dask dashboard integration
- Thread configuration
- Resource monitoring
### scripts/compare_algorithms.py
Compare GRNBoost2 and GENIE3 side-by-side.
**Usage:**
```bash
python scripts/compare_algorithms.py expression_data.tsv \
-t tf_names.txt \
--top-n 100
```
**Features:**
- Runtime comparison
- Network statistics
- Prediction overlap analysis
- Top prediction comparison
## Reference Documentation
Detailed API documentation is available in [references/api_reference.md](references/api_reference.md), including:
- Complete parameter descriptions for all functions
- Data format specifications
- Distributed computing configuration
- Performance optimization tips
- Integration with pySCENIC
- Comprehensive examples
Load this reference when:
- Working with advanced Dask configurations
- Troubleshooting complex deployment scenarios
- Understanding algorithm internals
- Optimizing performance for specific use cases
## Integration with pySCENIC
Arboreto is the first step in the pySCENIC single-cell analysis pipeline:
1. **GRN Inference (arboreto)** ← This skill
- Input: Expression matrix
- Output: Regulatory network
2. **Regulon Prediction (pySCENIC)**
- Input: Network from arboreto
- Output: Refined regulons
3. **Cell Type Identification (pySCENIC)**
- Input: Regulons
- Output: Cell type scores
When working with pySCENIC, use arboreto to generate the initial network, then pass results to the pySCENIC pipeline.
## Best Practices
1. **Always use seed parameter** for reproducible research
2. **Validate data orientation** (rows = observations, columns = genes)
3. **Specify TF list** when known to focus inference and improve speed
4. **Monitor with Dask dashboard** for distributed computing
5. **Save intermediate results** to avoid re-running long computations
6. **Filter results** by importance threshold for downstream analysis
7. **Use GRNBoost2 by default** unless specifically requiring GENIE3
8. **Include multiprocessing guard** (`if __name__ == '__main__':`) in all scripts
## Quick Reference
**Basic inference:**
```python
from arboreto.algo import grnboost2
network = grnboost2(expression_data=expr_df, seed=42)
```
**With TF specification:**
```python
network = grnboost2(expression_data=expr_df, tf_names=tf_list, seed=42)
```
**With custom Dask client:**
```python
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=4)
client = Client(cluster)
network = grnboost2(expression_data=expr_df, client_or_address=client, seed=42)
client.close()
cluster.close()
```
**Load TF names:**
```python
from arboreto.utils import load_tf_names
tf_names = load_tf_names('transcription_factors.txt')
```
**Transpose data:**
```python
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T
```

View File

@@ -0,0 +1,271 @@
# Arboreto API Reference
This document provides comprehensive API documentation for the arboreto package, a Python library for gene regulatory network (GRN) inference.
## Overview
Arboreto enables inference of gene regulatory networks from expression data using machine learning algorithms. It supports distributed computing via Dask for scalability from single machines to multi-node clusters.
**Current Version:** 0.1.5
**GitHub:** https://github.com/tmoerman/arboreto
**License:** BSD 3-Clause
## Core Algorithms
### GRNBoost2
The flagship algorithm for fast gene regulatory network inference using stochastic gradient boosting.
**Function:** `arboreto.algo.grnboost2()`
**Parameters:**
- `expression_data` (pandas.DataFrame or numpy.ndarray): Expression matrix where rows are observations (cells/samples) and columns are genes. Required.
- `gene_names` (list, optional): List of gene names matching column order. If None, uses DataFrame column names.
- `tf_names` (list, optional): List of transcription factor names to consider as regulators. If None, all genes are considered potential regulators.
- `seed` (int, optional): Random seed for reproducibility. Recommended when consistent results are needed across runs.
- `client_or_address` (dask.distributed.Client or str, optional): Custom Dask client or scheduler address for distributed computing. If None, creates a default local client.
- `verbose` (bool, optional): Enable verbose output for debugging.
**Returns:**
- pandas.DataFrame with columns `['TF', 'target', 'importance']` representing inferred regulatory links. Each row represents a regulatory relationship with an importance score.
**Algorithm Details:**
- Uses stochastic gradient boosting with early-stopping regularization
- Much faster than GENIE3, especially for large datasets (tens of thousands of observations)
- Extracts important features from trained regression models to identify regulatory relationships
- Recommended as the default choice for most use cases
**Example:**
```python
from arboreto.algo import grnboost2
import pandas as pd
# Load expression data
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
tf_list = ['TF1', 'TF2', 'TF3'] # Optional: specify TFs
# Run inference
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_list,
seed=42 # For reproducibility
)
# Save results
network.to_csv('output_network.tsv', sep='\t', index=False)
```
### GENIE3
Classical gene regulatory network inference using Random Forest regression.
**Function:** `arboreto.algo.genie3()`
**Parameters:**
Same as GRNBoost2 (see above).
**Returns:**
Same format as GRNBoost2 (see above).
**Algorithm Details:**
- Uses Random Forest or ExtraTrees regression models
- Blueprint for multiple regression GRN inference strategy
- More computationally expensive than GRNBoost2
- Better suited for smaller datasets or when maximum accuracy is needed
**When to Use GENIE3 vs GRNBoost2:**
- **Use GRNBoost2:** For large datasets, faster results, or when computational resources are limited
- **Use GENIE3:** For smaller datasets, when following established protocols, or for comparison with published results
## Module Structure
### arboreto.algo
Primary module for typical users. Contains high-level inference functions.
**Main Functions:**
- `grnboost2()` - Fast GRN inference using gradient boosting
- `genie3()` - Classical GRN inference using Random Forest
### arboreto.core
Advanced module for power users. Contains low-level framework components for custom implementations.
**Use cases:**
- Custom inference pipelines
- Algorithm modifications
- Performance tuning
### arboreto.utils
Utility functions for common data processing tasks.
**Key Functions:**
- `load_tf_names(filename)` - Load transcription factor names from file
- Reads a text file with one TF name per line
- Returns a list of TF names
- Example: `tf_names = load_tf_names('transcription_factors.txt')`
## Data Format Requirements
### Input Format
**Expression Matrix:**
- **Format:** pandas DataFrame or numpy ndarray
- **Orientation:** Rows = observations (cells/samples), Columns = genes
- **Convention:** Follows scikit-learn format
- **Gene Names:** Column names (DataFrame) or separate `gene_names` parameter
- **Data Type:** Numeric (float or int)
**Common Mistake:** If data is transposed (genes as rows), use pandas to transpose:
```python
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T
```
**Transcription Factor List:**
- **Format:** Python list of strings or text file (one TF per line)
- **Optional:** If not provided, all genes are considered potential regulators
- **Example:** `['Sox2', 'Oct4', 'Nanog']`
### Output Format
**Network DataFrame:**
- **Columns:**
- `TF` (str): Transcription factor (regulator) gene name
- `target` (str): Target gene name
- `importance` (float): Importance score of the regulatory relationship
- **Interpretation:** Higher importance scores indicate stronger predicted regulatory relationships
- **Sorting:** Typically sorted by importance (descending) for prioritization
**Example Output:**
```
TF target importance
Sox2 Gene1 15.234
Oct4 Gene1 12.456
Sox2 Gene2 8.901
```
## Distributed Computing with Dask
### Local Execution (Default)
Arboreto automatically creates a local Dask client if none is provided:
```python
network = grnboost2(expression_data=expr_matrix, tf_names=tf_list)
```
### Custom Local Cluster
For better control over resources or multiple inferences:
```python
from dask.distributed import Client, LocalCluster
# Configure cluster
cluster = LocalCluster(
n_workers=4,
threads_per_worker=2,
memory_limit='4GB'
)
client = Client(cluster)
# Run inference
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
client_or_address=client
)
# Clean up
client.close()
cluster.close()
```
### Distributed Cluster
For multi-node computation:
**On scheduler node:**
```bash
dask-scheduler --no-bokeh # Use --no-bokeh to avoid Bokeh errors
```
**On worker nodes:**
```bash
dask-worker scheduler-address:8786 --local-dir /tmp
```
**In Python script:**
```python
from dask.distributed import Client
client = Client('scheduler-address:8786')
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
client_or_address=client
)
```
### Dask Dashboard
Monitor computation progress via the Dask dashboard:
```python
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(diagnostics_port=8787)
client = Client(cluster)
# Dashboard available at: http://localhost:8787
```
## Reproducibility
To ensure reproducible results across runs:
```python
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
seed=42 # Fixed seed ensures identical results
)
```
**Note:** Without a seed parameter, results may vary slightly between runs due to randomness in the algorithms.
## Performance Considerations
### Memory Management
- Expression matrices should fit in memory (RAM)
- For very large datasets, consider:
- Using a machine with more RAM
- Distributing across multiple nodes
- Preprocessing to reduce dimensionality
### Worker Configuration
- **Local execution:** Number of workers = number of CPU cores (default)
- **Custom cluster:** Balance workers and threads based on available resources
- **Distributed execution:** Ensure adequate `local_dir` space on worker nodes
### Algorithm Choice
- **GRNBoost2:** ~10-100x faster than GENIE3 for large datasets
- **GENIE3:** More established but slower, better for small datasets (<10k observations)
## Integration with pySCENIC
Arboreto is a core component of the pySCENIC pipeline for single-cell RNA sequencing analysis:
1. **GRN Inference (Arboreto):** Infer regulatory networks using GRNBoost2
2. **Regulon Prediction:** Prune network and identify regulons
3. **Cell Type Identification:** Score regulons across cells
For pySCENIC workflows, arboreto is typically used in the first step to generate the initial regulatory network.
## Common Issues and Solutions
See the main SKILL.md for troubleshooting guidance.

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
Basic GRN inference script using arboreto GRNBoost2.
This script demonstrates the standard workflow for gene regulatory network inference:
1. Load expression data
2. Optionally load transcription factor names
3. Run GRNBoost2 inference
4. Save results
Usage:
python basic_grn_inference.py <expression_file> [options]
Example:
python basic_grn_inference.py expression_data.tsv -t tf_names.txt -o network.tsv
"""
import argparse
import pandas as pd
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
parser = argparse.ArgumentParser(
description='Infer gene regulatory network using GRNBoost2'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference...")
print(" (This may take a while depending on dataset size)")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python3
"""
Compare GRNBoost2 and GENIE3 algorithms on the same dataset.
This script runs both algorithms on the same expression data and compares:
- Runtime
- Number of predicted links
- Top predicted relationships
- Overlap between predictions
Usage:
python compare_algorithms.py <expression_file> [options]
Example:
python compare_algorithms.py expression_data.tsv -t tf_names.txt
"""
import argparse
import time
import pandas as pd
from arboreto.algo import grnboost2, genie3
from arboreto.utils import load_tf_names
def compare_networks(network1, network2, name1, name2, top_n=100):
"""Compare two inferred networks."""
print(f"\n{'='*60}")
print("Network Comparison")
print(f"{'='*60}")
# Basic statistics
print(f"\n{name1} Statistics:")
print(f" Total links: {len(network1)}")
print(f" Unique TFs: {network1['TF'].nunique()}")
print(f" Unique targets: {network1['target'].nunique()}")
print(f" Importance range: [{network1['importance'].min():.3f}, {network1['importance'].max():.3f}]")
print(f"\n{name2} Statistics:")
print(f" Total links: {len(network2)}")
print(f" Unique TFs: {network2['TF'].nunique()}")
print(f" Unique targets: {network2['target'].nunique()}")
print(f" Importance range: [{network2['importance'].min():.3f}, {network2['importance'].max():.3f}]")
# Compare top predictions
print(f"\nTop {top_n} Predictions Overlap:")
# Create edge sets for top N predictions
top_edges1 = set(
zip(network1.head(top_n)['TF'], network1.head(top_n)['target'])
)
top_edges2 = set(
zip(network2.head(top_n)['TF'], network2.head(top_n)['target'])
)
# Calculate overlap
overlap = top_edges1 & top_edges2
only_net1 = top_edges1 - top_edges2
only_net2 = top_edges2 - top_edges1
overlap_pct = (len(overlap) / top_n) * 100
print(f" Shared edges: {len(overlap)} ({overlap_pct:.1f}%)")
print(f" Only in {name1}: {len(only_net1)}")
print(f" Only in {name2}: {len(only_net2)}")
# Show some example overlapping edges
if overlap:
print(f"\nExample overlapping predictions:")
for i, (tf, target) in enumerate(list(overlap)[:5], 1):
print(f" {i}. {tf} -> {target}")
def main():
parser = argparse.ArgumentParser(
description='Compare GRNBoost2 and GENIE3 algorithms'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'--grnboost2-output',
help='Output file path for GRNBoost2 results',
default='grnboost2_network.tsv'
)
parser.add_argument(
'--genie3-output',
help='Output file path for GENIE3 results',
default='genie3_network.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
parser.add_argument(
'--top-n',
type=int,
help='Number of top predictions to compare (default: 100)',
default=100
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\n" + "="*60)
print("Running GRNBoost2...")
print("="*60)
start_time = time.time()
grnboost2_network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
grnboost2_time = time.time() - start_time
print(f"GRNBoost2 completed in {grnboost2_time:.2f} seconds")
# Save GRNBoost2 results
grnboost2_network.to_csv(args.grnboost2_output, sep='\t', index=False)
print(f"Results saved to {args.grnboost2_output}")
# Run GENIE3
print("\n" + "="*60)
print("Running GENIE3...")
print("="*60)
start_time = time.time()
genie3_network = genie3(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
genie3_time = time.time() - start_time
print(f"GENIE3 completed in {genie3_time:.2f} seconds")
# Save GENIE3 results
genie3_network.to_csv(args.genie3_output, sep='\t', index=False)
print(f"Results saved to {args.genie3_output}")
# Compare runtimes
print("\n" + "="*60)
print("Runtime Comparison")
print("="*60)
print(f"GRNBoost2: {grnboost2_time:.2f} seconds")
print(f"GENIE3: {genie3_time:.2f} seconds")
speedup = genie3_time / grnboost2_time
print(f"Speedup: {speedup:.2f}x (GRNBoost2 is {speedup:.2f}x faster)")
# Compare networks
compare_networks(
grnboost2_network,
genie3_network,
"GRNBoost2",
"GENIE3",
top_n=args.top_n
)
print("\n" + "="*60)
print("Comparison complete!")
print("="*60)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
Distributed GRN inference script using arboreto with custom Dask configuration.
This script demonstrates how to use arboreto with a custom Dask LocalCluster
for better control over computational resources.
Usage:
python distributed_inference.py <expression_file> [options]
Example:
python distributed_inference.py expression_data.tsv -t tf_names.txt -w 8 -m 4GB
"""
import argparse
import pandas as pd
from dask.distributed import Client, LocalCluster
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
parser = argparse.ArgumentParser(
description='Distributed GRN inference using GRNBoost2 with custom Dask cluster'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'-w', '--workers',
type=int,
help='Number of Dask workers',
default=4
)
parser.add_argument(
'-m', '--memory-limit',
help='Memory limit per worker (e.g., "4GB", "2000MB")',
default='4GB'
)
parser.add_argument(
'--threads',
type=int,
help='Threads per worker',
default=2
)
parser.add_argument(
'--dashboard-port',
type=int,
help='Port for Dask dashboard (default: 8787)',
default=8787
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Set up Dask cluster
print(f"\nSetting up Dask LocalCluster...")
print(f" Workers: {args.workers}")
print(f" Threads per worker: {args.threads}")
print(f" Memory limit per worker: {args.memory_limit}")
print(f" Dashboard: http://localhost:{args.dashboard_port}")
cluster = LocalCluster(
n_workers=args.workers,
threads_per_worker=args.threads,
memory_limit=args.memory_limit,
diagnostics_port=args.dashboard_port
)
client = Client(cluster)
print(f"\nDask cluster ready!")
print(f" Dashboard available at: {client.dashboard_link}")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference with distributed computation...")
print(" (Monitor progress via the Dask dashboard)")
try:
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed,
client_or_address=client
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
finally:
# Clean up Dask resources
print("\nClosing Dask cluster...")
client.close()
cluster.close()
if __name__ == '__main__':
main()