419 lines
15 KiB
Python
419 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Batch molecular filtering using medchem library.
|
|
|
|
This script provides a production-ready workflow for filtering compound libraries
|
|
using medchem rules, structural alerts, and custom constraints.
|
|
|
|
Usage:
|
|
python filter_molecules.py input.csv --rules rule_of_five,rule_of_cns --alerts nibr --output filtered.csv
|
|
python filter_molecules.py input.sdf --rules rule_of_drug --lilly --complexity 400 --output results.csv
|
|
python filter_molecules.py smiles.txt --nibr --pains --n-jobs -1 --output clean.csv
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional, Tuple
|
|
import json
|
|
|
|
try:
|
|
import pandas as pd
|
|
import datamol as dm
|
|
import medchem as mc
|
|
from rdkit import Chem
|
|
from tqdm import tqdm
|
|
except ImportError as e:
|
|
print(f"Error: Missing required package: {e}")
|
|
print("Install dependencies: pip install medchem datamol pandas tqdm")
|
|
sys.exit(1)
|
|
|
|
|
|
def load_molecules(input_file: Path, smiles_column: str = "smiles") -> Tuple[pd.DataFrame, List[Chem.Mol]]:
|
|
"""
|
|
Load molecules from various file formats.
|
|
|
|
Supports:
|
|
- CSV/TSV with SMILES column
|
|
- SDF files
|
|
- Plain text files with one SMILES per line
|
|
|
|
Returns:
|
|
Tuple of (DataFrame with metadata, list of RDKit molecules)
|
|
"""
|
|
suffix = input_file.suffix.lower()
|
|
|
|
if suffix == ".sdf":
|
|
print(f"Loading SDF file: {input_file}")
|
|
supplier = Chem.SDMolSupplier(str(input_file))
|
|
mols = [mol for mol in supplier if mol is not None]
|
|
|
|
# Create DataFrame from SDF properties
|
|
data = []
|
|
for mol in mols:
|
|
props = mol.GetPropsAsDict()
|
|
props["smiles"] = Chem.MolToSmiles(mol)
|
|
data.append(props)
|
|
df = pd.DataFrame(data)
|
|
|
|
elif suffix in [".csv", ".tsv"]:
|
|
print(f"Loading CSV/TSV file: {input_file}")
|
|
sep = "\t" if suffix == ".tsv" else ","
|
|
df = pd.read_csv(input_file, sep=sep)
|
|
|
|
if smiles_column not in df.columns:
|
|
print(f"Error: Column '{smiles_column}' not found in file")
|
|
print(f"Available columns: {', '.join(df.columns)}")
|
|
sys.exit(1)
|
|
|
|
print(f"Converting SMILES to molecules...")
|
|
mols = [dm.to_mol(smi) for smi in tqdm(df[smiles_column], desc="Parsing")]
|
|
|
|
elif suffix == ".txt":
|
|
print(f"Loading text file: {input_file}")
|
|
with open(input_file) as f:
|
|
smiles_list = [line.strip() for line in f if line.strip()]
|
|
|
|
df = pd.DataFrame({"smiles": smiles_list})
|
|
print(f"Converting SMILES to molecules...")
|
|
mols = [dm.to_mol(smi) for smi in tqdm(smiles_list, desc="Parsing")]
|
|
|
|
else:
|
|
print(f"Error: Unsupported file format: {suffix}")
|
|
print("Supported formats: .csv, .tsv, .sdf, .txt")
|
|
sys.exit(1)
|
|
|
|
# Filter out invalid molecules
|
|
valid_indices = [i for i, mol in enumerate(mols) if mol is not None]
|
|
if len(valid_indices) < len(mols):
|
|
n_invalid = len(mols) - len(valid_indices)
|
|
print(f"Warning: {n_invalid} invalid molecules removed")
|
|
df = df.iloc[valid_indices].reset_index(drop=True)
|
|
mols = [mols[i] for i in valid_indices]
|
|
|
|
print(f"Loaded {len(mols)} valid molecules")
|
|
return df, mols
|
|
|
|
|
|
def apply_rule_filters(mols: List[Chem.Mol], rules: List[str], n_jobs: int) -> pd.DataFrame:
|
|
"""Apply medicinal chemistry rule filters."""
|
|
print(f"\nApplying rule filters: {', '.join(rules)}")
|
|
|
|
rfilter = mc.rules.RuleFilters(rule_list=rules)
|
|
results = rfilter(mols=mols, n_jobs=n_jobs, progress=True)
|
|
|
|
# Convert to DataFrame
|
|
df_results = pd.DataFrame(results)
|
|
|
|
# Add summary column
|
|
df_results["passes_all_rules"] = df_results.all(axis=1)
|
|
|
|
return df_results
|
|
|
|
|
|
def apply_structural_alerts(mols: List[Chem.Mol], alert_type: str, n_jobs: int) -> pd.DataFrame:
|
|
"""Apply structural alert filters."""
|
|
print(f"\nApplying {alert_type} structural alerts...")
|
|
|
|
if alert_type == "common":
|
|
alert_filter = mc.structural.CommonAlertsFilters()
|
|
results = alert_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
|
|
|
df_results = pd.DataFrame({
|
|
"has_common_alerts": [r["has_alerts"] for r in results],
|
|
"num_common_alerts": [r["num_alerts"] for r in results],
|
|
"common_alert_details": [", ".join(r["alert_details"]) if r["alert_details"] else "" for r in results]
|
|
})
|
|
|
|
elif alert_type == "nibr":
|
|
nibr_filter = mc.structural.NIBRFilters()
|
|
results = nibr_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
|
|
|
df_results = pd.DataFrame({
|
|
"passes_nibr": results
|
|
})
|
|
|
|
elif alert_type == "lilly":
|
|
lilly_filter = mc.structural.LillyDemeritsFilters()
|
|
results = lilly_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
|
|
|
df_results = pd.DataFrame({
|
|
"lilly_demerits": [r["demerits"] for r in results],
|
|
"passes_lilly": [r["passes"] for r in results],
|
|
"lilly_patterns": [", ".join([p["pattern"] for p in r["matched_patterns"]]) for r in results]
|
|
})
|
|
|
|
elif alert_type == "pains":
|
|
results = [mc.rules.basic_rules.pains_filter(mol) for mol in tqdm(mols, desc="PAINS")]
|
|
|
|
df_results = pd.DataFrame({
|
|
"passes_pains": results
|
|
})
|
|
|
|
else:
|
|
raise ValueError(f"Unknown alert type: {alert_type}")
|
|
|
|
return df_results
|
|
|
|
|
|
def apply_complexity_filter(mols: List[Chem.Mol], max_complexity: float, method: str = "bertz") -> pd.DataFrame:
|
|
"""Calculate molecular complexity."""
|
|
print(f"\nCalculating molecular complexity (method={method}, max={max_complexity})...")
|
|
|
|
complexity_scores = [
|
|
mc.complexity.calculate_complexity(mol, method=method)
|
|
for mol in tqdm(mols, desc="Complexity")
|
|
]
|
|
|
|
df_results = pd.DataFrame({
|
|
"complexity_score": complexity_scores,
|
|
"passes_complexity": [score <= max_complexity for score in complexity_scores]
|
|
})
|
|
|
|
return df_results
|
|
|
|
|
|
def apply_constraints(mols: List[Chem.Mol], constraints: Dict, n_jobs: int) -> pd.DataFrame:
|
|
"""Apply custom property constraints."""
|
|
print(f"\nApplying constraints: {constraints}")
|
|
|
|
constraint_filter = mc.constraints.Constraints(**constraints)
|
|
results = constraint_filter(mols=mols, n_jobs=n_jobs, progress=True)
|
|
|
|
df_results = pd.DataFrame({
|
|
"passes_constraints": [r["passes"] for r in results],
|
|
"constraint_violations": [", ".join(r["violations"]) if r["violations"] else "" for r in results]
|
|
})
|
|
|
|
return df_results
|
|
|
|
|
|
def apply_chemical_groups(mols: List[Chem.Mol], groups: List[str]) -> pd.DataFrame:
|
|
"""Detect chemical groups."""
|
|
print(f"\nDetecting chemical groups: {', '.join(groups)}")
|
|
|
|
group_detector = mc.groups.ChemicalGroup(groups=groups)
|
|
results = group_detector.get_all_matches(mols)
|
|
|
|
df_results = pd.DataFrame()
|
|
for group in groups:
|
|
df_results[f"has_{group}"] = [bool(r.get(group)) for r in results]
|
|
|
|
return df_results
|
|
|
|
|
|
def generate_summary(df: pd.DataFrame, output_file: Path):
|
|
"""Generate filtering summary report."""
|
|
summary_file = output_file.parent / f"{output_file.stem}_summary.txt"
|
|
|
|
with open(summary_file, "w") as f:
|
|
f.write("=" * 80 + "\n")
|
|
f.write("MEDCHEM FILTERING SUMMARY\n")
|
|
f.write("=" * 80 + "\n\n")
|
|
|
|
f.write(f"Total molecules processed: {len(df)}\n\n")
|
|
|
|
# Rule results
|
|
rule_cols = [col for col in df.columns if col.startswith("rule_") or col == "passes_all_rules"]
|
|
if rule_cols:
|
|
f.write("RULE FILTERS:\n")
|
|
f.write("-" * 40 + "\n")
|
|
for col in rule_cols:
|
|
if col in df.columns and df[col].dtype == bool:
|
|
n_pass = df[col].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" {col}: {n_pass} passed ({pct:.1f}%)\n")
|
|
f.write("\n")
|
|
|
|
# Structural alerts
|
|
alert_cols = [col for col in df.columns if "alert" in col.lower() or "nibr" in col.lower() or "lilly" in col.lower() or "pains" in col.lower()]
|
|
if alert_cols:
|
|
f.write("STRUCTURAL ALERTS:\n")
|
|
f.write("-" * 40 + "\n")
|
|
if "has_common_alerts" in df.columns:
|
|
n_clean = (~df["has_common_alerts"]).sum()
|
|
pct = 100 * n_clean / len(df)
|
|
f.write(f" No common alerts: {n_clean} ({pct:.1f}%)\n")
|
|
if "passes_nibr" in df.columns:
|
|
n_pass = df["passes_nibr"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" Passes NIBR: {n_pass} ({pct:.1f}%)\n")
|
|
if "passes_lilly" in df.columns:
|
|
n_pass = df["passes_lilly"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" Passes Lilly: {n_pass} ({pct:.1f}%)\n")
|
|
avg_demerits = df["lilly_demerits"].mean()
|
|
f.write(f" Average Lilly demerits: {avg_demerits:.1f}\n")
|
|
if "passes_pains" in df.columns:
|
|
n_pass = df["passes_pains"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" Passes PAINS: {n_pass} ({pct:.1f}%)\n")
|
|
f.write("\n")
|
|
|
|
# Complexity
|
|
if "complexity_score" in df.columns:
|
|
f.write("COMPLEXITY:\n")
|
|
f.write("-" * 40 + "\n")
|
|
avg_complexity = df["complexity_score"].mean()
|
|
f.write(f" Average complexity: {avg_complexity:.1f}\n")
|
|
if "passes_complexity" in df.columns:
|
|
n_pass = df["passes_complexity"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" Within threshold: {n_pass} ({pct:.1f}%)\n")
|
|
f.write("\n")
|
|
|
|
# Constraints
|
|
if "passes_constraints" in df.columns:
|
|
f.write("CONSTRAINTS:\n")
|
|
f.write("-" * 40 + "\n")
|
|
n_pass = df["passes_constraints"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write(f" Passes all constraints: {n_pass} ({pct:.1f}%)\n")
|
|
f.write("\n")
|
|
|
|
# Overall pass rate
|
|
pass_cols = [col for col in df.columns if col.startswith("passes_")]
|
|
if pass_cols:
|
|
df["passes_all_filters"] = df[pass_cols].all(axis=1)
|
|
n_pass = df["passes_all_filters"].sum()
|
|
pct = 100 * n_pass / len(df)
|
|
f.write("OVERALL:\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write(f" Molecules passing all filters: {n_pass} ({pct:.1f}%)\n")
|
|
|
|
f.write("\n" + "=" * 80 + "\n")
|
|
|
|
print(f"\nSummary report saved to: {summary_file}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Batch molecular filtering using medchem",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__
|
|
)
|
|
|
|
# Input/Output
|
|
parser.add_argument("input", type=Path, help="Input file (CSV, TSV, SDF, or TXT)")
|
|
parser.add_argument("--output", "-o", type=Path, required=True, help="Output CSV file")
|
|
parser.add_argument("--smiles-column", default="smiles", help="Name of SMILES column (default: smiles)")
|
|
|
|
# Rule filters
|
|
parser.add_argument("--rules", help="Comma-separated list of rules (e.g., rule_of_five,rule_of_cns)")
|
|
|
|
# Structural alerts
|
|
parser.add_argument("--common-alerts", action="store_true", help="Apply common structural alerts")
|
|
parser.add_argument("--nibr", action="store_true", help="Apply NIBR filters")
|
|
parser.add_argument("--lilly", action="store_true", help="Apply Lilly demerits filter")
|
|
parser.add_argument("--pains", action="store_true", help="Apply PAINS filter")
|
|
|
|
# Complexity
|
|
parser.add_argument("--complexity", type=float, help="Maximum complexity threshold")
|
|
parser.add_argument("--complexity-method", default="bertz", choices=["bertz", "whitlock", "barone"],
|
|
help="Complexity calculation method")
|
|
|
|
# Constraints
|
|
parser.add_argument("--mw-range", help="Molecular weight range (e.g., 200,500)")
|
|
parser.add_argument("--logp-range", help="LogP range (e.g., -2,5)")
|
|
parser.add_argument("--tpsa-max", type=float, help="Maximum TPSA")
|
|
parser.add_argument("--hbd-max", type=int, help="Maximum H-bond donors")
|
|
parser.add_argument("--hba-max", type=int, help="Maximum H-bond acceptors")
|
|
parser.add_argument("--rotatable-bonds-max", type=int, help="Maximum rotatable bonds")
|
|
|
|
# Chemical groups
|
|
parser.add_argument("--groups", help="Comma-separated chemical groups to detect")
|
|
|
|
# Processing options
|
|
parser.add_argument("--n-jobs", type=int, default=-1, help="Number of parallel jobs (-1 = all cores)")
|
|
parser.add_argument("--no-summary", action="store_true", help="Don't generate summary report")
|
|
parser.add_argument("--filter-output", action="store_true", help="Only output molecules passing all filters")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load molecules
|
|
df, mols = load_molecules(args.input, args.smiles_column)
|
|
|
|
# Apply filters
|
|
result_dfs = [df]
|
|
|
|
# Rules
|
|
if args.rules:
|
|
rule_list = [r.strip() for r in args.rules.split(",")]
|
|
df_rules = apply_rule_filters(mols, rule_list, args.n_jobs)
|
|
result_dfs.append(df_rules)
|
|
|
|
# Structural alerts
|
|
if args.common_alerts:
|
|
df_alerts = apply_structural_alerts(mols, "common", args.n_jobs)
|
|
result_dfs.append(df_alerts)
|
|
|
|
if args.nibr:
|
|
df_nibr = apply_structural_alerts(mols, "nibr", args.n_jobs)
|
|
result_dfs.append(df_nibr)
|
|
|
|
if args.lilly:
|
|
df_lilly = apply_structural_alerts(mols, "lilly", args.n_jobs)
|
|
result_dfs.append(df_lilly)
|
|
|
|
if args.pains:
|
|
df_pains = apply_structural_alerts(mols, "pains", args.n_jobs)
|
|
result_dfs.append(df_pains)
|
|
|
|
# Complexity
|
|
if args.complexity:
|
|
df_complexity = apply_complexity_filter(mols, args.complexity, args.complexity_method)
|
|
result_dfs.append(df_complexity)
|
|
|
|
# Constraints
|
|
constraints = {}
|
|
if args.mw_range:
|
|
mw_min, mw_max = map(float, args.mw_range.split(","))
|
|
constraints["mw_range"] = (mw_min, mw_max)
|
|
if args.logp_range:
|
|
logp_min, logp_max = map(float, args.logp_range.split(","))
|
|
constraints["logp_range"] = (logp_min, logp_max)
|
|
if args.tpsa_max:
|
|
constraints["tpsa_max"] = args.tpsa_max
|
|
if args.hbd_max:
|
|
constraints["hbd_max"] = args.hbd_max
|
|
if args.hba_max:
|
|
constraints["hba_max"] = args.hba_max
|
|
if args.rotatable_bonds_max:
|
|
constraints["rotatable_bonds_max"] = args.rotatable_bonds_max
|
|
|
|
if constraints:
|
|
df_constraints = apply_constraints(mols, constraints, args.n_jobs)
|
|
result_dfs.append(df_constraints)
|
|
|
|
# Chemical groups
|
|
if args.groups:
|
|
group_list = [g.strip() for g in args.groups.split(",")]
|
|
df_groups = apply_chemical_groups(mols, group_list)
|
|
result_dfs.append(df_groups)
|
|
|
|
# Combine results
|
|
df_final = pd.concat(result_dfs, axis=1)
|
|
|
|
# Filter output if requested
|
|
if args.filter_output:
|
|
pass_cols = [col for col in df_final.columns if col.startswith("passes_")]
|
|
if pass_cols:
|
|
df_final["passes_all"] = df_final[pass_cols].all(axis=1)
|
|
df_final = df_final[df_final["passes_all"]]
|
|
print(f"\nFiltered to {len(df_final)} molecules passing all filters")
|
|
|
|
# Save results
|
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
df_final.to_csv(args.output, index=False)
|
|
print(f"\nResults saved to: {args.output}")
|
|
|
|
# Generate summary
|
|
if not args.no_summary:
|
|
generate_summary(df_final, args.output)
|
|
|
|
print("\nDone!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|