Improve the arboreto skill

This commit is contained in:
Timothy Kassis
2025-11-03 16:21:26 -08:00
parent 6ddea4786e
commit 537edff2a1
8 changed files with 758 additions and 1038 deletions

View File

@@ -1,415 +1,250 @@
--- ---
name: arboreto name: arboreto
description: "Gene regulatory network inference with GRNBoost2/GENIE3 algorithms. Infer TF-target relationships from expression data, scalable with Dask, for scRNA-seq and GRN analysis." description: Infer gene regulatory networks (GRNs) from gene expression data using scalable algorithms (GRNBoost2, GENIE3). Use when analyzing transcriptomics data (bulk RNA-seq, single-cell RNA-seq) to identify transcription factor-target gene relationships and regulatory interactions. Supports distributed computation for large-scale datasets.
--- ---
# Arboreto - Gene Regulatory Network Inference # Arboreto
## Overview ## Overview
Arboreto is a Python library for inferring gene regulatory networks (GRNs) from gene expression data using machine learning algorithms. It enables scalable GRN inference from single machines to multi-node clusters using Dask for distributed computing. The skill provides comprehensive support for both GRNBoost2 (fast gradient boosting) and GENIE3 (Random Forest) algorithms. Arboreto is a computational library for inferring gene regulatory networks (GRNs) from gene expression data using parallelized algorithms that scale from single machines to multi-node clusters.
## When to Use This Skill **Core capability**: Identify which transcription factors (TFs) regulate which target genes based on expression patterns across observations (cells, samples, conditions).
This skill should be used when: ## Quick Start
- Inferring regulatory relationships between genes from expression data
- Analyzing single-cell or bulk RNA-seq data to identify transcription factor targets Install arboreto:
- Building the GRN inference component of a pySCENIC pipeline ```bash
- Comparing GRNBoost2 and GENIE3 algorithm performance pip install arboreto
- Setting up distributed computing for large-scale genomic analyses ```
- Troubleshooting arboreto installation or runtime issues
Basic GRN inference:
```python
import pandas as pd
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Load expression data (genes as columns)
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
# Infer regulatory network
network = grnboost2(expression_data=expression_matrix)
# Save results (TF, target, importance)
network.to_csv('network.tsv', sep='\t', index=False, header=False)
```
**Critical**: Always use `if __name__ == '__main__':` guard because Dask spawns new processes.
## Core Capabilities ## Core Capabilities
### 1. Basic GRN Inference ### 1. Basic GRN Inference
For standard gene regulatory network inference tasks: For standard GRN inference workflows including:
- Input data preparation (Pandas DataFrame or NumPy array)
- Running inference with GRNBoost2 or GENIE3
- Filtering by transcription factors
- Output format and interpretation
**Key considerations:** **See**: `references/basic_inference.md`
- Expression data format: Rows = observations (cells/samples), Columns = genes
- If data has genes as rows, transpose it first: `expression_df.T`
- Always include `seed` parameter for reproducible results
- Transcription factor list is optional but recommended for focused analysis
**Typical workflow:** **Use the ready-to-run script**: `scripts/basic_grn_inference.py` for standard inference tasks:
```python ```bash
import pandas as pd python scripts/basic_grn_inference.py expression_data.tsv output_network.tsv --tf-file tfs.txt --seed 777
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
# Load expression data (ensure correct orientation)
expression_data = pd.read_csv('expression_data.tsv', sep='\t', index_col=0)
# Optional: Load TF names
tf_names = load_tf_names('transcription_factors.txt')
# Run inference
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=42 # For reproducibility
)
# Save results
network.to_csv('network_output.tsv', sep='\t', index=False)
```
**Output format:**
- DataFrame with columns: `['TF', 'target', 'importance']`
- Higher importance scores indicate stronger predicted regulatory relationships
- Typically sorted by importance (descending)
**Multiprocessing requirement:**
All arboreto code must include `if __name__ == '__main__':` protection due to Dask's multiprocessing requirements:
```python
if __name__ == '__main__':
# Arboreto code goes here
network = grnboost2(expression_data=expr_data, seed=42)
``` ```
### 2. Algorithm Selection ### 2. Algorithm Selection
**GRNBoost2 (Recommended for most cases):** Arboreto provides two algorithms:
- ~10-100x faster than GENIE3
- Uses stochastic gradient boosting with early-stopping
- Best for: Large datasets (>10k observations), time-sensitive analyses
- Function: `arboreto.algo.grnboost2()`
**GENIE3:** **GRNBoost2 (Recommended)**:
- Uses Random Forest regression - Fast gradient boosting-based inference
- More established, classical approach - Optimized for large datasets (10k+ observations)
- Best for: Small datasets, methodological comparisons, reproducing published results - Default choice for most analyses
- Function: `arboreto.algo.genie3()`
**When to compare both algorithms:** **GENIE3**:
Use the provided `compare_algorithms.py` script when: - Random Forest-based inference
- Validating results for critical analyses - Original multiple regression approach
- Benchmarking performance on new datasets - Use for comparison or validation
- Publishing research requiring methodological comparisons
Quick comparison:
```python
from arboreto.algo import grnboost2, genie3
# Fast, recommended
network_grnboost = grnboost2(expression_data=matrix)
# Classic algorithm
network_genie3 = genie3(expression_data=matrix)
```
**For detailed algorithm comparison, parameters, and selection guidance**: `references/algorithms.md`
### 3. Distributed Computing ### 3. Distributed Computing
**Local execution (default):** Scale inference from local multi-core to cluster environments:
Arboreto automatically creates a local Dask client. No configuration needed:
**Local (default)** - Uses all available cores automatically:
```python ```python
network = grnboost2(expression_data=expr_data) network = grnboost2(expression_data=matrix)
``` ```
**Custom local cluster (recommended for better control):** **Custom local client** - Control resources:
```python ```python
from dask.distributed import Client, LocalCluster from distributed import LocalCluster, Client
# Configure cluster local_cluster = LocalCluster(n_workers=10, memory_limit='8GB')
cluster = LocalCluster( client = Client(local_cluster)
n_workers=4,
threads_per_worker=2,
memory_limit='4GB',
diagnostics_port=8787 # Dashboard at http://localhost:8787
)
client = Client(cluster)
# Run inference network = grnboost2(expression_data=matrix, client_or_address=client)
network = grnboost2(
expression_data=expr_data,
client_or_address=client
)
# Clean up
client.close() client.close()
cluster.close() local_cluster.close()
``` ```
**Distributed cluster (multi-node):** **Cluster computing** - Connect to remote Dask scheduler:
On scheduler node:
```bash
dask-scheduler --no-bokeh
```
On worker nodes:
```bash
dask-worker scheduler-address:8786 --local-dir /tmp
```
In Python:
```python ```python
from dask.distributed import Client from distributed import Client
client = Client('scheduler-address:8786') client = Client('tcp://scheduler:8786')
network = grnboost2(expression_data=expr_data, client_or_address=client) network = grnboost2(expression_data=matrix, client_or_address=client)
``` ```
### 4. Data Preparation **For cluster setup, performance optimization, and large-scale workflows**: `references/distributed_computing.md`
**Common data format issues:**
1. **Transposed data** (genes as rows instead of columns):
```python
# If genes are rows, transpose
expression_data = pd.read_csv('data.tsv', sep='\t', index_col=0).T
```
2. **Missing gene names:**
```python
# Provide gene names if using numpy array
network = grnboost2(
expression_data=expr_array,
gene_names=['Gene1', 'Gene2', 'Gene3', ...],
seed=42
)
```
3. **Transcription factor specification:**
```python
# Option 1: Python list
tf_names = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
# Option 2: Load from file (one TF per line)
from arboreto.utils import load_tf_names
tf_names = load_tf_names('tf_names.txt')
```
### 5. Reproducibility
Always specify a seed for consistent results:
```python
network = grnboost2(expression_data=expr_data, seed=42)
```
Without a seed, results will vary between runs due to algorithm randomness.
### 6. Result Interpretation
**Understanding the output:**
- `TF`: Transcription factor (regulator) gene
- `target`: Target gene being regulated
- `importance`: Strength of predicted regulatory relationship
**Typical post-processing:**
```python
# Filter by importance threshold
high_confidence = network[network['importance'] > 10]
# Get top N predictions
top_predictions = network.head(1000)
# Find all targets of a specific TF
sox2_targets = network[network['TF'] == 'Sox2']
# Count regulations per TF
tf_counts = network['TF'].value_counts()
```
## Installation ## Installation
**Recommended (via conda):** **Recommended (Conda)**:
```bash ```bash
conda install -c bioconda arboreto conda install -c bioconda arboreto
``` ```
**Via pip:** **Alternative (pip)**:
```bash ```bash
pip install arboreto pip install arboreto
``` ```
**From source:** **For isolated environment**:
```bash
git clone https://github.com/tmoerman/arboreto.git
cd arboreto
pip install .
```
**Dependencies:**
- pandas
- numpy
- scikit-learn
- scipy
- dask
- distributed
## Troubleshooting
### Issue: Bokeh error when launching Dask scheduler
**Error:** `TypeError: got an unexpected keyword argument 'host'`
**Solutions:**
- Use `dask-scheduler --no-bokeh` to disable Bokeh
- Upgrade to Dask distributed >= 0.20.0
### Issue: Workers not connecting to scheduler
**Symptoms:** Worker processes start but fail to establish connections
**Solutions:**
- Remove `dask-worker-space` directory before restarting workers
- Specify adequate `local_dir` when creating cluster:
```python
cluster = LocalCluster(
worker_kwargs={'local_dir': '/tmp'}
)
```
### Issue: Memory errors with large datasets
**Solutions:**
- Increase worker memory limits: `memory_limit='8GB'`
- Distribute across more nodes
- Reduce dataset size through preprocessing (e.g., feature selection)
- Ensure expression matrix fits in available RAM
### Issue: Inconsistent results across runs
**Solution:** Always specify a `seed` parameter:
```python
network = grnboost2(expression_data=expr_data, seed=42)
```
### Issue: Import errors or missing dependencies
**Solution:** Use conda installation to handle numerical library dependencies:
```bash ```bash
conda create --name arboreto-env conda create --name arboreto-env
conda activate arboreto-env conda activate arboreto-env
conda install -c bioconda arboreto conda install -c bioconda arboreto
``` ```
## Provided Scripts **Dependencies**: scipy, scikit-learn, numpy, pandas, dask, distributed
This skill includes ready-to-use scripts for common workflows: ## Common Use Cases
### scripts/basic_grn_inference.py ### Single-Cell RNA-seq Analysis
```python
import pandas as pd
from arboreto.algo import grnboost2
Command-line tool for standard GRN inference workflow. if __name__ == '__main__':
# Load single-cell expression matrix (cells x genes)
sc_data = pd.read_csv('scrna_counts.tsv', sep='\t')
**Usage:** # Infer cell-type-specific regulatory network
```bash network = grnboost2(expression_data=sc_data, seed=42)
python scripts/basic_grn_inference.py expression_data.tsv \
-t tf_names.txt \ # Filter high-confidence links
-o network.tsv \ high_confidence = network[network['importance'] > 0.5]
-s 42 \ high_confidence.to_csv('grn_high_confidence.tsv', sep='\t', index=False)
--transpose # if genes are rows
``` ```
**Features:** ### Bulk RNA-seq with TF Filtering
- Automatic data loading and validation ```python
- Optional TF list specification from arboreto.utils import load_tf_names
- Configurable output format from arboreto.algo import grnboost2
- Data transposition support
- Summary statistics
### scripts/distributed_inference.py if __name__ == '__main__':
# Load data
expression_data = pd.read_csv('rnaseq_tpm.tsv', sep='\t')
tf_names = load_tf_names('human_tfs.txt')
GRN inference with custom Dask cluster configuration. # Infer with TF restriction
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=123
)
**Usage:** network.to_csv('tf_target_network.tsv', sep='\t', index=False)
```bash
python scripts/distributed_inference.py expression_data.tsv \
-t tf_names.txt \
-w 8 \
-m 4GB \
--threads 2 \
--dashboard-port 8787
``` ```
**Features:** ### Comparative Analysis (Multiple Conditions)
- Configurable worker count and memory limits ```python
- Dask dashboard integration from arboreto.algo import grnboost2
- Thread configuration
- Resource monitoring
### scripts/compare_algorithms.py if __name__ == '__main__':
# Infer networks for different conditions
conditions = ['control', 'treatment_24h', 'treatment_48h']
Compare GRNBoost2 and GENIE3 side-by-side. for condition in conditions:
data = pd.read_csv(f'{condition}_expression.tsv', sep='\t')
**Usage:** network = grnboost2(expression_data=data, seed=42)
```bash network.to_csv(f'{condition}_network.tsv', sep='\t', index=False)
python scripts/compare_algorithms.py expression_data.tsv \
-t tf_names.txt \
--top-n 100
``` ```
**Features:** ## Output Interpretation
- Runtime comparison
- Network statistics
- Prediction overlap analysis
- Top prediction comparison
## Reference Documentation Arboreto returns a DataFrame with regulatory links:
Detailed API documentation is available in [references/api_reference.md](references/api_reference.md), including: | Column | Description |
- Complete parameter descriptions for all functions |--------|-------------|
- Data format specifications | `TF` | Transcription factor (regulator) |
- Distributed computing configuration | `target` | Target gene |
- Performance optimization tips | `importance` | Regulatory importance score (higher = stronger) |
- Integration with pySCENIC
- Comprehensive examples
Load this reference when: **Filtering strategy**:
- Working with advanced Dask configurations - Top N links per target gene
- Troubleshooting complex deployment scenarios - Importance threshold (e.g., > 0.5)
- Understanding algorithm internals - Statistical significance testing (permutation tests)
- Optimizing performance for specific use cases
## Integration with pySCENIC ## Integration with pySCENIC
Arboreto is the first step in the pySCENIC single-cell analysis pipeline: Arboreto is a core component of the SCENIC pipeline for single-cell regulatory network analysis:
1. **GRN Inference (arboreto)** ← This skill
- Input: Expression matrix
- Output: Regulatory network
2. **Regulon Prediction (pySCENIC)**
- Input: Network from arboreto
- Output: Refined regulons
3. **Cell Type Identification (pySCENIC)**
- Input: Regulons
- Output: Cell type scores
When working with pySCENIC, use arboreto to generate the initial network, then pass results to the pySCENIC pipeline.
## Best Practices
1. **Always use seed parameter** for reproducible research
2. **Validate data orientation** (rows = observations, columns = genes)
3. **Specify TF list** when known to focus inference and improve speed
4. **Monitor with Dask dashboard** for distributed computing
5. **Save intermediate results** to avoid re-running long computations
6. **Filter results** by importance threshold for downstream analysis
7. **Use GRNBoost2 by default** unless specifically requiring GENIE3
8. **Include multiprocessing guard** (`if __name__ == '__main__':`) in all scripts
## Quick Reference
**Basic inference:**
```python ```python
# Step 1: Use arboreto for GRN inference
from arboreto.algo import grnboost2 from arboreto.algo import grnboost2
network = grnboost2(expression_data=expr_df, seed=42) network = grnboost2(expression_data=sc_data, tf_names=tf_list)
# Step 2: Use pySCENIC for regulon identification and activity scoring
# (See pySCENIC documentation for downstream analysis)
``` ```
**With TF specification:** ## Reproducibility
Always set a seed for reproducible results:
```python ```python
network = grnboost2(expression_data=expr_df, tf_names=tf_list, seed=42) network = grnboost2(expression_data=matrix, seed=777)
``` ```
**With custom Dask client:** Run multiple seeds for robustness analysis:
```python ```python
from dask.distributed import Client, LocalCluster from distributed import LocalCluster, Client
cluster = LocalCluster(n_workers=4)
client = Client(cluster) if __name__ == '__main__':
network = grnboost2(expression_data=expr_df, client_or_address=client, seed=42) client = Client(LocalCluster())
client.close()
cluster.close() seeds = [42, 123, 777]
networks = []
for seed in seeds:
net = grnboost2(expression_data=matrix, client_or_address=client, seed=seed)
networks.append(net)
# Combine networks and filter consensus links
consensus = analyze_consensus(networks)
``` ```
**Load TF names:** ## Troubleshooting
```python
from arboreto.utils import load_tf_names
tf_names = load_tf_names('transcription_factors.txt')
```
**Transpose data:** **Memory errors**: Reduce dataset size by filtering low-variance genes or use distributed computing
```python
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T **Slow performance**: Use GRNBoost2 instead of GENIE3, enable distributed client, filter TF list
```
**Dask errors**: Ensure `if __name__ == '__main__':` guard is present in scripts
**Empty results**: Check data format (genes as columns), verify TF names match gene names

View File

@@ -0,0 +1,138 @@
# GRN Inference Algorithms
Arboreto provides two algorithms for gene regulatory network (GRN) inference, both based on the multiple regression approach.
## Algorithm Overview
Both algorithms follow the same inference strategy:
1. For each target gene in the dataset, train a regression model
2. Identify the most important features (potential regulators) from the model
3. Emit these features as candidate regulators with importance scores
The key difference is **computational efficiency** and the underlying regression method.
## GRNBoost2 (Recommended)
**Purpose**: Fast GRN inference for large-scale datasets using gradient boosting.
### When to Use
- **Large datasets**: Tens of thousands of observations (e.g., single-cell RNA-seq)
- **Time-constrained analysis**: Need faster results than GENIE3
- **Default choice**: GRNBoost2 is the flagship algorithm and recommended for most use cases
### Technical Details
- **Method**: Stochastic gradient boosting with early-stopping regularization
- **Performance**: Significantly faster than GENIE3 on large datasets
- **Output**: Same format as GENIE3 (TF-target-importance triplets)
### Usage
```python
from arboreto.algo import grnboost2
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
seed=42 # For reproducibility
)
```
### Parameters
```python
grnboost2(
expression_data, # Required: pandas DataFrame or numpy array
gene_names=None, # Required for numpy arrays
tf_names='all', # List of TF names or 'all'
verbose=False, # Print progress messages
client_or_address='local', # Dask client or scheduler address
seed=None # Random seed for reproducibility
)
```
## GENIE3
**Purpose**: Classic Random Forest-based GRN inference, serving as the conceptual blueprint.
### When to Use
- **Smaller datasets**: When dataset size allows for longer computation
- **Comparison studies**: When comparing with published GENIE3 results
- **Validation**: To validate GRNBoost2 results
### Technical Details
- **Method**: Random Forest or ExtraTrees regression
- **Foundation**: Original multiple regression GRN inference strategy
- **Trade-off**: More computationally expensive but well-established
### Usage
```python
from arboreto.algo import genie3
network = genie3(
expression_data=expression_matrix,
tf_names=tf_names,
seed=42
)
```
### Parameters
```python
genie3(
expression_data, # Required: pandas DataFrame or numpy array
gene_names=None, # Required for numpy arrays
tf_names='all', # List of TF names or 'all'
verbose=False, # Print progress messages
client_or_address='local', # Dask client or scheduler address
seed=None # Random seed for reproducibility
)
```
## Algorithm Comparison
| Feature | GRNBoost2 | GENIE3 |
|---------|-----------|--------|
| **Speed** | Fast (optimized for large data) | Slower |
| **Method** | Gradient boosting | Random Forest |
| **Best for** | Large-scale data (10k+ observations) | Small-medium datasets |
| **Output format** | Same | Same |
| **Inference strategy** | Multiple regression | Multiple regression |
| **Recommended** | Yes (default choice) | For comparison/validation |
## Advanced: Custom Regressor Parameters
For advanced users, pass custom scikit-learn regressor parameters:
```python
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
# Custom GRNBoost2 parameters
custom_grnboost2 = grnboost2(
expression_data=expression_matrix,
regressor_type='GBM',
regressor_kwargs={
'n_estimators': 100,
'max_depth': 5,
'learning_rate': 0.1
}
)
# Custom GENIE3 parameters
custom_genie3 = genie3(
expression_data=expression_matrix,
regressor_type='RF',
regressor_kwargs={
'n_estimators': 1000,
'max_features': 'sqrt'
}
)
```
## Choosing the Right Algorithm
**Decision guide**:
1. **Start with GRNBoost2** - It's faster and handles large datasets better
2. **Use GENIE3 if**:
- Comparing with existing GENIE3 publications
- Dataset is small-medium sized
- Validating GRNBoost2 results
Both algorithms produce comparable regulatory networks with the same output format, making them interchangeable for most analyses.

View File

@@ -1,271 +0,0 @@
# Arboreto API Reference
This document provides comprehensive API documentation for the arboreto package, a Python library for gene regulatory network (GRN) inference.
## Overview
Arboreto enables inference of gene regulatory networks from expression data using machine learning algorithms. It supports distributed computing via Dask for scalability from single machines to multi-node clusters.
**Current Version:** 0.1.5
**GitHub:** https://github.com/tmoerman/arboreto
**License:** BSD 3-Clause
## Core Algorithms
### GRNBoost2
The flagship algorithm for fast gene regulatory network inference using stochastic gradient boosting.
**Function:** `arboreto.algo.grnboost2()`
**Parameters:**
- `expression_data` (pandas.DataFrame or numpy.ndarray): Expression matrix where rows are observations (cells/samples) and columns are genes. Required.
- `gene_names` (list, optional): List of gene names matching column order. If None, uses DataFrame column names.
- `tf_names` (list, optional): List of transcription factor names to consider as regulators. If None, all genes are considered potential regulators.
- `seed` (int, optional): Random seed for reproducibility. Recommended when consistent results are needed across runs.
- `client_or_address` (dask.distributed.Client or str, optional): Custom Dask client or scheduler address for distributed computing. If None, creates a default local client.
- `verbose` (bool, optional): Enable verbose output for debugging.
**Returns:**
- pandas.DataFrame with columns `['TF', 'target', 'importance']` representing inferred regulatory links. Each row represents a regulatory relationship with an importance score.
**Algorithm Details:**
- Uses stochastic gradient boosting with early-stopping regularization
- Much faster than GENIE3, especially for large datasets (tens of thousands of observations)
- Extracts important features from trained regression models to identify regulatory relationships
- Recommended as the default choice for most use cases
**Example:**
```python
from arboreto.algo import grnboost2
import pandas as pd
# Load expression data
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
tf_list = ['TF1', 'TF2', 'TF3'] # Optional: specify TFs
# Run inference
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_list,
seed=42 # For reproducibility
)
# Save results
network.to_csv('output_network.tsv', sep='\t', index=False)
```
### GENIE3
Classical gene regulatory network inference using Random Forest regression.
**Function:** `arboreto.algo.genie3()`
**Parameters:**
Same as GRNBoost2 (see above).
**Returns:**
Same format as GRNBoost2 (see above).
**Algorithm Details:**
- Uses Random Forest or ExtraTrees regression models
- Blueprint for multiple regression GRN inference strategy
- More computationally expensive than GRNBoost2
- Better suited for smaller datasets or when maximum accuracy is needed
**When to Use GENIE3 vs GRNBoost2:**
- **Use GRNBoost2:** For large datasets, faster results, or when computational resources are limited
- **Use GENIE3:** For smaller datasets, when following established protocols, or for comparison with published results
## Module Structure
### arboreto.algo
Primary module for typical users. Contains high-level inference functions.
**Main Functions:**
- `grnboost2()` - Fast GRN inference using gradient boosting
- `genie3()` - Classical GRN inference using Random Forest
### arboreto.core
Advanced module for power users. Contains low-level framework components for custom implementations.
**Use cases:**
- Custom inference pipelines
- Algorithm modifications
- Performance tuning
### arboreto.utils
Utility functions for common data processing tasks.
**Key Functions:**
- `load_tf_names(filename)` - Load transcription factor names from file
- Reads a text file with one TF name per line
- Returns a list of TF names
- Example: `tf_names = load_tf_names('transcription_factors.txt')`
## Data Format Requirements
### Input Format
**Expression Matrix:**
- **Format:** pandas DataFrame or numpy ndarray
- **Orientation:** Rows = observations (cells/samples), Columns = genes
- **Convention:** Follows scikit-learn format
- **Gene Names:** Column names (DataFrame) or separate `gene_names` parameter
- **Data Type:** Numeric (float or int)
**Common Mistake:** If data is transposed (genes as rows), use pandas to transpose:
```python
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T
```
**Transcription Factor List:**
- **Format:** Python list of strings or text file (one TF per line)
- **Optional:** If not provided, all genes are considered potential regulators
- **Example:** `['Sox2', 'Oct4', 'Nanog']`
### Output Format
**Network DataFrame:**
- **Columns:**
- `TF` (str): Transcription factor (regulator) gene name
- `target` (str): Target gene name
- `importance` (float): Importance score of the regulatory relationship
- **Interpretation:** Higher importance scores indicate stronger predicted regulatory relationships
- **Sorting:** Typically sorted by importance (descending) for prioritization
**Example Output:**
```
TF target importance
Sox2 Gene1 15.234
Oct4 Gene1 12.456
Sox2 Gene2 8.901
```
## Distributed Computing with Dask
### Local Execution (Default)
Arboreto automatically creates a local Dask client if none is provided:
```python
network = grnboost2(expression_data=expr_matrix, tf_names=tf_list)
```
### Custom Local Cluster
For better control over resources or multiple inferences:
```python
from dask.distributed import Client, LocalCluster
# Configure cluster
cluster = LocalCluster(
n_workers=4,
threads_per_worker=2,
memory_limit='4GB'
)
client = Client(cluster)
# Run inference
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
client_or_address=client
)
# Clean up
client.close()
cluster.close()
```
### Distributed Cluster
For multi-node computation:
**On scheduler node:**
```bash
dask-scheduler --no-bokeh # Use --no-bokeh to avoid Bokeh errors
```
**On worker nodes:**
```bash
dask-worker scheduler-address:8786 --local-dir /tmp
```
**In Python script:**
```python
from dask.distributed import Client
client = Client('scheduler-address:8786')
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
client_or_address=client
)
```
### Dask Dashboard
Monitor computation progress via the Dask dashboard:
```python
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(diagnostics_port=8787)
client = Client(cluster)
# Dashboard available at: http://localhost:8787
```
## Reproducibility
To ensure reproducible results across runs:
```python
network = grnboost2(
expression_data=expr_matrix,
tf_names=tf_list,
seed=42 # Fixed seed ensures identical results
)
```
**Note:** Without a seed parameter, results may vary slightly between runs due to randomness in the algorithms.
## Performance Considerations
### Memory Management
- Expression matrices should fit in memory (RAM)
- For very large datasets, consider:
- Using a machine with more RAM
- Distributing across multiple nodes
- Preprocessing to reduce dimensionality
### Worker Configuration
- **Local execution:** Number of workers = number of CPU cores (default)
- **Custom cluster:** Balance workers and threads based on available resources
- **Distributed execution:** Ensure adequate `local_dir` space on worker nodes
### Algorithm Choice
- **GRNBoost2:** ~10-100x faster than GENIE3 for large datasets
- **GENIE3:** More established but slower, better for small datasets (<10k observations)
## Integration with pySCENIC
Arboreto is a core component of the pySCENIC pipeline for single-cell RNA sequencing analysis:
1. **GRN Inference (Arboreto):** Infer regulatory networks using GRNBoost2
2. **Regulon Prediction:** Prune network and identify regulons
3. **Cell Type Identification:** Score regulons across cells
For pySCENIC workflows, arboreto is typically used in the first step to generate the initial regulatory network.
## Common Issues and Solutions
See the main SKILL.md for troubleshooting guidance.

View File

@@ -0,0 +1,151 @@
# Basic GRN Inference with Arboreto
## Input Data Requirements
Arboreto requires gene expression data in one of two formats:
### Pandas DataFrame (Recommended)
- **Rows**: Observations (cells, samples, conditions)
- **Columns**: Genes (with gene names as column headers)
- **Format**: Numeric expression values
Example:
```python
import pandas as pd
# Load expression matrix with genes as columns
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
# Columns: ['gene1', 'gene2', 'gene3', ...]
# Rows: observation data
```
### NumPy Array
- **Shape**: (observations, genes)
- **Requirement**: Separately provide gene names list matching column order
Example:
```python
import numpy as np
expression_matrix = np.genfromtxt('expression_data.tsv', delimiter='\t', skip_header=1)
with open('expression_data.tsv') as f:
gene_names = [gene.strip() for gene in f.readline().split('\t')]
assert expression_matrix.shape[1] == len(gene_names)
```
## Transcription Factors (TFs)
Optionally provide a list of transcription factor names to restrict regulatory inference:
```python
from arboreto.utils import load_tf_names
# Load from file (one TF per line)
tf_names = load_tf_names('transcription_factors.txt')
# Or define directly
tf_names = ['TF1', 'TF2', 'TF3']
```
If not provided, all genes are considered potential regulators.
## Basic Inference Workflow
### Using Pandas DataFrame
```python
import pandas as pd
from arboreto.utils import load_tf_names
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Load expression data
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
# Load transcription factors (optional)
tf_names = load_tf_names('tf_list.txt')
# Run GRN inference
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names # Optional
)
# Save results
network.to_csv('network_output.tsv', sep='\t', index=False, header=False)
```
**Critical**: The `if __name__ == '__main__':` guard is required because Dask spawns new processes internally.
### Using NumPy Array
```python
import numpy as np
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Load expression matrix
expression_matrix = np.genfromtxt('expression_data.tsv', delimiter='\t', skip_header=1)
# Extract gene names from header
with open('expression_data.tsv') as f:
gene_names = [gene.strip() for gene in f.readline().split('\t')]
# Verify dimensions match
assert expression_matrix.shape[1] == len(gene_names)
# Run inference with explicit gene names
network = grnboost2(
expression_data=expression_matrix,
gene_names=gene_names,
tf_names=tf_names
)
network.to_csv('network_output.tsv', sep='\t', index=False, header=False)
```
## Output Format
Arboreto returns a Pandas DataFrame with three columns:
| Column | Description |
|--------|-------------|
| `TF` | Transcription factor (regulator) gene name |
| `target` | Target gene name |
| `importance` | Regulatory importance score (higher = stronger regulation) |
Example output:
```
TF1 gene5 0.856
TF2 gene12 0.743
TF1 gene8 0.621
```
## Setting Random Seed
For reproducible results, provide a seed parameter:
```python
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
seed=777
)
```
## Algorithm Selection
Use `grnboost2()` for most cases (faster, handles large datasets):
```python
from arboreto.algo import grnboost2
network = grnboost2(expression_data=expression_matrix)
```
Use `genie3()` for comparison or specific requirements:
```python
from arboreto.algo import genie3
network = genie3(expression_data=expression_matrix)
```
See `references/algorithms.md` for detailed algorithm comparison.

View File

@@ -0,0 +1,242 @@
# Distributed Computing with Arboreto
Arboreto leverages Dask for parallelized computation, enabling efficient GRN inference from single-machine multi-core processing to multi-node cluster environments.
## Computation Architecture
GRN inference is inherently parallelizable:
- Each target gene's regression model can be trained independently
- Arboreto represents computation as a Dask task graph
- Tasks are distributed across available computational resources
## Local Multi-Core Processing (Default)
By default, arboreto uses all available CPU cores on the local machine:
```python
from arboreto.algo import grnboost2
# Automatically uses all local cores
network = grnboost2(expression_data=expression_matrix, tf_names=tf_names)
```
This is sufficient for most use cases and requires no additional configuration.
## Custom Local Dask Client
For fine-grained control over local resources, create a custom Dask client:
```python
from distributed import LocalCluster, Client
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Configure local cluster
local_cluster = LocalCluster(
n_workers=10, # Number of worker processes
threads_per_worker=1, # Threads per worker
memory_limit='8GB' # Memory limit per worker
)
# Create client
custom_client = Client(local_cluster)
# Run inference with custom client
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
client_or_address=custom_client
)
# Clean up
custom_client.close()
local_cluster.close()
```
### Benefits of Custom Client
- **Resource control**: Limit CPU and memory usage
- **Multiple runs**: Reuse same client for different parameter sets
- **Monitoring**: Access Dask dashboard for performance insights
## Multiple Inference Runs with Same Client
Reuse a single Dask client for multiple inference runs with different parameters:
```python
from distributed import LocalCluster, Client
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Initialize client once
local_cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(local_cluster)
# Run multiple inferences
network_seed1 = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
client_or_address=client,
seed=666
)
network_seed2 = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
client_or_address=client,
seed=777
)
# Different algorithms with same client
from arboreto.algo import genie3
network_genie3 = genie3(
expression_data=expression_matrix,
tf_names=tf_names,
client_or_address=client
)
# Clean up once
client.close()
local_cluster.close()
```
## Distributed Cluster Computing
For very large datasets, connect to a remote Dask distributed scheduler running on a cluster:
### Step 1: Set Up Dask Scheduler (on cluster head node)
```bash
dask-scheduler
# Output: Scheduler at tcp://10.118.224.134:8786
```
### Step 2: Start Dask Workers (on cluster compute nodes)
```bash
dask-worker tcp://10.118.224.134:8786
```
### Step 3: Connect from Client
```python
from distributed import Client
from arboreto.algo import grnboost2
if __name__ == '__main__':
# Connect to remote scheduler
scheduler_address = 'tcp://10.118.224.134:8786'
cluster_client = Client(scheduler_address)
# Run inference on cluster
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
client_or_address=cluster_client
)
cluster_client.close()
```
### Cluster Configuration Best Practices
**Worker configuration**:
```bash
dask-worker tcp://scheduler:8786 \
--nprocs 4 \ # Number of processes per node
--nthreads 1 \ # Threads per process
--memory-limit 16GB # Memory per process
```
**For large-scale inference**:
- Use more workers with moderate memory rather than fewer workers with large memory
- Set `threads_per_worker=1` to avoid GIL contention in scikit-learn
- Monitor memory usage to prevent workers from being killed
## Monitoring and Debugging
### Dask Dashboard
Access the Dask dashboard for real-time monitoring:
```python
from distributed import Client
client = Client() # Prints dashboard URL
# Dashboard available at: http://localhost:8787/status
```
The dashboard shows:
- **Task progress**: Number of tasks completed/pending
- **Resource usage**: CPU, memory per worker
- **Task stream**: Real-time visualization of computation
- **Performance**: Bottleneck identification
### Verbose Output
Enable verbose logging to track inference progress:
```python
network = grnboost2(
expression_data=expression_matrix,
tf_names=tf_names,
verbose=True
)
```
## Performance Optimization Tips
### 1. Data Format
- **Use Pandas DataFrame when possible**: More efficient than NumPy for Dask operations
- **Reduce data size**: Filter low-variance genes before inference
### 2. Worker Configuration
- **CPU-bound tasks**: Set `threads_per_worker=1`, increase `n_workers`
- **Memory-bound tasks**: Increase `memory_limit` per worker
### 3. Cluster Setup
- **Network**: Ensure high-bandwidth, low-latency network between nodes
- **Storage**: Use shared filesystem or object storage for large datasets
- **Scheduling**: Allocate dedicated nodes to avoid resource contention
### 4. Transcription Factor Filtering
- **Limit TF list**: Providing specific TF names reduces computation
```python
# Full search (slow)
network = grnboost2(expression_data=matrix)
# Filtered search (faster)
network = grnboost2(expression_data=matrix, tf_names=known_tfs)
```
## Example: Large-Scale Single-Cell Analysis
Complete workflow for processing single-cell RNA-seq data on a cluster:
```python
from distributed import Client
from arboreto.algo import grnboost2
import pandas as pd
if __name__ == '__main__':
# Connect to cluster
client = Client('tcp://cluster-scheduler:8786')
# Load large single-cell dataset (50,000 cells x 20,000 genes)
expression_data = pd.read_csv('scrnaseq_data.tsv', sep='\t')
# Load cell-type-specific TFs
tf_names = pd.read_csv('tf_list.txt', header=None)[0].tolist()
# Run distributed inference
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
client_or_address=client,
verbose=True,
seed=42
)
# Save results
network.to_csv('grn_results.tsv', sep='\t', index=False)
client.close()
```
This approach enables analysis of datasets that would be impractical on a single machine.

View File

@@ -1,18 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Basic GRN inference script using arboreto GRNBoost2. Basic GRN inference example using Arboreto.
This script demonstrates the standard workflow for gene regulatory network inference: This script demonstrates the standard workflow for inferring gene regulatory
1. Load expression data networks from expression data using GRNBoost2.
2. Optionally load transcription factor names
3. Run GRNBoost2 inference
4. Save results
Usage: Usage:
python basic_grn_inference.py <expression_file> [options] python basic_grn_inference.py <expression_file> <output_file> [--tf-file TF_FILE] [--seed SEED]
Example: Arguments:
python basic_grn_inference.py expression_data.tsv -t tf_names.txt -o network.tsv expression_file: Path to expression matrix (TSV format, genes as columns)
output_file: Path for output network (TSV format)
--tf-file: Optional path to transcription factors file (one per line)
--seed: Random seed for reproducibility (default: 777)
""" """
import argparse import argparse
@@ -21,90 +21,77 @@ from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names from arboreto.utils import load_tf_names
def main(): def run_grn_inference(expression_file, output_file, tf_file=None, seed=777):
"""
Run GRN inference using GRNBoost2.
Args:
expression_file: Path to expression matrix TSV file
output_file: Path for output network file
tf_file: Optional path to TF names file
seed: Random seed for reproducibility
"""
print(f"Loading expression data from {expression_file}...")
expression_data = pd.read_csv(expression_file, sep='\t')
print(f"Expression matrix shape: {expression_data.shape}")
print(f"Number of genes: {expression_data.shape[1]}")
print(f"Number of observations: {expression_data.shape[0]}")
# Load TF names if provided
tf_names = 'all'
if tf_file:
print(f"Loading transcription factors from {tf_file}...")
tf_names = load_tf_names(tf_file)
print(f"Number of TFs: {len(tf_names)}")
# Run GRN inference
print(f"Running GRNBoost2 with seed={seed}...")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=seed,
verbose=True
)
# Save results
print(f"Saving network to {output_file}...")
network.to_csv(output_file, sep='\t', index=False, header=False)
print(f"Done! Network contains {len(network)} regulatory links.")
print(f"\nTop 10 regulatory links:")
print(network.head(10).to_string(index=False))
if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Infer gene regulatory network using GRNBoost2' description='Infer gene regulatory network using GRNBoost2'
) )
parser.add_argument( parser.add_argument(
'expression_file', 'expression_file',
help='Path to expression data file (TSV/CSV format)' help='Path to expression matrix (TSV format, genes as columns)'
) )
parser.add_argument( parser.add_argument(
'-t', '--tf-file', 'output_file',
help='Path to file containing transcription factor names (one per line)', help='Path for output network (TSV format)'
)
parser.add_argument(
'--tf-file',
help='Path to transcription factors file (one per line)',
default=None default=None
) )
parser.add_argument( parser.add_argument(
'-o', '--output', '--seed',
help='Output file path for network results', help='Random seed for reproducibility (default: 777)',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int, type=int,
help='Random seed for reproducibility', default=777
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
) )
args = parser.parse_args() args = parser.parse_args()
# Load expression data run_grn_inference(
print(f"Loading expression data from {args.expression_file}...") expression_file=args.expression_file,
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0) output_file=args.output_file,
tf_file=args.tf_file,
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference...")
print(" (This may take a while depending on dataset size)")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed seed=args.seed
) )
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
if __name__ == '__main__':
main()

View File

@@ -1,205 +0,0 @@
#!/usr/bin/env python3
"""
Compare GRNBoost2 and GENIE3 algorithms on the same dataset.
This script runs both algorithms on the same expression data and compares:
- Runtime
- Number of predicted links
- Top predicted relationships
- Overlap between predictions
Usage:
python compare_algorithms.py <expression_file> [options]
Example:
python compare_algorithms.py expression_data.tsv -t tf_names.txt
"""
import argparse
import time
import pandas as pd
from arboreto.algo import grnboost2, genie3
from arboreto.utils import load_tf_names
def compare_networks(network1, network2, name1, name2, top_n=100):
"""Compare two inferred networks."""
print(f"\n{'='*60}")
print("Network Comparison")
print(f"{'='*60}")
# Basic statistics
print(f"\n{name1} Statistics:")
print(f" Total links: {len(network1)}")
print(f" Unique TFs: {network1['TF'].nunique()}")
print(f" Unique targets: {network1['target'].nunique()}")
print(f" Importance range: [{network1['importance'].min():.3f}, {network1['importance'].max():.3f}]")
print(f"\n{name2} Statistics:")
print(f" Total links: {len(network2)}")
print(f" Unique TFs: {network2['TF'].nunique()}")
print(f" Unique targets: {network2['target'].nunique()}")
print(f" Importance range: [{network2['importance'].min():.3f}, {network2['importance'].max():.3f}]")
# Compare top predictions
print(f"\nTop {top_n} Predictions Overlap:")
# Create edge sets for top N predictions
top_edges1 = set(
zip(network1.head(top_n)['TF'], network1.head(top_n)['target'])
)
top_edges2 = set(
zip(network2.head(top_n)['TF'], network2.head(top_n)['target'])
)
# Calculate overlap
overlap = top_edges1 & top_edges2
only_net1 = top_edges1 - top_edges2
only_net2 = top_edges2 - top_edges1
overlap_pct = (len(overlap) / top_n) * 100
print(f" Shared edges: {len(overlap)} ({overlap_pct:.1f}%)")
print(f" Only in {name1}: {len(only_net1)}")
print(f" Only in {name2}: {len(only_net2)}")
# Show some example overlapping edges
if overlap:
print(f"\nExample overlapping predictions:")
for i, (tf, target) in enumerate(list(overlap)[:5], 1):
print(f" {i}. {tf} -> {target}")
def main():
parser = argparse.ArgumentParser(
description='Compare GRNBoost2 and GENIE3 algorithms'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'--grnboost2-output',
help='Output file path for GRNBoost2 results',
default='grnboost2_network.tsv'
)
parser.add_argument(
'--genie3-output',
help='Output file path for GENIE3 results',
default='genie3_network.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
parser.add_argument(
'--top-n',
type=int,
help='Number of top predictions to compare (default: 100)',
default=100
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\n" + "="*60)
print("Running GRNBoost2...")
print("="*60)
start_time = time.time()
grnboost2_network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
grnboost2_time = time.time() - start_time
print(f"GRNBoost2 completed in {grnboost2_time:.2f} seconds")
# Save GRNBoost2 results
grnboost2_network.to_csv(args.grnboost2_output, sep='\t', index=False)
print(f"Results saved to {args.grnboost2_output}")
# Run GENIE3
print("\n" + "="*60)
print("Running GENIE3...")
print("="*60)
start_time = time.time()
genie3_network = genie3(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
genie3_time = time.time() - start_time
print(f"GENIE3 completed in {genie3_time:.2f} seconds")
# Save GENIE3 results
genie3_network.to_csv(args.genie3_output, sep='\t', index=False)
print(f"Results saved to {args.genie3_output}")
# Compare runtimes
print("\n" + "="*60)
print("Runtime Comparison")
print("="*60)
print(f"GRNBoost2: {grnboost2_time:.2f} seconds")
print(f"GENIE3: {genie3_time:.2f} seconds")
speedup = genie3_time / grnboost2_time
print(f"Speedup: {speedup:.2f}x (GRNBoost2 is {speedup:.2f}x faster)")
# Compare networks
compare_networks(
grnboost2_network,
genie3_network,
"GRNBoost2",
"GENIE3",
top_n=args.top_n
)
print("\n" + "="*60)
print("Comparison complete!")
print("="*60)
if __name__ == '__main__':
main()

View File

@@ -1,157 +0,0 @@
#!/usr/bin/env python3
"""
Distributed GRN inference script using arboreto with custom Dask configuration.
This script demonstrates how to use arboreto with a custom Dask LocalCluster
for better control over computational resources.
Usage:
python distributed_inference.py <expression_file> [options]
Example:
python distributed_inference.py expression_data.tsv -t tf_names.txt -w 8 -m 4GB
"""
import argparse
import pandas as pd
from dask.distributed import Client, LocalCluster
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
parser = argparse.ArgumentParser(
description='Distributed GRN inference using GRNBoost2 with custom Dask cluster'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'-w', '--workers',
type=int,
help='Number of Dask workers',
default=4
)
parser.add_argument(
'-m', '--memory-limit',
help='Memory limit per worker (e.g., "4GB", "2000MB")',
default='4GB'
)
parser.add_argument(
'--threads',
type=int,
help='Threads per worker',
default=2
)
parser.add_argument(
'--dashboard-port',
type=int,
help='Port for Dask dashboard (default: 8787)',
default=8787
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Set up Dask cluster
print(f"\nSetting up Dask LocalCluster...")
print(f" Workers: {args.workers}")
print(f" Threads per worker: {args.threads}")
print(f" Memory limit per worker: {args.memory_limit}")
print(f" Dashboard: http://localhost:{args.dashboard_port}")
cluster = LocalCluster(
n_workers=args.workers,
threads_per_worker=args.threads,
memory_limit=args.memory_limit,
diagnostics_port=args.dashboard_port
)
client = Client(cluster)
print(f"\nDask cluster ready!")
print(f" Dashboard available at: {client.dashboard_link}")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference with distributed computation...")
print(" (Monitor progress via the Dask dashboard)")
try:
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed,
client_or_address=client
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
finally:
# Clean up Dask resources
print("\nClosing Dask cluster...")
client.close()
cluster.close()
if __name__ == '__main__':
main()