Improve the arboreto skill

This commit is contained in:
Timothy Kassis
2025-11-03 16:21:26 -08:00
parent 6ddea4786e
commit 537edff2a1
8 changed files with 758 additions and 1038 deletions

View File

@@ -1,18 +1,18 @@
#!/usr/bin/env python3
"""
Basic GRN inference script using arboreto GRNBoost2.
Basic GRN inference example using Arboreto.
This script demonstrates the standard workflow for gene regulatory network inference:
1. Load expression data
2. Optionally load transcription factor names
3. Run GRNBoost2 inference
4. Save results
This script demonstrates the standard workflow for inferring gene regulatory
networks from expression data using GRNBoost2.
Usage:
python basic_grn_inference.py <expression_file> [options]
python basic_grn_inference.py <expression_file> <output_file> [--tf-file TF_FILE] [--seed SEED]
Example:
python basic_grn_inference.py expression_data.tsv -t tf_names.txt -o network.tsv
Arguments:
expression_file: Path to expression matrix (TSV format, genes as columns)
output_file: Path for output network (TSV format)
--tf-file: Optional path to transcription factors file (one per line)
--seed: Random seed for reproducibility (default: 777)
"""
import argparse
@@ -21,90 +21,77 @@ from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
def run_grn_inference(expression_file, output_file, tf_file=None, seed=777):
"""
Run GRN inference using GRNBoost2.
Args:
expression_file: Path to expression matrix TSV file
output_file: Path for output network file
tf_file: Optional path to TF names file
seed: Random seed for reproducibility
"""
print(f"Loading expression data from {expression_file}...")
expression_data = pd.read_csv(expression_file, sep='\t')
print(f"Expression matrix shape: {expression_data.shape}")
print(f"Number of genes: {expression_data.shape[1]}")
print(f"Number of observations: {expression_data.shape[0]}")
# Load TF names if provided
tf_names = 'all'
if tf_file:
print(f"Loading transcription factors from {tf_file}...")
tf_names = load_tf_names(tf_file)
print(f"Number of TFs: {len(tf_names)}")
# Run GRN inference
print(f"Running GRNBoost2 with seed={seed}...")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=seed,
verbose=True
)
# Save results
print(f"Saving network to {output_file}...")
network.to_csv(output_file, sep='\t', index=False, header=False)
print(f"Done! Network contains {len(network)} regulatory links.")
print(f"\nTop 10 regulatory links:")
print(network.head(10).to_string(index=False))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Infer gene regulatory network using GRNBoost2'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
help='Path to expression matrix (TSV format, genes as columns)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
'output_file',
help='Path for output network (TSV format)'
)
parser.add_argument(
'--tf-file',
help='Path to transcription factors file (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
'--seed',
help='Random seed for reproducibility (default: 777)',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
default=777
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference...")
print(" (This may take a while depending on dataset size)")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
run_grn_inference(
expression_file=args.expression_file,
output_file=args.output_file,
tf_file=args.tf_file,
seed=args.seed
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
if __name__ == '__main__':
main()

View File

@@ -1,205 +0,0 @@
#!/usr/bin/env python3
"""
Compare GRNBoost2 and GENIE3 algorithms on the same dataset.
This script runs both algorithms on the same expression data and compares:
- Runtime
- Number of predicted links
- Top predicted relationships
- Overlap between predictions
Usage:
python compare_algorithms.py <expression_file> [options]
Example:
python compare_algorithms.py expression_data.tsv -t tf_names.txt
"""
import argparse
import time
import pandas as pd
from arboreto.algo import grnboost2, genie3
from arboreto.utils import load_tf_names
def compare_networks(network1, network2, name1, name2, top_n=100):
"""Compare two inferred networks."""
print(f"\n{'='*60}")
print("Network Comparison")
print(f"{'='*60}")
# Basic statistics
print(f"\n{name1} Statistics:")
print(f" Total links: {len(network1)}")
print(f" Unique TFs: {network1['TF'].nunique()}")
print(f" Unique targets: {network1['target'].nunique()}")
print(f" Importance range: [{network1['importance'].min():.3f}, {network1['importance'].max():.3f}]")
print(f"\n{name2} Statistics:")
print(f" Total links: {len(network2)}")
print(f" Unique TFs: {network2['TF'].nunique()}")
print(f" Unique targets: {network2['target'].nunique()}")
print(f" Importance range: [{network2['importance'].min():.3f}, {network2['importance'].max():.3f}]")
# Compare top predictions
print(f"\nTop {top_n} Predictions Overlap:")
# Create edge sets for top N predictions
top_edges1 = set(
zip(network1.head(top_n)['TF'], network1.head(top_n)['target'])
)
top_edges2 = set(
zip(network2.head(top_n)['TF'], network2.head(top_n)['target'])
)
# Calculate overlap
overlap = top_edges1 & top_edges2
only_net1 = top_edges1 - top_edges2
only_net2 = top_edges2 - top_edges1
overlap_pct = (len(overlap) / top_n) * 100
print(f" Shared edges: {len(overlap)} ({overlap_pct:.1f}%)")
print(f" Only in {name1}: {len(only_net1)}")
print(f" Only in {name2}: {len(only_net2)}")
# Show some example overlapping edges
if overlap:
print(f"\nExample overlapping predictions:")
for i, (tf, target) in enumerate(list(overlap)[:5], 1):
print(f" {i}. {tf} -> {target}")
def main():
parser = argparse.ArgumentParser(
description='Compare GRNBoost2 and GENIE3 algorithms'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'--grnboost2-output',
help='Output file path for GRNBoost2 results',
default='grnboost2_network.tsv'
)
parser.add_argument(
'--genie3-output',
help='Output file path for GENIE3 results',
default='genie3_network.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
parser.add_argument(
'--top-n',
type=int,
help='Number of top predictions to compare (default: 100)',
default=100
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\n" + "="*60)
print("Running GRNBoost2...")
print("="*60)
start_time = time.time()
grnboost2_network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
grnboost2_time = time.time() - start_time
print(f"GRNBoost2 completed in {grnboost2_time:.2f} seconds")
# Save GRNBoost2 results
grnboost2_network.to_csv(args.grnboost2_output, sep='\t', index=False)
print(f"Results saved to {args.grnboost2_output}")
# Run GENIE3
print("\n" + "="*60)
print("Running GENIE3...")
print("="*60)
start_time = time.time()
genie3_network = genie3(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed
)
genie3_time = time.time() - start_time
print(f"GENIE3 completed in {genie3_time:.2f} seconds")
# Save GENIE3 results
genie3_network.to_csv(args.genie3_output, sep='\t', index=False)
print(f"Results saved to {args.genie3_output}")
# Compare runtimes
print("\n" + "="*60)
print("Runtime Comparison")
print("="*60)
print(f"GRNBoost2: {grnboost2_time:.2f} seconds")
print(f"GENIE3: {genie3_time:.2f} seconds")
speedup = genie3_time / grnboost2_time
print(f"Speedup: {speedup:.2f}x (GRNBoost2 is {speedup:.2f}x faster)")
# Compare networks
compare_networks(
grnboost2_network,
genie3_network,
"GRNBoost2",
"GENIE3",
top_n=args.top_n
)
print("\n" + "="*60)
print("Comparison complete!")
print("="*60)
if __name__ == '__main__':
main()

View File

@@ -1,157 +0,0 @@
#!/usr/bin/env python3
"""
Distributed GRN inference script using arboreto with custom Dask configuration.
This script demonstrates how to use arboreto with a custom Dask LocalCluster
for better control over computational resources.
Usage:
python distributed_inference.py <expression_file> [options]
Example:
python distributed_inference.py expression_data.tsv -t tf_names.txt -w 8 -m 4GB
"""
import argparse
import pandas as pd
from dask.distributed import Client, LocalCluster
from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
parser = argparse.ArgumentParser(
description='Distributed GRN inference using GRNBoost2 with custom Dask cluster'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'-w', '--workers',
type=int,
help='Number of Dask workers',
default=4
)
parser.add_argument(
'-m', '--memory-limit',
help='Memory limit per worker (e.g., "4GB", "2000MB")',
default='4GB'
)
parser.add_argument(
'--threads',
type=int,
help='Threads per worker',
default=2
)
parser.add_argument(
'--dashboard-port',
type=int,
help='Port for Dask dashboard (default: 8787)',
default=8787
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Set up Dask cluster
print(f"\nSetting up Dask LocalCluster...")
print(f" Workers: {args.workers}")
print(f" Threads per worker: {args.threads}")
print(f" Memory limit per worker: {args.memory_limit}")
print(f" Dashboard: http://localhost:{args.dashboard_port}")
cluster = LocalCluster(
n_workers=args.workers,
threads_per_worker=args.threads,
memory_limit=args.memory_limit,
diagnostics_port=args.dashboard_port
)
client = Client(cluster)
print(f"\nDask cluster ready!")
print(f" Dashboard available at: {client.dashboard_link}")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference with distributed computation...")
print(" (Monitor progress via the Dask dashboard)")
try:
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=args.seed,
client_or_address=client
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
finally:
# Clean up Dask resources
print("\nClosing Dask cluster...")
client.close()
cluster.close()
if __name__ == '__main__':
main()