Improve the arboreto skill

This commit is contained in:
Timothy Kassis
2025-11-03 16:21:26 -08:00
parent 6ddea4786e
commit 537edff2a1
8 changed files with 758 additions and 1038 deletions

View File

@@ -1,18 +1,18 @@
#!/usr/bin/env python3
"""
Basic GRN inference script using arboreto GRNBoost2.
Basic GRN inference example using Arboreto.
This script demonstrates the standard workflow for gene regulatory network inference:
1. Load expression data
2. Optionally load transcription factor names
3. Run GRNBoost2 inference
4. Save results
This script demonstrates the standard workflow for inferring gene regulatory
networks from expression data using GRNBoost2.
Usage:
python basic_grn_inference.py <expression_file> [options]
python basic_grn_inference.py <expression_file> <output_file> [--tf-file TF_FILE] [--seed SEED]
Example:
python basic_grn_inference.py expression_data.tsv -t tf_names.txt -o network.tsv
Arguments:
expression_file: Path to expression matrix (TSV format, genes as columns)
output_file: Path for output network (TSV format)
--tf-file: Optional path to transcription factors file (one per line)
--seed: Random seed for reproducibility (default: 777)
"""
import argparse
@@ -21,90 +21,77 @@ from arboreto.algo import grnboost2
from arboreto.utils import load_tf_names
def main():
def run_grn_inference(expression_file, output_file, tf_file=None, seed=777):
"""
Run GRN inference using GRNBoost2.
Args:
expression_file: Path to expression matrix TSV file
output_file: Path for output network file
tf_file: Optional path to TF names file
seed: Random seed for reproducibility
"""
print(f"Loading expression data from {expression_file}...")
expression_data = pd.read_csv(expression_file, sep='\t')
print(f"Expression matrix shape: {expression_data.shape}")
print(f"Number of genes: {expression_data.shape[1]}")
print(f"Number of observations: {expression_data.shape[0]}")
# Load TF names if provided
tf_names = 'all'
if tf_file:
print(f"Loading transcription factors from {tf_file}...")
tf_names = load_tf_names(tf_file)
print(f"Number of TFs: {len(tf_names)}")
# Run GRN inference
print(f"Running GRNBoost2 with seed={seed}...")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
seed=seed,
verbose=True
)
# Save results
print(f"Saving network to {output_file}...")
network.to_csv(output_file, sep='\t', index=False, header=False)
print(f"Done! Network contains {len(network)} regulatory links.")
print(f"\nTop 10 regulatory links:")
print(network.head(10).to_string(index=False))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Infer gene regulatory network using GRNBoost2'
)
parser.add_argument(
'expression_file',
help='Path to expression data file (TSV/CSV format)'
help='Path to expression matrix (TSV format, genes as columns)'
)
parser.add_argument(
'-t', '--tf-file',
help='Path to file containing transcription factor names (one per line)',
'output_file',
help='Path for output network (TSV format)'
)
parser.add_argument(
'--tf-file',
help='Path to transcription factors file (one per line)',
default=None
)
parser.add_argument(
'-o', '--output',
help='Output file path for network results',
default='network_output.tsv'
)
parser.add_argument(
'-s', '--seed',
'--seed',
help='Random seed for reproducibility (default: 777)',
type=int,
help='Random seed for reproducibility',
default=42
)
parser.add_argument(
'--sep',
help='Separator for input file (default: tab)',
default='\t'
)
parser.add_argument(
'--transpose',
action='store_true',
help='Transpose the expression matrix (use if genes are rows)'
default=777
)
args = parser.parse_args()
# Load expression data
print(f"Loading expression data from {args.expression_file}...")
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
# Transpose if needed
if args.transpose:
print("Transposing expression matrix...")
expression_data = expression_data.T
print(f"Expression data shape: {expression_data.shape}")
print(f" Observations (rows): {expression_data.shape[0]}")
print(f" Genes (columns): {expression_data.shape[1]}")
# Load TF names if provided
tf_names = None
if args.tf_file:
print(f"Loading transcription factor names from {args.tf_file}...")
tf_names = load_tf_names(args.tf_file)
print(f" Found {len(tf_names)} transcription factors")
else:
print("No TF file provided. Using all genes as potential regulators.")
# Run GRNBoost2
print("\nRunning GRNBoost2 inference...")
print(" (This may take a while depending on dataset size)")
network = grnboost2(
expression_data=expression_data,
tf_names=tf_names,
run_grn_inference(
expression_file=args.expression_file,
output_file=args.output_file,
tf_file=args.tf_file,
seed=args.seed
)
print(f"\nInference complete!")
print(f" Total regulatory links inferred: {len(network)}")
print(f" Unique TFs: {network['TF'].nunique()}")
print(f" Unique targets: {network['target'].nunique()}")
# Save results
print(f"\nSaving results to {args.output}...")
network.to_csv(args.output, sep='\t', index=False)
# Display top 10 predictions
print("\nTop 10 predicted regulatory relationships:")
print(network.head(10).to_string(index=False))
print("\nDone!")
if __name__ == '__main__':
main()