diff --git a/.claude-plugin/marketplace.json b/.claude-plugin/marketplace.json index 4f76c55..64534ef 100644 --- a/.claude-plugin/marketplace.json +++ b/.claude-plugin/marketplace.json @@ -17,7 +17,41 @@ "strict": false, "skills": [ "./scientific-packages/anndata", - "./scientific-packages/arboreto" + "./scientific-packages/arboreto", + "./scientific-packages/astropy", + "./scientific-packages/biomni", + "./scientific-packages/biopython", + "./scientific-packages/bioservices", + "./scientific-packages/cellxgene-census", + "./scientific-packages/cobrapy", + "./scientific-packages/datamol", + "./scientific-packages/deepchem", + "./scientific-packages/deeptools", + "./scientific-packages/diffdock", + "./scientific-packages/etetoolkit", + "./scientific-packages/flowio", + "./scientific-packages/gget", + "./scientific-packages/matplotlib", + "./scientific-packages/medchem", + "./scientific-packages/molfeat", + "./scientific-packages/polars", + "./scientific-packages/pubchem-database", + "./scientific-packages/pydeseq2", + "./scientific-packages/pymatgen", + "./scientific-packages/pymc", + "./scientific-packages/pymoo", + "./scientific-packages/pytdc", + "./scientific-packages/pytorch-lightning", + "./scientific-packages/rdkit", + "./scientific-packages/reportlab", + "./scientific-packages/scanpy", + "./scientific-packages/scikit-bio", + "./scientific-packages/scikit-learn", + "./scientific-packages/seaborn", + "./scientific-packages/torch_geometric", + "./scientific-packages/transformers", + "./scientific-packages/umap-learn", + "./scientific-packages/zarr-python" ] }, { diff --git a/scientific-packages/astropy/SKILL.md b/scientific-packages/astropy/SKILL.md new file mode 100644 index 0000000..07a87ed --- /dev/null +++ b/scientific-packages/astropy/SKILL.md @@ -0,0 +1,790 @@ +--- +name: astropy +description: Comprehensive toolkit for astronomical data analysis and computation using the astropy Python library. This skill should be used when working with astronomical data including FITS files, coordinate transformations, cosmological calculations, time systems, physical units, data tables, model fitting, WCS transformations, and visualization. Use this skill for tasks involving celestial coordinates, astronomical file formats, photometry, spectroscopy, or any astronomy-specific Python computations. +--- + +# Astropy + +## Overview + +Astropy is the community standard Python library for astronomy, providing core functionality for astronomical data analysis and computation. This skill provides comprehensive guidance and tools for working with astropy's extensive capabilities across coordinate systems, file I/O, units and quantities, time systems, cosmology, modeling, and more. + +## When to Use This Skill + +Use this skill when: +- Working with FITS files (reading, writing, inspecting, modifying) +- Performing coordinate transformations between astronomical reference frames +- Calculating cosmological distances, ages, or other quantities +- Handling astronomical time systems and conversions +- Working with physical units and dimensional analysis +- Processing astronomical data tables with specialized column types +- Fitting models to astronomical data +- Converting between pixel and world coordinates (WCS) +- Performing robust statistical analysis on astronomical data +- Visualizing astronomical images with proper scaling and stretching + +## Core Capabilities + +### 1. FITS File Operations + +FITS (Flexible Image Transport System) is the standard file format in astronomy. Astropy provides comprehensive FITS support. + +**Quick FITS Inspection**: +Use the included `scripts/fits_info.py` script for rapid file inspection: +```bash +python scripts/fits_info.py observation.fits +python scripts/fits_info.py observation.fits --detailed +python scripts/fits_info.py observation.fits --ext 1 +``` + +**Common FITS workflows**: +```python +from astropy.io import fits + +# Read FITS file +with fits.open('image.fits') as hdul: + hdul.info() # Display structure + data = hdul[0].data + header = hdul[0].header + +# Write FITS file +fits.writeto('output.fits', data, header, overwrite=True) + +# Quick access (less efficient for multiple operations) +data = fits.getdata('image.fits', ext=0) +header = fits.getheader('image.fits', ext=0) + +# Update specific header keyword +fits.setval('image.fits', 'OBJECT', value='M31') +``` + +**Multi-extension FITS**: +```python +from astropy.io import fits + +# Create multi-extension FITS +primary = fits.PrimaryHDU(primary_data) +image_ext = fits.ImageHDU(science_data, name='SCI') +error_ext = fits.ImageHDU(error_data, name='ERR') + +hdul = fits.HDUList([primary, image_ext, error_ext]) +hdul.writeto('multi_ext.fits', overwrite=True) +``` + +**Binary tables**: +```python +from astropy.io import fits + +# Read binary table +with fits.open('catalog.fits') as hdul: + table_data = hdul[1].data + ra = table_data['RA'] + dec = table_data['DEC'] + +# Better: use astropy.table for table operations (see section 5) +``` + +### 2. Coordinate Systems and Transformations + +Astropy supports ~25 coordinate frames with seamless transformations. + +**Quick Coordinate Conversion**: +Use the included `scripts/coord_convert.py` script: +```bash +python scripts/coord_convert.py 10.68 41.27 --from icrs --to galactic +python scripts/coord_convert.py --file coords.txt --from icrs --to galactic --output sexagesimal +``` + +**Basic coordinate operations**: +```python +from astropy.coordinates import SkyCoord +import astropy.units as u + +# Create coordinate (multiple input formats supported) +c = SkyCoord(ra=10.68*u.degree, dec=41.27*u.degree, frame='icrs') +c = SkyCoord('00:42:44.3 +41:16:09', unit=(u.hourangle, u.deg)) +c = SkyCoord('00h42m44.3s +41d16m09s') + +# Transform between frames +c_galactic = c.galactic +c_fk5 = c.fk5 + +print(f"Galactic: l={c_galactic.l.deg:.3f}, b={c_galactic.b.deg:.3f}") +``` + +**Working with coordinate arrays**: +```python +import numpy as np +from astropy.coordinates import SkyCoord +import astropy.units as u + +# Arrays of coordinates +ra = np.array([10.1, 10.2, 10.3]) * u.degree +dec = np.array([40.1, 40.2, 40.3]) * u.degree +coords = SkyCoord(ra=ra, dec=dec, frame='icrs') + +# Calculate separations +sep = coords[0].separation(coords[1]) +print(f"Separation: {sep.to(u.arcmin)}") + +# Position angle +pa = coords[0].position_angle(coords[1]) +``` + +**Catalog matching**: +```python +from astropy.coordinates import SkyCoord +import astropy.units as u + +catalog1 = SkyCoord(ra=[10, 11, 12]*u.degree, dec=[40, 41, 42]*u.degree) +catalog2 = SkyCoord(ra=[10.01, 11.02, 13]*u.degree, dec=[40.01, 41.01, 43]*u.degree) + +# Find nearest neighbors +idx, sep2d, dist3d = catalog1.match_to_catalog_sky(catalog2) + +# Filter by separation threshold +max_sep = 1 * u.arcsec +matched = sep2d < max_sep +``` + +**Horizontal coordinates (Alt/Az)**: +```python +from astropy.coordinates import SkyCoord, EarthLocation, AltAz +from astropy.time import Time +import astropy.units as u + +location = EarthLocation(lat=40*u.deg, lon=-70*u.deg, height=300*u.m) +obstime = Time('2023-01-01 03:00:00') +target = SkyCoord(ra=10*u.degree, dec=40*u.degree, frame='icrs') + +altaz_frame = AltAz(obstime=obstime, location=location) +target_altaz = target.transform_to(altaz_frame) + +print(f"Alt: {target_altaz.alt.deg:.2f}°, Az: {target_altaz.az.deg:.2f}°") +``` + +**Available coordinate frames**: +- `icrs` - International Celestial Reference System (default, preferred) +- `fk5`, `fk4` - Fifth/Fourth Fundamental Katalog +- `galactic` - Galactic coordinates +- `supergalactic` - Supergalactic coordinates +- `altaz` - Horizontal (altitude-azimuth) coordinates +- `gcrs`, `cirs`, `itrs` - Earth-based systems +- Ecliptic frames: `BarycentricMeanEcliptic`, `HeliocentricMeanEcliptic`, `GeocentricMeanEcliptic` + +### 3. Units and Quantities + +Physical units are fundamental to astronomical calculations. Astropy's units system provides dimensional analysis and automatic conversions. + +**Basic unit operations**: +```python +import astropy.units as u + +# Create quantities +distance = 5.2 * u.parsec +velocity = 300 * u.km / u.s +time = 10 * u.year + +# Convert units +distance_ly = distance.to(u.lightyear) +velocity_mps = velocity.to(u.m / u.s) + +# Arithmetic with units +wavelength = 500 * u.nm +frequency = wavelength.to(u.Hz, equivalencies=u.spectral()) +``` + +**Working with arrays**: +```python +import numpy as np +import astropy.units as u + +wavelengths = np.array([400, 500, 600]) * u.nm +frequencies = wavelengths.to(u.THz, equivalencies=u.spectral()) + +fluxes = np.array([1.2, 2.3, 1.8]) * u.Jy +luminosities = 4 * np.pi * (10*u.pc)**2 * fluxes +``` + +**Important equivalencies**: +- `u.spectral()` - Convert wavelength ↔ frequency ↔ energy +- `u.doppler_optical(rest)` - Optical Doppler velocity +- `u.doppler_radio(rest)` - Radio Doppler velocity +- `u.doppler_relativistic(rest)` - Relativistic Doppler +- `u.temperature()` - Temperature unit conversions +- `u.brightness_temperature(freq)` - Brightness temperature + +**Physical constants**: +```python +from astropy import constants as const + +print(const.c) # Speed of light +print(const.G) # Gravitational constant +print(const.M_sun) # Solar mass +print(const.R_sun) # Solar radius +print(const.L_sun) # Solar luminosity +``` + +**Performance tip**: Use the `<<` operator for fast unit assignment to arrays: +```python +# Fast +result = large_array << u.m + +# Slower +result = large_array * u.m +``` + +### 4. Time Systems + +Astronomical time systems require high precision and multiple time scales. + +**Creating time objects**: +```python +from astropy.time import Time +import astropy.units as u + +# Various input formats +t1 = Time('2023-01-01T00:00:00', format='isot', scale='utc') +t2 = Time(2459945.5, format='jd', scale='utc') +t3 = Time(['2023-01-01', '2023-06-01'], format='iso') + +# Convert formats +print(t1.jd) # Julian Date +print(t1.mjd) # Modified Julian Date +print(t1.unix) # Unix timestamp +print(t1.iso) # ISO format + +# Convert time scales +print(t1.tai) # International Atomic Time +print(t1.tt) # Terrestrial Time +print(t1.tdb) # Barycentric Dynamical Time +``` + +**Time arithmetic**: +```python +from astropy.time import Time, TimeDelta +import astropy.units as u + +t1 = Time('2023-01-01T00:00:00') +dt = TimeDelta(1*u.day) + +t2 = t1 + dt +diff = t2 - t1 +print(diff.to(u.hour)) + +# Array of times +times = t1 + np.arange(10) * u.day +``` + +**Astronomical time calculations**: +```python +from astropy.time import Time +from astropy.coordinates import SkyCoord, EarthLocation +import astropy.units as u + +location = EarthLocation(lat=40*u.deg, lon=-70*u.deg) +t = Time('2023-01-01T00:00:00') + +# Local sidereal time +lst = t.sidereal_time('apparent', longitude=location.lon) + +# Barycentric correction +target = SkyCoord(ra=10*u.deg, dec=40*u.deg) +ltt = t.light_travel_time(target, location=location) +t_bary = t.tdb + ltt +``` + +**Available time scales**: +- `utc` - Coordinated Universal Time +- `tai` - International Atomic Time +- `tt` - Terrestrial Time +- `tcb`, `tcg` - Barycentric/Geocentric Coordinate Time +- `tdb` - Barycentric Dynamical Time +- `ut1` - Universal Time + +### 5. Data Tables + +Astropy tables provide astronomy-specific enhancements over pandas. + +**Creating and manipulating tables**: +```python +from astropy.table import Table +import astropy.units as u + +# Create table +t = Table() +t['name'] = ['Star1', 'Star2', 'Star3'] +t['ra'] = [10.5, 11.2, 12.3] * u.degree +t['dec'] = [41.2, 42.1, 43.5] * u.degree +t['flux'] = [1.2, 2.3, 0.8] * u.Jy + +# Column metadata +t['flux'].description = 'Flux at 1.4 GHz' +t['flux'].format = '.2f' + +# Add calculated column +t['flux_mJy'] = t['flux'].to(u.mJy) + +# Filter and sort +bright = t[t['flux'] > 1.0 * u.Jy] +t.sort('flux') +``` + +**Table I/O**: +```python +from astropy.table import Table + +# Read (format auto-detected from extension) +t = Table.read('data.fits') +t = Table.read('data.csv', format='ascii.csv') +t = Table.read('data.ecsv', format='ascii.ecsv') # Preserves units! +t = Table.read('data.votable', format='votable') + +# Write +t.write('output.fits', overwrite=True) +t.write('output.ecsv', format='ascii.ecsv', overwrite=True) +``` + +**Advanced operations**: +```python +from astropy.table import Table, join, vstack, hstack + +# Join tables (like SQL) +joined = join(table1, table2, keys='id') + +# Stack tables +combined_rows = vstack([t1, t2]) +combined_cols = hstack([t1, t2]) + +# Grouping and aggregation +t.group_by('category').groups.aggregate(np.mean) +``` + +**Tables with astronomical objects**: +```python +from astropy.table import Table +from astropy.coordinates import SkyCoord +from astropy.time import Time +import astropy.units as u + +coords = SkyCoord(ra=[10, 11, 12]*u.deg, dec=[40, 41, 42]*u.deg) +times = Time(['2023-01-01', '2023-01-02', '2023-01-03']) + +t = Table([coords, times], names=['coords', 'obstime']) +print(t['coords'][0].ra) # Access coordinate properties +``` + +### 6. Cosmological Calculations + +Quick cosmology calculations using standard models. + +**Using the cosmology calculator**: +```bash +python scripts/cosmo_calc.py 0.5 1.0 1.5 +python scripts/cosmo_calc.py --range 0 3 0.5 --cosmology Planck18 +python scripts/cosmo_calc.py 0.5 --verbose +python scripts/cosmo_calc.py --convert 1000 --from luminosity_distance +``` + +**Programmatic usage**: +```python +from astropy.cosmology import Planck18 +import astropy.units as u +import numpy as np + +cosmo = Planck18 + +# Calculate distances +z = 1.5 +d_L = cosmo.luminosity_distance(z) +d_A = cosmo.angular_diameter_distance(z) +d_C = cosmo.comoving_distance(z) + +# Time calculations +age = cosmo.age(z) +lookback = cosmo.lookback_time(z) + +# Hubble parameter +H_z = cosmo.H(z) + +print(f"At z={z}:") +print(f" Luminosity distance: {d_L:.2f}") +print(f" Age of universe: {age:.2f}") +``` + +**Convert observables**: +```python +from astropy.cosmology import Planck18 +import astropy.units as u + +cosmo = Planck18 +z = 1.5 + +# Angular size to physical size +d_A = cosmo.angular_diameter_distance(z) +angular_size = 1 * u.arcsec +physical_size = (angular_size.to(u.radian) * d_A).to(u.kpc) + +# Flux to luminosity +flux = 1e-17 * u.erg / u.s / u.cm**2 +d_L = cosmo.luminosity_distance(z) +luminosity = flux * 4 * np.pi * d_L**2 + +# Find redshift for given distance +from astropy.cosmology import z_at_value +z = z_at_value(cosmo.luminosity_distance, 1000*u.Mpc) +``` + +**Available cosmologies**: +- `Planck18`, `Planck15`, `Planck13` - Planck satellite parameters +- `WMAP9`, `WMAP7`, `WMAP5` - WMAP satellite parameters +- Custom: `FlatLambdaCDM(H0=70*u.km/u.s/u.Mpc, Om0=0.3)` + +### 7. Model Fitting + +Fit mathematical models to astronomical data. + +**1D fitting example**: +```python +from astropy.modeling import models, fitting +import numpy as np + +# Generate data +x = np.linspace(0, 10, 100) +y_data = 10 * np.exp(-0.5 * ((x - 5) / 1)**2) + np.random.normal(0, 0.5, x.shape) + +# Create and fit model +g_init = models.Gaussian1D(amplitude=8, mean=4.5, stddev=0.8) +fitter = fitting.LevMarLSQFitter() +g_fit = fitter(g_init, x, y_data) + +# Results +print(f"Amplitude: {g_fit.amplitude.value:.3f}") +print(f"Mean: {g_fit.mean.value:.3f}") +print(f"Stddev: {g_fit.stddev.value:.3f}") + +# Evaluate fitted model +y_fit = g_fit(x) +``` + +**Common 1D models**: +- `Gaussian1D` - Gaussian profile +- `Lorentz1D` - Lorentzian profile +- `Voigt1D` - Voigt profile +- `Moffat1D` - Moffat profile (PSF modeling) +- `Polynomial1D` - Polynomial +- `PowerLaw1D` - Power law +- `BlackBody` - Blackbody spectrum + +**Common 2D models**: +- `Gaussian2D` - 2D Gaussian +- `Moffat2D` - 2D Moffat (stellar PSF) +- `AiryDisk2D` - Airy disk (diffraction pattern) +- `Disk2D` - Circular disk + +**Fitting with constraints**: +```python +from astropy.modeling import models, fitting + +g = models.Gaussian1D(amplitude=10, mean=5, stddev=1) + +# Set bounds +g.amplitude.bounds = (0, None) # Positive only +g.mean.bounds = (4, 6) # Constrain center + +# Fix parameters +g.stddev.fixed = True + +# Compound models +model = models.Gaussian1D() + models.Polynomial1D(degree=1) +``` + +**Available fitters**: +- `LinearLSQFitter` - Linear least squares (fast, for linear models) +- `LevMarLSQFitter` - Levenberg-Marquardt (most common) +- `SimplexLSQFitter` - Downhill simplex +- `SLSQPLSQFitter` - Sequential Least Squares with constraints + +### 8. World Coordinate System (WCS) + +Transform between pixel and world coordinates in images. + +**Basic WCS usage**: +```python +from astropy.io import fits +from astropy.wcs import WCS + +# Read FITS with WCS +hdu = fits.open('image.fits')[0] +wcs = WCS(hdu.header) + +# Pixel to world +ra, dec = wcs.pixel_to_world_values(100, 200) + +# World to pixel +x, y = wcs.world_to_pixel_values(ra, dec) + +# Using SkyCoord (more powerful) +from astropy.coordinates import SkyCoord +import astropy.units as u + +coord = SkyCoord(ra=150*u.deg, dec=-30*u.deg) +x, y = wcs.world_to_pixel(coord) +``` + +**Plotting with WCS**: +```python +from astropy.io import fits +from astropy.wcs import WCS +from astropy.visualization import ImageNormalize, LogStretch, PercentileInterval +import matplotlib.pyplot as plt + +hdu = fits.open('image.fits')[0] +wcs = WCS(hdu.header) +data = hdu.data + +# Create figure with WCS projection +fig = plt.figure() +ax = fig.add_subplot(111, projection=wcs) + +# Plot with coordinate grid +norm = ImageNormalize(data, interval=PercentileInterval(99.5), + stretch=LogStretch()) +ax.imshow(data, norm=norm, origin='lower', cmap='viridis') + +# Coordinate labels and grid +ax.set_xlabel('RA') +ax.set_ylabel('Dec') +ax.coords.grid(color='white', alpha=0.5) +``` + +### 9. Statistics and Data Processing + +Robust statistical tools for astronomical data. + +**Sigma clipping** (remove outliers): +```python +from astropy.stats import sigma_clip, sigma_clipped_stats + +# Remove outliers +clipped = sigma_clip(data, sigma=3, maxiters=5) + +# Get statistics on cleaned data +mean, median, std = sigma_clipped_stats(data, sigma=3) + +# Use clipped data +background = median +signal = data - background +snr = signal / std +``` + +**Other statistical functions**: +```python +from astropy.stats import mad_std, biweight_location, biweight_scale + +# Robust standard deviation +std_robust = mad_std(data) + +# Robust central location +center = biweight_location(data) + +# Robust scale +scale = biweight_scale(data) +``` + +### 10. Visualization + +Display astronomical images with proper scaling. + +**Image normalization and stretching**: +```python +from astropy.visualization import (ImageNormalize, MinMaxInterval, + PercentileInterval, ZScaleInterval, + SqrtStretch, LogStretch, PowerStretch, + AsinhStretch) +import matplotlib.pyplot as plt + +# Common combination: percentile interval + sqrt stretch +norm = ImageNormalize(data, + interval=PercentileInterval(99), + stretch=SqrtStretch()) + +plt.imshow(data, norm=norm, origin='lower', cmap='gray') +plt.colorbar() +``` + +**Available intervals** (determine min/max): +- `MinMaxInterval()` - Use actual min/max +- `PercentileInterval(percentile)` - Clip to percentile (e.g., 99%) +- `ZScaleInterval()` - IRAF's zscale algorithm +- `ManualInterval(vmin, vmax)` - Specify manually + +**Available stretches** (nonlinear scaling): +- `LinearStretch()` - Linear (default) +- `SqrtStretch()` - Square root (common for images) +- `LogStretch()` - Logarithmic (for high dynamic range) +- `PowerStretch(power)` - Power law +- `AsinhStretch()` - Arcsinh (good for wide range) + +## Bundled Resources + +### scripts/ + +**`fits_info.py`** - Comprehensive FITS file inspection tool +```bash +python scripts/fits_info.py observation.fits +python scripts/fits_info.py observation.fits --detailed +python scripts/fits_info.py observation.fits --ext 1 +``` + +**`coord_convert.py`** - Batch coordinate transformation utility +```bash +python scripts/coord_convert.py 10.68 41.27 --from icrs --to galactic +python scripts/coord_convert.py --file coords.txt --from icrs --to galactic +``` + +**`cosmo_calc.py`** - Cosmological calculator +```bash +python scripts/cosmo_calc.py 0.5 1.0 1.5 +python scripts/cosmo_calc.py --range 0 3 0.5 --cosmology Planck18 +``` + +### references/ + +**`module_overview.md`** - Comprehensive reference of all astropy subpackages, classes, and methods. Consult this for detailed API information, available functions, and module capabilities. + +**`common_workflows.md`** - Complete working examples for common astronomical data analysis tasks. Contains full code examples for FITS operations, coordinate transformations, cosmology, modeling, and complete analysis pipelines. + +## Best Practices + +1. **Use context managers for FITS files**: + ```python + with fits.open('file.fits') as hdul: + # Work with file + ``` + +2. **Prefer astropy.table over raw FITS tables** for better unit/metadata support + +3. **Use SkyCoord for coordinates** (high-level interface) rather than low-level frame classes + +4. **Always attach units** to quantities when possible for dimensional safety + +5. **Use ECSV format** for saving tables when you want to preserve units and metadata + +6. **Vectorize coordinate operations** rather than looping for performance + +7. **Use memmap=True** when opening large FITS files to save memory + +8. **Install Bottleneck** package for faster statistics operations + +9. **Pre-compute composite units** for repeated operations to improve performance + +10. **Consult `references/module_overview.md`** for detailed module information and `references/common_workflows.md`** for complete working examples + +## Common Patterns + +### Pattern: FITS → Process → FITS +```python +from astropy.io import fits +from astropy.stats import sigma_clipped_stats + +# Read +with fits.open('input.fits') as hdul: + data = hdul[0].data + header = hdul[0].header + + # Process + mean, median, std = sigma_clipped_stats(data, sigma=3) + processed = (data - median) / std + + # Write + fits.writeto('output.fits', processed, header, overwrite=True) +``` + +### Pattern: Catalog Matching +```python +from astropy.coordinates import SkyCoord +from astropy.table import Table +import astropy.units as u + +# Load catalogs +cat1 = Table.read('catalog1.fits') +cat2 = Table.read('catalog2.fits') + +# Create coordinate objects +coords1 = SkyCoord(ra=cat1['RA'], dec=cat1['DEC'], unit=u.degree) +coords2 = SkyCoord(ra=cat2['RA'], dec=cat2['DEC'], unit=u.degree) + +# Match +idx, sep2d, dist3d = coords1.match_to_catalog_sky(coords2) + +# Filter by separation +max_sep = 1 * u.arcsec +matched_mask = sep2d < max_sep + +# Create matched catalog +matched_cat1 = cat1[matched_mask] +matched_cat2 = cat2[idx[matched_mask]] +``` + +### Pattern: Time Series Analysis +```python +from astropy.time import Time +from astropy.timeseries import TimeSeries +import astropy.units as u + +# Create time series +times = Time(['2023-01-01', '2023-01-02', '2023-01-03']) +flux = [1.2, 2.3, 1.8] * u.Jy + +ts = TimeSeries(time=times) +ts['flux'] = flux + +# Fold on period +from astropy.timeseries import aggregate_downsample +period = 1.5 * u.day +folded = ts.fold(period=period) +``` + +### Pattern: Image Display with WCS +```python +from astropy.io import fits +from astropy.wcs import WCS +from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval +import matplotlib.pyplot as plt + +hdu = fits.open('image.fits')[0] +wcs = WCS(hdu.header) +data = hdu.data + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111, projection=wcs) + +norm = ImageNormalize(data, interval=PercentileInterval(99), + stretch=SqrtStretch()) +im = ax.imshow(data, norm=norm, origin='lower', cmap='viridis') + +ax.set_xlabel('RA') +ax.set_ylabel('Dec') +ax.coords.grid(color='white', alpha=0.5, linestyle='solid') +plt.colorbar(im, ax=ax) +``` + +## Installation Note + +Ensure astropy is installed in the Python environment: +```bash +pip install astropy +``` + +For additional performance and features: +```bash +pip install astropy[all] # Includes optional dependencies +``` + +## Additional Resources + +- Official documentation: https://docs.astropy.org +- Tutorials: https://learn.astropy.org +- API reference: Consult `references/module_overview.md` in this skill +- Working examples: Consult `references/common_workflows.md` in this skill diff --git a/scientific-packages/astropy/references/common_workflows.md b/scientific-packages/astropy/references/common_workflows.md new file mode 100644 index 0000000..aed9ff2 --- /dev/null +++ b/scientific-packages/astropy/references/common_workflows.md @@ -0,0 +1,618 @@ +# Common Astropy Workflows + +This document describes frequently used workflows when working with astronomical data using astropy. + +## 1. Working with FITS Files + +### Basic FITS Reading +```python +from astropy.io import fits +import numpy as np + +# Open and examine structure +with fits.open('observation.fits') as hdul: + hdul.info() + + # Access primary HDU + primary_hdr = hdul[0].header + primary_data = hdul[0].data + + # Access extension + ext_data = hdul[1].data + ext_hdr = hdul[1].header + + # Read specific header keywords + object_name = primary_hdr['OBJECT'] + exposure = primary_hdr['EXPTIME'] +``` + +### Writing FITS Files +```python +# Create new FITS file +from astropy.io import fits +import numpy as np + +# Create data +data = np.random.random((100, 100)) + +# Create primary HDU +hdu = fits.PrimaryHDU(data) +hdu.header['OBJECT'] = 'M31' +hdu.header['EXPTIME'] = 300.0 + +# Write to file +hdu.writeto('output.fits', overwrite=True) + +# Multi-extension FITS +hdul = fits.HDUList([ + fits.PrimaryHDU(data1), + fits.ImageHDU(data2, name='SCI'), + fits.ImageHDU(data3, name='ERR') +]) +hdul.writeto('multi_ext.fits', overwrite=True) +``` + +### FITS Table Operations +```python +from astropy.io import fits + +# Read binary table +with fits.open('catalog.fits') as hdul: + table_data = hdul[1].data + + # Access columns + ra = table_data['RA'] + dec = table_data['DEC'] + mag = table_data['MAG'] + + # Filter data + bright = table_data[table_data['MAG'] < 15] + +# Write binary table +from astropy.table import Table +import astropy.units as u + +t = Table([ra, dec, mag], names=['RA', 'DEC', 'MAG']) +t['RA'].unit = u.degree +t['DEC'].unit = u.degree +t.write('output_catalog.fits', format='fits', overwrite=True) +``` + +## 2. Coordinate Transformations + +### Basic Coordinate Creation and Transformation +```python +from astropy.coordinates import SkyCoord +import astropy.units as u + +# Create from RA/Dec +c = SkyCoord(ra=10.68458*u.degree, dec=41.26917*u.degree, frame='icrs') + +# Alternative creation methods +c = SkyCoord('00:42:44.3 +41:16:09', unit=(u.hourangle, u.deg)) +c = SkyCoord('00h42m44.3s +41d16m09s') + +# Transform to different frames +c_gal = c.galactic +c_fk5 = c.fk5 +print(f"Galactic: l={c_gal.l.deg}, b={c_gal.b.deg}") +``` + +### Coordinate Arrays and Separations +```python +import numpy as np +from astropy.coordinates import SkyCoord +import astropy.units as u + +# Create array of coordinates +ra_array = np.array([10.1, 10.2, 10.3]) * u.degree +dec_array = np.array([40.1, 40.2, 40.3]) * u.degree +coords = SkyCoord(ra=ra_array, dec=dec_array, frame='icrs') + +# Calculate separations +c1 = SkyCoord(ra=10*u.degree, dec=40*u.degree) +c2 = SkyCoord(ra=11*u.degree, dec=41*u.degree) +sep = c1.separation(c2) +print(f"Separation: {sep.to(u.arcmin)}") + +# Position angle +pa = c1.position_angle(c2) +``` + +### Catalog Matching +```python +from astropy.coordinates import SkyCoord, match_coordinates_sky +import astropy.units as u + +# Two catalogs of coordinates +catalog1 = SkyCoord(ra=[10, 11, 12]*u.degree, dec=[40, 41, 42]*u.degree) +catalog2 = SkyCoord(ra=[10.01, 11.02, 13]*u.degree, dec=[40.01, 41.01, 43]*u.degree) + +# Find nearest neighbors +idx, sep2d, dist3d = catalog1.match_to_catalog_sky(catalog2) + +# Filter by separation threshold +max_sep = 1 * u.arcsec +matched = sep2d < max_sep +matching_indices = idx[matched] +``` + +### Horizontal Coordinates (Alt/Az) +```python +from astropy.coordinates import SkyCoord, EarthLocation, AltAz +from astropy.time import Time +import astropy.units as u + +# Observer location +location = EarthLocation(lat=40*u.deg, lon=-70*u.deg, height=300*u.m) + +# Observation time +obstime = Time('2023-01-01 03:00:00') + +# Target coordinate +target = SkyCoord(ra=10*u.degree, dec=40*u.degree, frame='icrs') + +# Transform to Alt/Az +altaz_frame = AltAz(obstime=obstime, location=location) +target_altaz = target.transform_to(altaz_frame) + +print(f"Altitude: {target_altaz.alt.deg}") +print(f"Azimuth: {target_altaz.az.deg}") +``` + +## 3. Units and Quantities + +### Basic Unit Operations +```python +import astropy.units as u + +# Create quantities +distance = 5.2 * u.parsec +time = 10 * u.year +velocity = 300 * u.km / u.s + +# Unit conversion +distance_ly = distance.to(u.lightyear) +velocity_mps = velocity.to(u.m / u.s) + +# Arithmetic with units +wavelength = 500 * u.nm +frequency = wavelength.to(u.Hz, equivalencies=u.spectral()) + +# Compose/decompose units +composite = (1 * u.kg * u.m**2 / u.s**2) +print(composite.decompose()) # Base SI units +print(composite.compose()) # Known compound units (Joule) +``` + +### Working with Arrays +```python +import numpy as np +import astropy.units as u + +# Quantity arrays +wavelengths = np.array([400, 500, 600]) * u.nm +frequencies = wavelengths.to(u.THz, equivalencies=u.spectral()) + +# Mathematical operations preserve units +fluxes = np.array([1.2, 2.3, 1.8]) * u.Jy +luminosities = 4 * np.pi * (10*u.pc)**2 * fluxes +``` + +### Custom Units and Equivalencies +```python +import astropy.units as u + +# Define custom unit +beam = u.def_unit('beam', 1.5e-10 * u.steradian) + +# Register for session +u.add_enabled_units([beam]) + +# Use in calculations +flux_per_beam = 1.5 * u.Jy / beam + +# Doppler equivalencies +rest_wavelength = 656.3 * u.nm # H-alpha +observed = 656.5 * u.nm +velocity = observed.to(u.km/u.s, + equivalencies=u.doppler_optical(rest_wavelength)) +``` + +## 4. Time Handling + +### Time Creation and Conversion +```python +from astropy.time import Time +import astropy.units as u + +# Create time objects +t1 = Time('2023-01-01T00:00:00', format='isot', scale='utc') +t2 = Time(2459945.5, format='jd', scale='utc') +t3 = Time(['2023-01-01', '2023-06-01'], format='iso') + +# Convert formats +print(t1.jd) # Julian Date +print(t1.mjd) # Modified Julian Date +print(t1.unix) # Unix timestamp +print(t1.iso) # ISO format + +# Convert time scales +print(t1.tai) # Convert to TAI +print(t1.tt) # Convert to TT +print(t1.tdb) # Convert to TDB +``` + +### Time Arithmetic +```python +from astropy.time import Time, TimeDelta +import astropy.units as u + +t1 = Time('2023-01-01T00:00:00') +dt = TimeDelta(1*u.day) + +# Add time delta +t2 = t1 + dt + +# Difference between times +diff = t2 - t1 +print(diff.to(u.hour)) + +# Array of times +times = t1 + np.arange(10) * u.day +``` + +### Sidereal Time and Astronomical Calculations +```python +from astropy.time import Time +from astropy.coordinates import EarthLocation +import astropy.units as u + +location = EarthLocation(lat=40*u.deg, lon=-70*u.deg) +t = Time('2023-01-01T00:00:00') + +# Local sidereal time +lst = t.sidereal_time('apparent', longitude=location.lon) + +# Light travel time correction +from astropy.coordinates import SkyCoord +target = SkyCoord(ra=10*u.deg, dec=40*u.deg) +ltt_bary = t.light_travel_time(target, location=location) +t_bary = t + ltt_bary +``` + +## 5. Tables and Data Management + +### Creating and Manipulating Tables +```python +from astropy.table import Table, Column +import astropy.units as u +import numpy as np + +# Create table +t = Table() +t['name'] = ['Star1', 'Star2', 'Star3'] +t['ra'] = [10.5, 11.2, 12.3] * u.degree +t['dec'] = [41.2, 42.1, 43.5] * u.degree +t['flux'] = [1.2, 2.3, 0.8] * u.Jy + +# Add column metadata +t['flux'].description = 'Flux at 1.4 GHz' +t['flux'].format = '.2f' + +# Add new column +t['flux_mJy'] = t['flux'].to(u.mJy) + +# Filter rows +bright = t[t['flux'] > 1.0 * u.Jy] + +# Sort +t.sort('flux') +``` + +### Table I/O +```python +from astropy.table import Table + +# Read various formats +t = Table.read('data.fits') +t = Table.read('data.csv', format='ascii.csv') +t = Table.read('data.ecsv', format='ascii.ecsv') # Preserves units +t = Table.read('data.votable', format='votable') + +# Write various formats +t.write('output.fits', overwrite=True) +t.write('output.csv', format='ascii.csv', overwrite=True) +t.write('output.ecsv', format='ascii.ecsv', overwrite=True) +t.write('output.votable', format='votable', overwrite=True) +``` + +### Advanced Table Operations +```python +from astropy.table import Table, join, vstack, hstack + +# Join tables +t1 = Table([[1, 2], ['a', 'b']], names=['id', 'val1']) +t2 = Table([[1, 2], ['c', 'd']], names=['id', 'val2']) +joined = join(t1, t2, keys='id') + +# Stack tables vertically +combined = vstack([t1, t2]) + +# Stack horizontally +combined = hstack([t1, t2]) + +# Grouping +t.group_by('category').groups.aggregate(np.mean) +``` + +### Tables with Astronomical Objects +```python +from astropy.table import Table +from astropy.coordinates import SkyCoord +from astropy.time import Time +import astropy.units as u + +# Table with SkyCoord column +coords = SkyCoord(ra=[10, 11, 12]*u.deg, dec=[40, 41, 42]*u.deg) +times = Time(['2023-01-01', '2023-01-02', '2023-01-03']) + +t = Table([coords, times], names=['coords', 'obstime']) + +# Access individual coordinates +print(t['coords'][0].ra) +print(t['coords'][0].dec) +``` + +## 6. Cosmological Calculations + +### Distance Calculations +```python +from astropy.cosmology import Planck18, FlatLambdaCDM +import astropy.units as u +import numpy as np + +# Use built-in cosmology +cosmo = Planck18 + +# Redshifts +z = np.linspace(0, 5, 50) + +# Calculate distances +comoving_dist = cosmo.comoving_distance(z) +angular_diam_dist = cosmo.angular_diameter_distance(z) +luminosity_dist = cosmo.luminosity_distance(z) + +# Age of universe +age_at_z = cosmo.age(z) +lookback_time = cosmo.lookback_time(z) + +# Hubble parameter +H_z = cosmo.H(z) +``` + +### Converting Observables +```python +from astropy.cosmology import Planck18 +import astropy.units as u + +cosmo = Planck18 +z = 1.5 + +# Angular diameter distance +d_A = cosmo.angular_diameter_distance(z) + +# Convert angular size to physical size +angular_size = 1 * u.arcsec +physical_size = (angular_size.to(u.radian) * d_A).to(u.kpc) + +# Convert flux to luminosity +flux = 1e-17 * u.erg / u.s / u.cm**2 +d_L = cosmo.luminosity_distance(z) +luminosity = flux * 4 * np.pi * d_L**2 + +# Find redshift for given distance +from astropy.cosmology import z_at_value +z_result = z_at_value(cosmo.luminosity_distance, 1000*u.Mpc) +``` + +### Custom Cosmology +```python +from astropy.cosmology import FlatLambdaCDM +import astropy.units as u + +# Define custom cosmology +my_cosmo = FlatLambdaCDM(H0=70 * u.km/u.s/u.Mpc, + Om0=0.3, + Tcmb0=2.725 * u.K) + +# Use it for calculations +print(my_cosmo.age(0)) +print(my_cosmo.luminosity_distance(1.5)) +``` + +## 7. Model Fitting + +### Fitting 1D Models +```python +from astropy.modeling import models, fitting +import numpy as np +import matplotlib.pyplot as plt + +# Generate data with noise +x = np.linspace(0, 10, 100) +true_model = models.Gaussian1D(amplitude=10, mean=5, stddev=1) +y = true_model(x) + np.random.normal(0, 0.5, x.shape) + +# Create and fit model +g_init = models.Gaussian1D(amplitude=8, mean=4.5, stddev=0.8) +fitter = fitting.LevMarLSQFitter() +g_fit = fitter(g_init, x, y) + +# Plot results +plt.plot(x, y, 'o', label='Data') +plt.plot(x, g_fit(x), label='Fit') +plt.legend() + +# Get fitted parameters +print(f"Amplitude: {g_fit.amplitude.value}") +print(f"Mean: {g_fit.mean.value}") +print(f"Stddev: {g_fit.stddev.value}") +``` + +### Fitting with Constraints +```python +from astropy.modeling import models, fitting + +# Set parameter bounds +g = models.Gaussian1D(amplitude=10, mean=5, stddev=1) +g.amplitude.bounds = (0, None) # Positive only +g.mean.bounds = (4, 6) # Constrain center +g.stddev.fixed = True # Fix width + +# Tie parameters (for multi-component models) +g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=1, name='g1') +g2 = models.Gaussian1D(amplitude=5, mean=6, stddev=1, name='g2') +g2.stddev.tied = lambda model: model.g1.stddev + +# Compound model +model = g1 + g2 +``` + +### 2D Image Fitting +```python +from astropy.modeling import models, fitting +import numpy as np + +# Create 2D data +y, x = np.mgrid[0:100, 0:100] +z = models.Gaussian2D(amplitude=100, x_mean=50, y_mean=50, + x_stddev=5, y_stddev=5)(x, y) +z += np.random.normal(0, 5, z.shape) + +# Fit 2D Gaussian +g_init = models.Gaussian2D(amplitude=90, x_mean=48, y_mean=48, + x_stddev=4, y_stddev=4) +fitter = fitting.LevMarLSQFitter() +g_fit = fitter(g_init, x, y, z) + +# Get parameters +print(f"Center: ({g_fit.x_mean.value}, {g_fit.y_mean.value})") +print(f"Width: ({g_fit.x_stddev.value}, {g_fit.y_stddev.value})") +``` + +## 8. Image Processing and Visualization + +### Image Display with Proper Scaling +```python +from astropy.io import fits +from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval +import matplotlib.pyplot as plt + +# Read FITS image +data = fits.getdata('image.fits') + +# Apply normalization +norm = ImageNormalize(data, + interval=PercentileInterval(99), + stretch=SqrtStretch()) + +# Display +plt.imshow(data, norm=norm, origin='lower', cmap='gray') +plt.colorbar() +``` + +### WCS Plotting +```python +from astropy.io import fits +from astropy.wcs import WCS +from astropy.visualization import ImageNormalize, LogStretch, PercentileInterval +import matplotlib.pyplot as plt + +# Read FITS with WCS +hdu = fits.open('image.fits')[0] +wcs = WCS(hdu.header) +data = hdu.data + +# Create figure with WCS projection +fig = plt.figure() +ax = fig.add_subplot(111, projection=wcs) + +# Plot with coordinate grid +norm = ImageNormalize(data, interval=PercentileInterval(99.5), + stretch=LogStretch()) +im = ax.imshow(data, norm=norm, origin='lower', cmap='viridis') + +# Add coordinate labels +ax.set_xlabel('RA') +ax.set_ylabel('Dec') +ax.coords.grid(color='white', alpha=0.5) +plt.colorbar(im) +``` + +### Sigma Clipping and Statistics +```python +from astropy.stats import sigma_clip, sigma_clipped_stats +import numpy as np + +# Data with outliers +data = np.random.normal(100, 15, 1000) +data[0:50] = np.random.normal(200, 10, 50) # Add outliers + +# Sigma clipping +clipped = sigma_clip(data, sigma=3, maxiters=5) + +# Get statistics on clipped data +mean, median, std = sigma_clipped_stats(data, sigma=3) + +print(f"Mean: {mean:.2f}") +print(f"Median: {median:.2f}") +print(f"Std: {std:.2f}") +print(f"Clipped {clipped.mask.sum()} values") +``` + +## 9. Complete Analysis Example + +### Photometry Pipeline +```python +from astropy.io import fits +from astropy.wcs import WCS +from astropy.coordinates import SkyCoord +from astropy.stats import sigma_clipped_stats +from astropy.visualization import ImageNormalize, LogStretch +import astropy.units as u +import numpy as np + +# Read FITS file +hdu = fits.open('observation.fits')[0] +data = hdu.data +header = hdu.header +wcs = WCS(header) + +# Calculate background statistics +mean, median, std = sigma_clipped_stats(data, sigma=3.0) +print(f"Background: {median:.2f} +/- {std:.2f}") + +# Subtract background +data_sub = data - median + +# Known source coordinates +source_coord = SkyCoord(ra='10:42:30', dec='+41:16:09', unit=(u.hourangle, u.deg)) + +# Convert to pixel coordinates +x_pix, y_pix = wcs.world_to_pixel(source_coord) + +# Simple aperture photometry +aperture_radius = 10 # pixels +y, x = np.ogrid[:data.shape[0], :data.shape[1]] +mask = (x - x_pix)**2 + (y - y_pix)**2 <= aperture_radius**2 + +aperture_sum = np.sum(data_sub[mask]) +npix = np.sum(mask) + +print(f"Source position: ({x_pix:.1f}, {y_pix:.1f})") +print(f"Aperture sum: {aperture_sum:.2f}") +print(f"S/N: {aperture_sum / (std * np.sqrt(npix)):.2f}") +``` + +This workflow document provides practical examples for common astronomical data analysis tasks using astropy. diff --git a/scientific-packages/astropy/references/module_overview.md b/scientific-packages/astropy/references/module_overview.md new file mode 100644 index 0000000..6de0bd3 --- /dev/null +++ b/scientific-packages/astropy/references/module_overview.md @@ -0,0 +1,340 @@ +# Astropy Module Overview + +This document provides a comprehensive reference of all major astropy subpackages and their capabilities. + +## Core Data Structures + +### astropy.units +**Purpose**: Handle physical units and dimensional analysis in computations. + +**Key Classes**: +- `Quantity` - Combines numerical values with units +- `Unit` - Represents physical units + +**Common Operations**: +```python +import astropy.units as u +distance = 5 * u.meter +time = 2 * u.second +velocity = distance / time # Returns Quantity in m/s +wavelength = 500 * u.nm +frequency = wavelength.to(u.Hz, equivalencies=u.spectral()) +``` + +**Equivalencies**: +- `u.spectral()` - Convert wavelength ↔ frequency +- `u.doppler_optical()`, `u.doppler_radio()` - Velocity conversions +- `u.temperature()` - Temperature unit conversions +- `u.pixel_scale()` - Pixel to physical units + +### astropy.constants +**Purpose**: Provide physical and astronomical constants. + +**Common Constants**: +- `c` - Speed of light +- `G` - Gravitational constant +- `h` - Planck constant +- `M_sun`, `R_sun`, `L_sun` - Solar mass, radius, luminosity +- `M_earth`, `R_earth` - Earth mass, radius +- `pc`, `au` - Parsec, astronomical unit + +### astropy.time +**Purpose**: Represent and manipulate times and dates with astronomical precision. + +**Time Scales**: +- `UTC` - Coordinated Universal Time +- `TAI` - International Atomic Time +- `TT` - Terrestrial Time +- `TCB`, `TCG` - Barycentric/Geocentric Coordinate Time +- `TDB` - Barycentric Dynamical Time +- `UT1` - Universal Time + +**Common Formats**: +- `iso`, `isot` - ISO 8601 strings +- `jd`, `mjd` - Julian/Modified Julian Date +- `unix`, `gps` - Unix/GPS timestamps +- `datetime` - Python datetime objects + +**Example**: +```python +from astropy.time import Time +t = Time('2023-01-01T00:00:00', format='isot', scale='utc') +print(t.mjd) # Modified Julian Date +print(t.jd) # Julian Date +print(t.tt) # Convert to TT scale +``` + +### astropy.table +**Purpose**: Work with tabular data optimized for astronomical applications. + +**Key Features**: +- Native support for astropy Quantity, Time, and SkyCoord columns +- Multi-dimensional columns +- Metadata preservation (units, descriptions, formats) +- Advanced operations: joins, grouping, binning +- File I/O via unified interface + +**Example**: +```python +from astropy.table import Table +import astropy.units as u + +t = Table() +t['name'] = ['Star1', 'Star2', 'Star3'] +t['ra'] = [10.5, 11.2, 12.3] * u.degree +t['dec'] = [41.2, 42.1, 43.5] * u.degree +t['flux'] = [1.2, 2.3, 0.8] * u.Jy +``` + +## Coordinates and World Coordinate Systems + +### astropy.coordinates +**Purpose**: Represent and transform celestial coordinates. + +**Primary Interface**: `SkyCoord` - High-level class for sky positions + +**Coordinate Frames**: +- `ICRS` - International Celestial Reference System (default) +- `FK5`, `FK4` - Fifth/Fourth Fundamental Katalog +- `Galactic`, `Supergalactic` - Galactic coordinates +- `AltAz` - Horizontal (altitude-azimuth) coordinates +- `GCRS`, `CIRS`, `ITRS` - Earth-based systems +- `BarycentricMeanEcliptic`, `HeliocentricMeanEcliptic`, `GeocentricMeanEcliptic` - Ecliptic coordinates + +**Common Operations**: +```python +from astropy.coordinates import SkyCoord +import astropy.units as u + +# Create coordinate +c = SkyCoord(ra=10.625*u.degree, dec=41.2*u.degree, frame='icrs') + +# Transform to galactic +c_gal = c.galactic + +# Calculate separation +c2 = SkyCoord(ra=11*u.degree, dec=42*u.degree, frame='icrs') +sep = c.separation(c2) + +# Match catalogs +idx, sep2d, dist3d = c.match_to_catalog_sky(catalog_coords) +``` + +### astropy.wcs +**Purpose**: Handle World Coordinate System transformations for astronomical images. + +**Key Class**: `WCS` - Maps between pixel and world coordinates + +**Common Use Cases**: +- Convert pixel coordinates to RA/Dec +- Convert RA/Dec to pixel coordinates +- Handle distortion corrections (SIP, lookup tables) + +**Example**: +```python +from astropy.wcs import WCS +from astropy.io import fits + +hdu = fits.open('image.fits')[0] +wcs = WCS(hdu.header) + +# Pixel to world +ra, dec = wcs.pixel_to_world_values(100, 200) + +# World to pixel +x, y = wcs.world_to_pixel_values(ra, dec) +``` + +## File I/O + +### astropy.io.fits +**Purpose**: Read and write FITS (Flexible Image Transport System) files. + +**Key Classes**: +- `HDUList` - Container for all HDUs in a file +- `PrimaryHDU` - Primary header data unit +- `ImageHDU` - Image extension +- `BinTableHDU` - Binary table extension +- `Header` - FITS header keywords + +**Common Operations**: +```python +from astropy.io import fits + +# Read FITS file +with fits.open('file.fits') as hdul: + hdul.info() # Display structure + header = hdul[0].header + data = hdul[0].data + +# Write FITS file +fits.writeto('output.fits', data, header) + +# Update header keyword +fits.setval('file.fits', 'OBJECT', value='M31') +``` + +### astropy.io.ascii +**Purpose**: Read and write ASCII tables in various formats. + +**Supported Formats**: +- Basic, CSV, tab-delimited +- CDS/MRT (Machine Readable Tables) +- IPAC, Daophot, SExtractor +- LaTeX tables +- HTML tables + +### astropy.io.votable +**Purpose**: Handle Virtual Observatory (VO) table format. + +### astropy.io.misc +**Purpose**: Additional formats including HDF5, Parquet, and YAML. + +## Scientific Calculations + +### astropy.cosmology +**Purpose**: Perform cosmological calculations. + +**Common Models**: +- `FlatLambdaCDM` - Flat universe with cosmological constant (most common) +- `LambdaCDM` - Universe with cosmological constant +- `Planck18`, `Planck15`, `Planck13` - Pre-defined Planck parameters +- `WMAP9`, `WMAP7`, `WMAP5` - Pre-defined WMAP parameters + +**Common Methods**: +```python +from astropy.cosmology import FlatLambdaCDM, Planck18 +import astropy.units as u + +cosmo = FlatLambdaCDM(H0=70, Om0=0.3) +# Or use built-in: cosmo = Planck18 + +z = 1.5 +print(cosmo.age(z)) # Age of universe at z +print(cosmo.luminosity_distance(z)) # Luminosity distance +print(cosmo.angular_diameter_distance(z)) # Angular diameter distance +print(cosmo.comoving_distance(z)) # Comoving distance +print(cosmo.H(z)) # Hubble parameter at z +``` + +### astropy.modeling +**Purpose**: Framework for model evaluation and fitting. + +**Model Categories**: +- 1D models: Gaussian1D, Lorentz1D, Voigt1D, Polynomial1D +- 2D models: Gaussian2D, Disk2D, Moffat2D +- Physical models: BlackBody, Drude1D, NFW +- Polynomial models: Chebyshev, Legendre + +**Common Fitters**: +- `LinearLSQFitter` - Linear least squares +- `LevMarLSQFitter` - Levenberg-Marquardt +- `SimplexLSQFitter` - Downhill simplex + +**Example**: +```python +from astropy.modeling import models, fitting + +# Create model +g = models.Gaussian1D(amplitude=10, mean=5, stddev=1) + +# Fit to data +fitter = fitting.LevMarLSQFitter() +fitted_model = fitter(g, x_data, y_data) +``` + +### astropy.convolution +**Purpose**: Convolve and filter astronomical data. + +**Common Kernels**: +- `Gaussian2DKernel` - 2D Gaussian smoothing +- `Box2DKernel` - 2D boxcar smoothing +- `Tophat2DKernel` - 2D tophat filter +- Custom kernels via arrays + +### astropy.stats +**Purpose**: Statistical tools for astronomical data analysis. + +**Key Functions**: +- `sigma_clip()` - Remove outliers via sigma clipping +- `sigma_clipped_stats()` - Compute mean, median, std with clipping +- `mad_std()` - Median Absolute Deviation +- `biweight_location()`, `biweight_scale()` - Robust statistics +- `circmean()`, `circstd()` - Circular statistics + +**Example**: +```python +from astropy.stats import sigma_clip, sigma_clipped_stats + +# Remove outliers +filtered_data = sigma_clip(data, sigma=3, maxiters=5) + +# Get statistics +mean, median, std = sigma_clipped_stats(data, sigma=3) +``` + +## Data Processing + +### astropy.nddata +**Purpose**: Handle N-dimensional datasets with metadata. + +**Key Class**: `NDData` - Container for array data with units, uncertainty, mask, and WCS + +### astropy.timeseries +**Purpose**: Work with time series data. + +**Key Classes**: +- `TimeSeries` - Time-indexed data table +- `BinnedTimeSeries` - Time-binned data + +**Common Operations**: +- Period finding (Lomb-Scargle) +- Folding time series +- Binning data + +### astropy.visualization +**Purpose**: Display astronomical data effectively. + +**Key Features**: +- Image normalization (LogStretch, PowerStretch, SqrtStretch, etc.) +- Interval scaling (MinMaxInterval, PercentileInterval, ZScaleInterval) +- WCSAxes for plotting with coordinate overlays +- RGB image creation with stretching +- Astronomical colormaps + +**Example**: +```python +from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval +import matplotlib.pyplot as plt + +norm = ImageNormalize(data, interval=PercentileInterval(99), + stretch=SqrtStretch()) +plt.imshow(data, norm=norm, origin='lower') +``` + +## Utilities + +### astropy.samp +**Purpose**: Simple Application Messaging Protocol for inter-application communication. + +**Use Case**: Connect Python scripts with other astronomical tools (e.g., DS9, TOPCAT). + +## Module Import Patterns + +**Standard imports**: +```python +import astropy.units as u +from astropy.coordinates import SkyCoord +from astropy.time import Time +from astropy.io import fits +from astropy.table import Table +from astropy import constants as const +``` + +## Performance Tips + +1. **Pre-compute composite units** for repeated operations +2. **Use `<<` operator** for fast unit assignments: `array << u.m` instead of `array * u.m` +3. **Vectorize operations** rather than looping over coordinates/times +4. **Use memmap=True** when opening large FITS files +5. **Install Bottleneck** for faster stats operations diff --git a/scientific-packages/astropy/scripts/coord_convert.py b/scientific-packages/astropy/scripts/coord_convert.py new file mode 100644 index 0000000..0341b8a --- /dev/null +++ b/scientific-packages/astropy/scripts/coord_convert.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +""" +Coordinate conversion utility for astronomical coordinates. + +This script provides batch coordinate transformations between different +astronomical coordinate systems using astropy. +""" + +import sys +import argparse +from astropy.coordinates import SkyCoord +import astropy.units as u + + +def convert_coordinates(coords_input, input_frame='icrs', output_frame='galactic', + input_format='decimal', output_format='decimal'): + """ + Convert astronomical coordinates between different frames. + + Parameters + ---------- + coords_input : list of tuples or str + Input coordinates as (lon, lat) pairs or strings + input_frame : str + Input coordinate frame (icrs, fk5, galactic, etc.) + output_frame : str + Output coordinate frame + input_format : str + Format of input coordinates ('decimal', 'sexagesimal', 'hourangle') + output_format : str + Format for output display ('decimal', 'sexagesimal', 'hourangle') + + Returns + ------- + list + Converted coordinates + """ + results = [] + + for coord in coords_input: + try: + # Parse input coordinate + if input_format == 'decimal': + if isinstance(coord, str): + parts = coord.split() + lon, lat = float(parts[0]), float(parts[1]) + else: + lon, lat = coord + c = SkyCoord(lon*u.degree, lat*u.degree, frame=input_frame) + + elif input_format == 'sexagesimal': + c = SkyCoord(coord, frame=input_frame, unit=(u.hourangle, u.deg)) + + elif input_format == 'hourangle': + if isinstance(coord, str): + parts = coord.split() + lon, lat = parts[0], parts[1] + else: + lon, lat = coord + c = SkyCoord(lon, lat, frame=input_frame, unit=(u.hourangle, u.deg)) + + # Transform to output frame + if output_frame == 'icrs': + c_out = c.icrs + elif output_frame == 'fk5': + c_out = c.fk5 + elif output_frame == 'fk4': + c_out = c.fk4 + elif output_frame == 'galactic': + c_out = c.galactic + elif output_frame == 'supergalactic': + c_out = c.supergalactic + else: + c_out = c.transform_to(output_frame) + + results.append(c_out) + + except Exception as e: + print(f"Error converting coordinate {coord}: {e}", file=sys.stderr) + results.append(None) + + return results + + +def format_output(coords, frame, output_format='decimal'): + """Format coordinates for display.""" + output = [] + + for c in coords: + if c is None: + output.append("ERROR") + continue + + if frame in ['icrs', 'fk5', 'fk4']: + lon_name, lat_name = 'RA', 'Dec' + lon = c.ra + lat = c.dec + elif frame == 'galactic': + lon_name, lat_name = 'l', 'b' + lon = c.l + lat = c.b + elif frame == 'supergalactic': + lon_name, lat_name = 'sgl', 'sgb' + lon = c.sgl + lat = c.sgb + else: + lon_name, lat_name = 'lon', 'lat' + lon = c.spherical.lon + lat = c.spherical.lat + + if output_format == 'decimal': + out_str = f"{lon.degree:12.8f} {lat.degree:+12.8f}" + elif output_format == 'sexagesimal': + if frame in ['icrs', 'fk5', 'fk4']: + out_str = f"{lon.to_string(unit=u.hourangle, sep=':', pad=True)} " + out_str += f"{lat.to_string(unit=u.degree, sep=':', pad=True)}" + else: + out_str = f"{lon.to_string(unit=u.degree, sep=':', pad=True)} " + out_str += f"{lat.to_string(unit=u.degree, sep=':', pad=True)}" + elif output_format == 'hourangle': + out_str = f"{lon.to_string(unit=u.hourangle, sep=' ', pad=True)} " + out_str += f"{lat.to_string(unit=u.degree, sep=' ', pad=True)}" + + output.append(out_str) + + return output + + +def main(): + """Main function for command-line usage.""" + parser = argparse.ArgumentParser( + description='Convert astronomical coordinates between different frames', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Supported frames: icrs, fk5, fk4, galactic, supergalactic + +Input formats: + decimal : Degrees (e.g., "10.68 41.27") + sexagesimal : HMS/DMS (e.g., "00:42:44.3 +41:16:09") + hourangle : Hours and degrees (e.g., "10.5h 41.5d") + +Examples: + %(prog)s --from icrs --to galactic "10.68 41.27" + %(prog)s --from icrs --to galactic --input decimal --output sexagesimal "150.5 -30.2" + %(prog)s --from galactic --to icrs "120.5 45.3" + %(prog)s --file coords.txt --from icrs --to galactic + """ + ) + + parser.add_argument('coordinates', nargs='*', + help='Coordinates to convert (lon lat pairs)') + parser.add_argument('-f', '--from', dest='input_frame', default='icrs', + help='Input coordinate frame (default: icrs)') + parser.add_argument('-t', '--to', dest='output_frame', default='galactic', + help='Output coordinate frame (default: galactic)') + parser.add_argument('-i', '--input', dest='input_format', default='decimal', + choices=['decimal', 'sexagesimal', 'hourangle'], + help='Input format (default: decimal)') + parser.add_argument('-o', '--output', dest='output_format', default='decimal', + choices=['decimal', 'sexagesimal', 'hourangle'], + help='Output format (default: decimal)') + parser.add_argument('--file', dest='input_file', + help='Read coordinates from file (one per line)') + parser.add_argument('--header', action='store_true', + help='Print header line with coordinate names') + + args = parser.parse_args() + + # Get coordinates from file or command line + if args.input_file: + try: + with open(args.input_file, 'r') as f: + coords = [line.strip() for line in f if line.strip()] + except FileNotFoundError: + print(f"Error: File '{args.input_file}' not found.", file=sys.stderr) + sys.exit(1) + else: + if not args.coordinates: + print("Error: No coordinates provided.", file=sys.stderr) + parser.print_help() + sys.exit(1) + + # Combine pairs of arguments + if args.input_format == 'decimal': + coords = [] + i = 0 + while i < len(args.coordinates): + if i + 1 < len(args.coordinates): + coords.append(f"{args.coordinates[i]} {args.coordinates[i+1]}") + i += 2 + else: + print(f"Warning: Odd number of coordinates, skipping last value", + file=sys.stderr) + break + else: + coords = args.coordinates + + # Convert coordinates + converted = convert_coordinates(coords, + input_frame=args.input_frame, + output_frame=args.output_frame, + input_format=args.input_format, + output_format=args.output_format) + + # Format and print output + formatted = format_output(converted, args.output_frame, args.output_format) + + # Print header if requested + if args.header: + if args.output_frame in ['icrs', 'fk5', 'fk4']: + if args.output_format == 'decimal': + print(f"{'RA (deg)':>12s} {'Dec (deg)':>13s}") + else: + print(f"{'RA':>25s} {'Dec':>26s}") + elif args.output_frame == 'galactic': + if args.output_format == 'decimal': + print(f"{'l (deg)':>12s} {'b (deg)':>13s}") + else: + print(f"{'l':>25s} {'b':>26s}") + + for line in formatted: + print(line) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/astropy/scripts/cosmo_calc.py b/scientific-packages/astropy/scripts/cosmo_calc.py new file mode 100644 index 0000000..57c3b9c --- /dev/null +++ b/scientific-packages/astropy/scripts/cosmo_calc.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +Cosmological calculator using astropy.cosmology. + +This script provides quick calculations of cosmological distances, +ages, and other quantities for given redshifts. +""" + +import sys +import argparse +import numpy as np +from astropy.cosmology import FlatLambdaCDM, Planck18, Planck15, WMAP9 +import astropy.units as u + + +def calculate_cosmology(redshifts, cosmology='Planck18', H0=None, Om0=None): + """ + Calculate cosmological quantities for given redshifts. + + Parameters + ---------- + redshifts : array-like + Redshift values + cosmology : str + Cosmology to use ('Planck18', 'Planck15', 'WMAP9', 'custom') + H0 : float, optional + Hubble constant for custom cosmology (km/s/Mpc) + Om0 : float, optional + Matter density parameter for custom cosmology + + Returns + ------- + dict + Dictionary containing calculated quantities + """ + # Select cosmology + if cosmology == 'Planck18': + cosmo = Planck18 + elif cosmology == 'Planck15': + cosmo = Planck15 + elif cosmology == 'WMAP9': + cosmo = WMAP9 + elif cosmology == 'custom': + if H0 is None or Om0 is None: + raise ValueError("Must provide H0 and Om0 for custom cosmology") + cosmo = FlatLambdaCDM(H0=H0 * u.km/u.s/u.Mpc, Om0=Om0) + else: + raise ValueError(f"Unknown cosmology: {cosmology}") + + z = np.atleast_1d(redshifts) + + results = { + 'redshift': z, + 'cosmology': str(cosmo), + 'luminosity_distance': cosmo.luminosity_distance(z), + 'angular_diameter_distance': cosmo.angular_diameter_distance(z), + 'comoving_distance': cosmo.comoving_distance(z), + 'comoving_volume': cosmo.comoving_volume(z), + 'age': cosmo.age(z), + 'lookback_time': cosmo.lookback_time(z), + 'H': cosmo.H(z), + 'scale_factor': 1.0 / (1.0 + z) + } + + return results, cosmo + + +def print_results(results, verbose=False, csv=False): + """Print calculation results.""" + + z = results['redshift'] + + if csv: + # CSV output + print("z,D_L(Mpc),D_A(Mpc),D_C(Mpc),Age(Gyr),t_lookback(Gyr),H(km/s/Mpc)") + for i in range(len(z)): + print(f"{z[i]:.6f}," + f"{results['luminosity_distance'][i].value:.6f}," + f"{results['angular_diameter_distance'][i].value:.6f}," + f"{results['comoving_distance'][i].value:.6f}," + f"{results['age'][i].value:.6f}," + f"{results['lookback_time'][i].value:.6f}," + f"{results['H'][i].value:.6f}") + else: + # Formatted table output + if verbose: + print(f"\nCosmology: {results['cosmology']}") + print("-" * 80) + + print(f"\n{'z':>8s} {'D_L':>12s} {'D_A':>12s} {'D_C':>12s} " + f"{'Age':>10s} {'t_lb':>10s} {'H(z)':>10s}") + print(f"{'':>8s} {'(Mpc)':>12s} {'(Mpc)':>12s} {'(Mpc)':>12s} " + f"{'(Gyr)':>10s} {'(Gyr)':>10s} {'(km/s/Mpc)':>10s}") + print("-" * 80) + + for i in range(len(z)): + print(f"{z[i]:8.4f} " + f"{results['luminosity_distance'][i].value:12.3f} " + f"{results['angular_diameter_distance'][i].value:12.3f} " + f"{results['comoving_distance'][i].value:12.3f} " + f"{results['age'][i].value:10.4f} " + f"{results['lookback_time'][i].value:10.4f} " + f"{results['H'][i].value:10.4f}") + + if verbose: + print("\nLegend:") + print(" z : Redshift") + print(" D_L : Luminosity distance") + print(" D_A : Angular diameter distance") + print(" D_C : Comoving distance") + print(" Age : Age of universe at z") + print(" t_lb : Lookback time to z") + print(" H(z) : Hubble parameter at z") + + +def convert_quantity(value, quantity_type, cosmo, to_redshift=False): + """ + Convert between redshift and cosmological quantity. + + Parameters + ---------- + value : float + Value to convert + quantity_type : str + Type of quantity ('luminosity_distance', 'age', etc.) + cosmo : Cosmology + Cosmology object + to_redshift : bool + If True, convert quantity to redshift; else convert z to quantity + """ + from astropy.cosmology import z_at_value + + if to_redshift: + # Convert quantity to redshift + if quantity_type == 'luminosity_distance': + z = z_at_value(cosmo.luminosity_distance, value * u.Mpc) + elif quantity_type == 'age': + z = z_at_value(cosmo.age, value * u.Gyr) + elif quantity_type == 'lookback_time': + z = z_at_value(cosmo.lookback_time, value * u.Gyr) + elif quantity_type == 'comoving_distance': + z = z_at_value(cosmo.comoving_distance, value * u.Mpc) + else: + raise ValueError(f"Unknown quantity type: {quantity_type}") + return z + else: + # Convert redshift to quantity + if quantity_type == 'luminosity_distance': + return cosmo.luminosity_distance(value) + elif quantity_type == 'age': + return cosmo.age(value) + elif quantity_type == 'lookback_time': + return cosmo.lookback_time(value) + elif quantity_type == 'comoving_distance': + return cosmo.comoving_distance(value) + else: + raise ValueError(f"Unknown quantity type: {quantity_type}") + + +def main(): + """Main function for command-line usage.""" + parser = argparse.ArgumentParser( + description='Calculate cosmological quantities for given redshifts', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Available cosmologies: Planck18, Planck15, WMAP9, custom + +Examples: + %(prog)s 0.5 1.0 1.5 + %(prog)s 0.5 --cosmology Planck15 + %(prog)s 0.5 --cosmology custom --H0 70 --Om0 0.3 + %(prog)s --range 0 3 0.5 + %(prog)s 0.5 --verbose + %(prog)s 0.5 1.0 --csv + %(prog)s --convert 1000 --from luminosity_distance --cosmology Planck18 + """ + ) + + parser.add_argument('redshifts', nargs='*', type=float, + help='Redshift values to calculate') + parser.add_argument('-c', '--cosmology', default='Planck18', + choices=['Planck18', 'Planck15', 'WMAP9', 'custom'], + help='Cosmology to use (default: Planck18)') + parser.add_argument('--H0', type=float, + help='Hubble constant for custom cosmology (km/s/Mpc)') + parser.add_argument('--Om0', type=float, + help='Matter density parameter for custom cosmology') + parser.add_argument('-r', '--range', nargs=3, type=float, metavar=('START', 'STOP', 'STEP'), + help='Generate redshift range (start stop step)') + parser.add_argument('-v', '--verbose', action='store_true', + help='Print verbose output with cosmology details') + parser.add_argument('--csv', action='store_true', + help='Output in CSV format') + parser.add_argument('--convert', type=float, + help='Convert a quantity to redshift') + parser.add_argument('--from', dest='from_quantity', + choices=['luminosity_distance', 'age', 'lookback_time', 'comoving_distance'], + help='Type of quantity to convert from') + + args = parser.parse_args() + + # Handle conversion mode + if args.convert is not None: + if args.from_quantity is None: + print("Error: Must specify --from when using --convert", file=sys.stderr) + sys.exit(1) + + # Get cosmology + if args.cosmology == 'Planck18': + cosmo = Planck18 + elif args.cosmology == 'Planck15': + cosmo = Planck15 + elif args.cosmology == 'WMAP9': + cosmo = WMAP9 + elif args.cosmology == 'custom': + if args.H0 is None or args.Om0 is None: + print("Error: Must provide --H0 and --Om0 for custom cosmology", + file=sys.stderr) + sys.exit(1) + cosmo = FlatLambdaCDM(H0=args.H0 * u.km/u.s/u.Mpc, Om0=args.Om0) + + z = convert_quantity(args.convert, args.from_quantity, cosmo, to_redshift=True) + print(f"\n{args.from_quantity.replace('_', ' ').title()} = {args.convert}") + print(f"Redshift z = {z:.6f}") + print(f"(using {args.cosmology} cosmology)") + return + + # Get redshifts + if args.range: + start, stop, step = args.range + redshifts = np.arange(start, stop + step/2, step) + elif args.redshifts: + redshifts = np.array(args.redshifts) + else: + print("Error: No redshifts provided.", file=sys.stderr) + parser.print_help() + sys.exit(1) + + # Calculate + try: + results, cosmo = calculate_cosmology(redshifts, args.cosmology, + H0=args.H0, Om0=args.Om0) + print_results(results, verbose=args.verbose, csv=args.csv) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/astropy/scripts/fits_info.py b/scientific-packages/astropy/scripts/fits_info.py new file mode 100644 index 0000000..233a89c --- /dev/null +++ b/scientific-packages/astropy/scripts/fits_info.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Quick FITS file inspection tool. + +This script provides a convenient way to inspect FITS file structure, +headers, and basic statistics without writing custom code each time. +""" + +import sys +from pathlib import Path +from astropy.io import fits +import numpy as np + + +def print_fits_info(filename, detailed=False, ext=None): + """ + Print comprehensive information about a FITS file. + + Parameters + ---------- + filename : str + Path to FITS file + detailed : bool + If True, print detailed statistics for each HDU + ext : int or str, optional + Specific extension to examine in detail + """ + print(f"\n{'='*70}") + print(f"FITS File: {filename}") + print(f"{'='*70}\n") + + try: + with fits.open(filename) as hdul: + # Print file structure + print("File Structure:") + print("-" * 70) + hdul.info() + print() + + # If specific extension requested + if ext is not None: + print(f"\nDetailed view of extension: {ext}") + print("-" * 70) + hdu = hdul[ext] + print_hdu_details(hdu, detailed=True) + return + + # Print header and data info for each HDU + for i, hdu in enumerate(hdul): + print(f"\n{'='*70}") + print(f"HDU {i}: {hdu.name}") + print(f"{'='*70}") + print_hdu_details(hdu, detailed=detailed) + + except FileNotFoundError: + print(f"Error: File '{filename}' not found.") + sys.exit(1) + except Exception as e: + print(f"Error reading FITS file: {e}") + sys.exit(1) + + +def print_hdu_details(hdu, detailed=False): + """Print details for a single HDU.""" + + # Header information + print("\nHeader Information:") + print("-" * 70) + + # Key header keywords + important_keywords = ['SIMPLE', 'BITPIX', 'NAXIS', 'EXTEND', + 'OBJECT', 'TELESCOP', 'INSTRUME', 'OBSERVER', + 'DATE-OBS', 'EXPTIME', 'FILTER', 'AIRMASS', + 'RA', 'DEC', 'EQUINOX', 'CTYPE1', 'CTYPE2'] + + header = hdu.header + for key in important_keywords: + if key in header: + value = header[key] + comment = header.comments[key] + print(f" {key:12s} = {str(value):20s} / {comment}") + + # NAXIS keywords + if 'NAXIS' in header: + naxis = header['NAXIS'] + for i in range(1, naxis + 1): + key = f'NAXIS{i}' + if key in header: + print(f" {key:12s} = {str(header[key]):20s} / {header.comments[key]}") + + # Data information + if hdu.data is not None: + print("\nData Information:") + print("-" * 70) + + data = hdu.data + print(f" Data type: {data.dtype}") + print(f" Shape: {data.shape}") + + # For image data + if hasattr(data, 'ndim') and data.ndim >= 1: + try: + # Calculate statistics + finite_data = data[np.isfinite(data)] + if len(finite_data) > 0: + print(f" Min: {np.min(finite_data):.6g}") + print(f" Max: {np.max(finite_data):.6g}") + print(f" Mean: {np.mean(finite_data):.6g}") + print(f" Median: {np.median(finite_data):.6g}") + print(f" Std: {np.std(finite_data):.6g}") + + # Count special values + n_nan = np.sum(np.isnan(data)) + n_inf = np.sum(np.isinf(data)) + if n_nan > 0: + print(f" NaN values: {n_nan}") + if n_inf > 0: + print(f" Inf values: {n_inf}") + except Exception as e: + print(f" Could not calculate statistics: {e}") + + # For table data + if hasattr(data, 'columns'): + print(f"\n Table Columns ({len(data.columns)}):") + for col in data.columns: + print(f" {col.name:20s} {col.format:10s} {col.unit or ''}") + + if detailed: + print(f"\n First few rows:") + print(data[:min(5, len(data))]) + else: + print("\n No data in this HDU") + + # WCS information if present + try: + from astropy.wcs import WCS + wcs = WCS(hdu.header) + if wcs.has_celestial: + print("\nWCS Information:") + print("-" * 70) + print(f" Has celestial WCS: Yes") + print(f" CTYPE: {wcs.wcs.ctype}") + if wcs.wcs.crval is not None: + print(f" CRVAL: {wcs.wcs.crval}") + if wcs.wcs.crpix is not None: + print(f" CRPIX: {wcs.wcs.crpix}") + if wcs.wcs.cdelt is not None: + print(f" CDELT: {wcs.wcs.cdelt}") + except Exception: + pass + + +def main(): + """Main function for command-line usage.""" + import argparse + + parser = argparse.ArgumentParser( + description='Inspect FITS file structure and contents', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s image.fits + %(prog)s image.fits --detailed + %(prog)s image.fits --ext 1 + %(prog)s image.fits --ext SCI + """ + ) + + parser.add_argument('filename', help='FITS file to inspect') + parser.add_argument('-d', '--detailed', action='store_true', + help='Show detailed statistics for each HDU') + parser.add_argument('-e', '--ext', type=str, default=None, + help='Show details for specific extension only (number or name)') + + args = parser.parse_args() + + # Convert extension to int if numeric + ext = args.ext + if ext is not None: + try: + ext = int(ext) + except ValueError: + pass # Keep as string for extension name + + print_fits_info(args.filename, detailed=args.detailed, ext=ext) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/biomni/SKILL.md b/scientific-packages/biomni/SKILL.md new file mode 100644 index 0000000..3ba134b --- /dev/null +++ b/scientific-packages/biomni/SKILL.md @@ -0,0 +1,375 @@ +--- +name: biomni +description: General-purpose biomedical AI agent for autonomously executing research tasks across diverse biomedical domains. Use this skill when working with biomedical data analysis, CRISPR screening, single-cell RNA-seq, molecular property prediction, genomics, proteomics, drug discovery, or any computational biology task requiring LLM-powered code generation and retrieval-augmented planning. +--- + +# Biomni + +## Overview + +Biomni is a general-purpose biomedical AI agent that autonomously executes research tasks across diverse biomedical subfields. It combines large language model reasoning with retrieval-augmented planning and code-based execution to enhance scientific productivity and hypothesis generation. The system operates with an ~11GB biomedical knowledge base covering molecular, genomic, and clinical domains. + +## Quick Start + +Initialize and use the Biomni agent with these basic steps: + +```python +from biomni.agent import A1 + +# Initialize agent with data path and LLM model +agent = A1(path='./data', llm='claude-sonnet-4-20250514') + +# Execute a biomedical research task +agent.go("Your biomedical task description") +``` + +The agent will autonomously decompose the task, retrieve relevant biomedical knowledge, generate and execute code, and provide results. + +## Installation and Setup + +### Environment Preparation + +1. **Set up the conda environment:** + - Follow instructions in `biomni_env/README.md` from the repository + - Activate the environment: `conda activate biomni_e1` + +2. **Install the package:** + ```bash + pip install biomni --upgrade + ``` + + Or install from source: + ```bash + git clone https://github.com/snap-stanford/biomni.git + cd biomni + pip install -e . + ``` + +3. **Configure API keys:** + + Set up credentials via environment variables or `.env` file: + ```bash + export ANTHROPIC_API_KEY="your-key-here" + export OPENAI_API_KEY="your-key-here" # Optional + ``` + +4. **Data initialization:** + + On first use, the agent will automatically download the ~11GB biomedical knowledge base. + +### LLM Provider Configuration + +Biomni supports multiple LLM providers. Configure the default provider using: + +```python +from biomni.config import default_config + +# Set the default LLM model +default_config.llm = "claude-sonnet-4-20250514" # Anthropic +# default_config.llm = "gpt-4" # OpenAI +# default_config.llm = "azure/gpt-4" # Azure OpenAI +# default_config.llm = "gemini/gemini-pro" # Google Gemini + +# Set timeout (optional) +default_config.timeout_seconds = 1200 + +# Set data path (optional) +default_config.data_path = "./custom/data/path" +``` + +Refer to `references/llm_providers.md` for detailed configuration options for each provider. + +## Core Biomedical Research Tasks + +### 1. CRISPR Screening and Design + +Execute CRISPR screening tasks including guide RNA design, off-target analysis, and screening experiment planning: + +```python +agent.go("Design a CRISPR screening experiment to identify genes involved in cancer cell resistance to drug X") +``` + +The agent will: +- Retrieve relevant gene databases +- Design guide RNAs with specificity analysis +- Plan experimental controls and readout strategies +- Generate analysis code for screening results + +### 2. Single-Cell RNA-seq Analysis + +Perform comprehensive scRNA-seq analysis workflows: + +```python +agent.go("Analyze this 10X Genomics scRNA-seq dataset, identify cell types, and find differentially expressed genes between clusters") +``` + +Capabilities include: +- Quality control and preprocessing +- Dimensionality reduction and clustering +- Cell type annotation using marker databases +- Differential expression analysis +- Pathway enrichment analysis + +### 3. Molecular Property Prediction (ADMET) + +Predict absorption, distribution, metabolism, excretion, and toxicity properties: + +```python +agent.go("Predict ADMET properties for these drug candidates: [SMILES strings]") +``` + +The agent handles: +- Molecular descriptor calculation +- Property prediction using integrated models +- Toxicity screening +- Drug-likeness assessment + +### 4. Genomic Analysis + +Execute genomic data analysis tasks: + +```python +agent.go("Perform GWAS analysis to identify SNPs associated with disease phenotype in this cohort") +``` + +Supports: +- Genome-wide association studies (GWAS) +- Variant calling and annotation +- Population genetics analysis +- Functional genomics integration + +### 5. Protein Structure and Function + +Analyze protein sequences and structures: + +```python +agent.go("Predict the structure of this protein sequence and identify potential binding sites") +``` + +Capabilities: +- Sequence analysis and domain identification +- Structure prediction integration +- Binding site prediction +- Protein-protein interaction analysis + +### 6. Disease Diagnosis and Classification + +Perform disease classification from multi-omics data: + +```python +agent.go("Build a classifier to diagnose disease X from patient RNA-seq and clinical data") +``` + +### 7. Systems Biology and Pathway Analysis + +Analyze biological pathways and networks: + +```python +agent.go("Identify dysregulated pathways in this differential expression dataset") +``` + +### 8. Drug Discovery and Repurposing + +Support drug discovery workflows: + +```python +agent.go("Identify FDA-approved drugs that could be repurposed for treating disease Y based on mechanism of action") +``` + +## Advanced Features + +### Custom Configuration per Agent + +Override global configuration for specific agent instances: + +```python +agent = A1( + path='./project_data', + llm='gpt-4o', + timeout=1800 +) +``` + +### Conversation History and Reporting + +Save execution traces as formatted PDF reports: + +```python +# After executing tasks +agent.save_conversation_history( + output_path='./reports/experiment_log.pdf', + format='pdf' +) +``` + +Requires one of: WeasyPrint, markdown2pdf, or Pandoc. + +### Model Context Protocol (MCP) Integration + +Extend agent capabilities with external tools: + +```python +# Add MCP-compatible tools +agent.add_mcp(config_path='./mcp_config.json') +``` + +MCP enables integration with: +- Laboratory information management systems (LIMS) +- Specialized bioinformatics databases +- Custom analysis pipelines +- External computational resources + +### Using Biomni-R0 (Specialized Reasoning Model) + +Deploy the 32B parameter Biomni-R0 model for enhanced biological reasoning: + +```bash +# Install SGLang +pip install "sglang[all]" + +# Deploy Biomni-R0 +python -m sglang.launch_server \ + --model-path snap-stanford/biomni-r0 \ + --port 30000 \ + --trust-remote-code +``` + +Then configure the agent: + +```python +from biomni.config import default_config + +default_config.llm = "openai/biomni-r0" +default_config.api_base = "http://localhost:30000/v1" +``` + +Biomni-R0 provides specialized reasoning for: +- Complex multi-step biological workflows +- Hypothesis generation and evaluation +- Experimental design optimization +- Literature-informed analysis + +## Best Practices + +### Task Specification + +Provide clear, specific task descriptions: + +✅ **Good:** "Analyze this scRNA-seq dataset (file: data.h5ad) to identify T cell subtypes, then perform differential expression analysis comparing activated vs. resting T cells" + +❌ **Vague:** "Analyze my RNA-seq data" + +### Data Organization + +Structure data directories for efficient retrieval: + +``` +project/ +├── data/ # Biomni knowledge base +├── raw_data/ # Your experimental data +├── results/ # Analysis outputs +└── reports/ # Generated reports +``` + +### Iterative Refinement + +Use iterative task execution for complex analyses: + +```python +# Step 1: Exploratory analysis +agent.go("Load and perform initial QC on the proteomics dataset") + +# Step 2: Based on results, refine analysis +agent.go("Based on the QC results, remove low-quality samples and normalize using method X") + +# Step 3: Downstream analysis +agent.go("Perform differential abundance analysis with adjusted parameters") +``` + +### Security Considerations + +**CRITICAL:** Biomni executes LLM-generated code with full system privileges. For production use: + +1. **Use sandboxed environments:** Deploy in Docker containers or VMs with restricted permissions +2. **Validate sensitive operations:** Review code before execution for file access, network calls, or credential usage +3. **Limit data access:** Restrict agent access to only necessary data directories +4. **Monitor execution:** Log all executed code for audit trails + +Never run Biomni with: +- Unrestricted file system access +- Direct access to sensitive credentials +- Network access to production systems +- Elevated system privileges + +### Model Selection Guidelines + +Choose models based on task complexity: + +- **Claude Sonnet 4:** Recommended for most biomedical tasks, excellent biological reasoning +- **GPT-4/GPT-4o:** Strong general capabilities, good for diverse tasks +- **Biomni-R0:** Specialized for complex biological reasoning, multi-step workflows +- **Smaller models:** Use for simple, well-defined tasks to reduce cost + +## Evaluation and Benchmarking + +Biomni-Eval1 benchmark contains 433 evaluation instances across 10 biological tasks: + +- GWAS analysis +- Disease diagnosis +- Gene detection and classification +- Molecular property prediction +- Pathway analysis +- Protein function prediction +- Drug response prediction +- Variant interpretation +- Cell type annotation +- Biomarker discovery + +Use the benchmark to: +- Evaluate custom agent configurations +- Compare LLM providers for specific tasks +- Validate analysis pipelines + +## Troubleshooting + +### Common Issues + +**Issue:** Data download fails or times out +**Solution:** Manually download the knowledge base or increase timeout settings + +**Issue:** Package dependency conflicts +**Solution:** Some optional dependencies cannot be installed by default due to conflicts. Install specific packages manually and uncomment relevant code sections as documented in the repository + +**Issue:** LLM API errors +**Solution:** Verify API key configuration, check rate limits, ensure sufficient credits + +**Issue:** Memory errors with large datasets +**Solution:** Process data in chunks, use data subsampling, or deploy on higher-memory instances + +### Getting Help + +For detailed troubleshooting: +- Review the Biomni GitHub repository issues +- Check `references/api_reference.md` for detailed API documentation +- Consult `references/task_examples.md` for comprehensive task patterns + +## Resources + +### references/ +Detailed reference documentation for advanced usage: + +- **api_reference.md:** Complete API documentation for A1 agent, configuration objects, and utility functions +- **llm_providers.md:** Comprehensive guide for configuring all supported LLM providers (Anthropic, OpenAI, Azure, Gemini, Groq, Ollama, AWS Bedrock) +- **task_examples.md:** Extensive collection of biomedical task examples with code patterns + +### scripts/ +Helper scripts for common operations: + +- **setup_environment.py:** Automated environment setup and validation +- **generate_report.py:** Enhanced PDF report generation with custom formatting + +Load reference documentation as needed: +```python +# Claude can read reference files when needed for detailed information +# Example: "Check references/llm_providers.md for Azure OpenAI configuration" +``` diff --git a/scientific-packages/biomni/references/api_reference.md b/scientific-packages/biomni/references/api_reference.md new file mode 100644 index 0000000..91aa46a --- /dev/null +++ b/scientific-packages/biomni/references/api_reference.md @@ -0,0 +1,635 @@ +# Biomni API Reference + +This document provides comprehensive API documentation for the Biomni biomedical AI agent system. + +## Core Classes + +### A1 Agent + +The primary agent class for executing biomedical research tasks. + +#### Initialization + +```python +from biomni.agent import A1 + +agent = A1( + path='./data', # Path to biomedical knowledge base + llm='claude-sonnet-4-20250514', # LLM model identifier + timeout=None, # Optional timeout in seconds + verbose=True # Enable detailed logging +) +``` + +**Parameters:** + +- `path` (str, required): Directory path where the biomedical knowledge base is stored or will be downloaded. First-time initialization will download ~11GB of data. +- `llm` (str, optional): LLM model identifier. Defaults to the value in `default_config.llm`. Supports multiple providers (see LLM Providers section). +- `timeout` (int, optional): Maximum execution time in seconds for agent operations. Overrides `default_config.timeout_seconds`. +- `verbose` (bool, optional): Enable verbose logging for debugging. Default: True. + +**Returns:** A1 agent instance ready for task execution. + +#### Methods + +##### `go(task_description: str) -> None` + +Execute a biomedical research task autonomously. + +```python +agent.go("Analyze this scRNA-seq dataset and identify cell types") +``` + +**Parameters:** +- `task_description` (str, required): Natural language description of the biomedical task to execute. Be specific about: + - Data location and format + - Desired analysis or output + - Any specific methods or parameters + - Expected results format + +**Behavior:** +1. Decomposes the task into executable steps +2. Retrieves relevant biomedical knowledge from the data lake +3. Generates and executes Python/R code +4. Provides results and visualizations +5. Handles errors and retries with refinement + +**Notes:** +- Executes code with system privileges - use in sandboxed environments +- Long-running tasks may require timeout adjustments +- Intermediate results are displayed during execution + +##### `save_conversation_history(output_path: str, format: str = 'pdf') -> None` + +Export conversation history and execution trace as a formatted report. + +```python +agent.save_conversation_history( + output_path='./reports/analysis_log.pdf', + format='pdf' +) +``` + +**Parameters:** +- `output_path` (str, required): File path for the output report +- `format` (str, optional): Output format. Options: 'pdf', 'markdown'. Default: 'pdf' + +**Requirements:** +- For PDF: Install one of: WeasyPrint, markdown2pdf, or Pandoc + ```bash + pip install weasyprint # Recommended + # or + pip install markdown2pdf + # or install Pandoc system-wide + ``` + +**Report Contents:** +- Task description and parameters +- Retrieved biomedical knowledge +- Generated code with execution traces +- Results, visualizations, and outputs +- Timestamps and execution metadata + +##### `add_mcp(config_path: str) -> None` + +Add Model Context Protocol (MCP) tools to extend agent capabilities. + +```python +agent.add_mcp(config_path='./mcp_tools_config.json') +``` + +**Parameters:** +- `config_path` (str, required): Path to MCP configuration JSON file + +**MCP Configuration Format:** +```json +{ + "tools": [ + { + "name": "tool_name", + "endpoint": "http://localhost:8000/tool", + "description": "Tool description for LLM", + "parameters": { + "param1": "string", + "param2": "integer" + } + } + ] +} +``` + +**Use Cases:** +- Connect to laboratory information systems +- Integrate proprietary databases +- Access specialized computational resources +- Link to institutional data repositories + +## Configuration + +### default_config + +Global configuration object for Biomni settings. + +```python +from biomni.config import default_config +``` + +#### Attributes + +##### `llm: str` + +Default LLM model identifier for all agent instances. + +```python +default_config.llm = "claude-sonnet-4-20250514" +``` + +**Supported Models:** + +**Anthropic:** +- `claude-sonnet-4-20250514` (Recommended) +- `claude-opus-4-20250514` +- `claude-3-5-sonnet-20241022` +- `claude-3-opus-20240229` + +**OpenAI:** +- `gpt-4o` +- `gpt-4` +- `gpt-4-turbo` +- `gpt-3.5-turbo` + +**Azure OpenAI:** +- `azure/gpt-4` +- `azure/` + +**Google Gemini:** +- `gemini/gemini-pro` +- `gemini/gemini-1.5-pro` + +**Groq:** +- `groq/llama-3.1-70b-versatile` +- `groq/mixtral-8x7b-32768` + +**Ollama (Local):** +- `ollama/llama3` +- `ollama/mistral` +- `ollama/` + +**AWS Bedrock:** +- `bedrock/anthropic.claude-v2` +- `bedrock/anthropic.claude-3-sonnet` + +**Custom/Biomni-R0:** +- `openai/biomni-r0` (requires local SGLang deployment) + +##### `timeout_seconds: int` + +Default timeout for agent operations in seconds. + +```python +default_config.timeout_seconds = 1200 # 20 minutes +``` + +**Recommended Values:** +- Simple tasks (QC, basic analysis): 300-600 seconds +- Medium tasks (differential expression, clustering): 600-1200 seconds +- Complex tasks (full pipelines, ML models): 1200-3600 seconds +- Very complex tasks: 3600+ seconds + +##### `data_path: str` + +Default path to biomedical knowledge base. + +```python +default_config.data_path = "/path/to/biomni/data" +``` + +**Storage Requirements:** +- Initial download: ~11GB +- Extracted size: ~15GB +- Additional working space: ~5-10GB recommended + +##### `api_base: str` + +Custom API endpoint for LLM providers (advanced usage). + +```python +# For local Biomni-R0 deployment +default_config.api_base = "http://localhost:30000/v1" + +# For custom OpenAI-compatible endpoints +default_config.api_base = "https://your-endpoint.com/v1" +``` + +##### `max_retries: int` + +Number of retry attempts for failed operations. + +```python +default_config.max_retries = 3 +``` + +#### Methods + +##### `reset() -> None` + +Reset all configuration values to system defaults. + +```python +default_config.reset() +``` + +## Database Query System + +Biomni includes a retrieval-augmented generation (RAG) system for querying the biomedical knowledge base. + +### Query Functions + +#### `query_genes(query: str, top_k: int = 10) -> List[Dict]` + +Query gene information from integrated databases. + +```python +from biomni.database import query_genes + +results = query_genes( + query="genes involved in p53 pathway", + top_k=20 +) +``` + +**Parameters:** +- `query` (str): Natural language or gene identifier query +- `top_k` (int): Number of results to return + +**Returns:** List of dictionaries containing: +- `gene_symbol`: Official gene symbol +- `gene_name`: Full gene name +- `description`: Functional description +- `pathways`: Associated biological pathways +- `go_terms`: Gene Ontology annotations +- `diseases`: Associated diseases +- `similarity_score`: Relevance score (0-1) + +#### `query_proteins(query: str, top_k: int = 10) -> List[Dict]` + +Query protein information from UniProt and other sources. + +```python +from biomni.database import query_proteins + +results = query_proteins( + query="kinase proteins in cell cycle", + top_k=15 +) +``` + +**Returns:** List of dictionaries with protein metadata: +- `uniprot_id`: UniProt accession +- `protein_name`: Protein name +- `function`: Functional annotation +- `domains`: Protein domains +- `subcellular_location`: Cellular localization +- `similarity_score`: Relevance score + +#### `query_drugs(query: str, top_k: int = 10) -> List[Dict]` + +Query drug and compound information. + +```python +from biomni.database import query_drugs + +results = query_drugs( + query="FDA approved cancer drugs targeting EGFR", + top_k=10 +) +``` + +**Returns:** Drug information including: +- `drug_name`: Common name +- `drugbank_id`: DrugBank identifier +- `indication`: Therapeutic indication +- `mechanism`: Mechanism of action +- `targets`: Molecular targets +- `approval_status`: Regulatory status +- `smiles`: Chemical structure (SMILES notation) + +#### `query_diseases(query: str, top_k: int = 10) -> List[Dict]` + +Query disease information from clinical databases. + +```python +from biomni.database import query_diseases + +results = query_diseases( + query="autoimmune diseases affecting joints", + top_k=10 +) +``` + +**Returns:** Disease data: +- `disease_name`: Standard disease name +- `disease_id`: Ontology identifier +- `symptoms`: Clinical manifestations +- `associated_genes`: Genetic associations +- `prevalence`: Epidemiological data + +#### `query_pathways(query: str, top_k: int = 10) -> List[Dict]` + +Query biological pathways from KEGG, Reactome, and other sources. + +```python +from biomni.database import query_pathways + +results = query_pathways( + query="immune response signaling pathways", + top_k=15 +) +``` + +**Returns:** Pathway information: +- `pathway_name`: Pathway name +- `pathway_id`: Database identifier +- `genes`: Genes in pathway +- `description`: Functional description +- `source`: Database source (KEGG, Reactome, etc.) + +## Data Structures + +### TaskResult + +Result object returned by complex agent operations. + +```python +class TaskResult: + success: bool # Whether task completed successfully + output: Any # Task output (varies by task) + code: str # Generated code + execution_time: float # Execution time in seconds + error: Optional[str] # Error message if failed + metadata: Dict # Additional metadata +``` + +### BiomedicalEntity + +Base class for biomedical entities in the knowledge base. + +```python +class BiomedicalEntity: + entity_id: str # Unique identifier + entity_type: str # Type (gene, protein, drug, etc.) + name: str # Entity name + description: str # Description + attributes: Dict # Additional attributes + references: List[str] # Literature references +``` + +## Utility Functions + +### `download_data(path: str, force: bool = False) -> None` + +Manually download or update the biomedical knowledge base. + +```python +from biomni.utils import download_data + +download_data( + path='./data', + force=True # Force re-download +) +``` + +### `validate_environment() -> Dict[str, bool]` + +Check if the environment is properly configured. + +```python +from biomni.utils import validate_environment + +status = validate_environment() +# Returns: { +# 'conda_env': True, +# 'api_keys': True, +# 'data_available': True, +# 'dependencies': True +# } +``` + +### `list_available_models() -> List[str]` + +Get a list of available LLM models based on configured API keys. + +```python +from biomni.utils import list_available_models + +models = list_available_models() +# Returns: ['claude-sonnet-4-20250514', 'gpt-4o', ...] +``` + +## Error Handling + +### Common Exceptions + +#### `BiomniConfigError` + +Raised when configuration is invalid or incomplete. + +```python +from biomni.exceptions import BiomniConfigError + +try: + agent = A1(path='./data') +except BiomniConfigError as e: + print(f"Configuration error: {e}") +``` + +#### `BiomniExecutionError` + +Raised when code generation or execution fails. + +```python +from biomni.exceptions import BiomniExecutionError + +try: + agent.go("invalid task") +except BiomniExecutionError as e: + print(f"Execution failed: {e}") + # Access failed code: e.code + # Access error details: e.details +``` + +#### `BiomniDataError` + +Raised when knowledge base or data access fails. + +```python +from biomni.exceptions import BiomniDataError + +try: + results = query_genes("unknown query format") +except BiomniDataError as e: + print(f"Data access error: {e}") +``` + +#### `BiomniTimeoutError` + +Raised when operations exceed timeout limit. + +```python +from biomni.exceptions import BiomniTimeoutError + +try: + agent.go("very complex long-running task") +except BiomniTimeoutError as e: + print(f"Task timed out after {e.duration} seconds") + # Partial results may be available: e.partial_results +``` + +## Best Practices + +### Efficient Knowledge Retrieval + +Pre-query databases for relevant context before complex tasks: + +```python +from biomni.database import query_genes, query_pathways + +# Gather relevant biological context first +genes = query_genes("cell cycle genes", top_k=50) +pathways = query_pathways("cell cycle regulation", top_k=20) + +# Then execute task with enriched context +agent.go(f""" +Analyze the cell cycle progression in this dataset. +Focus on these genes: {[g['gene_symbol'] for g in genes]} +Consider these pathways: {[p['pathway_name'] for p in pathways]} +""") +``` + +### Error Recovery + +Implement robust error handling for production workflows: + +```python +from biomni.exceptions import BiomniExecutionError, BiomniTimeoutError + +max_attempts = 3 +for attempt in range(max_attempts): + try: + agent.go("complex biomedical task") + break + except BiomniTimeoutError: + # Increase timeout and retry + default_config.timeout_seconds *= 2 + print(f"Timeout, retrying with {default_config.timeout_seconds}s timeout") + except BiomniExecutionError as e: + # Refine task based on error + print(f"Execution failed: {e}, refining task...") + # Optionally modify task description + else: + print("Task failed after max attempts") +``` + +### Memory Management + +For large-scale analyses, manage memory explicitly: + +```python +import gc + +# Process datasets in chunks +for chunk_id in range(num_chunks): + agent.go(f"Process data chunk {chunk_id} located at data/chunk_{chunk_id}.h5ad") + + # Force garbage collection between chunks + gc.collect() + + # Save intermediate results + agent.save_conversation_history(f"./reports/chunk_{chunk_id}.pdf") +``` + +### Reproducibility + +Ensure reproducible analyses by: + +1. **Fixing random seeds:** +```python +agent.go("Set random seed to 42 for all analyses, then perform clustering...") +``` + +2. **Logging configuration:** +```python +import json +config_log = { + 'llm': default_config.llm, + 'timeout': default_config.timeout_seconds, + 'data_path': default_config.data_path, + 'timestamp': datetime.now().isoformat() +} +with open('config_log.json', 'w') as f: + json.dump(config_log, f, indent=2) +``` + +3. **Saving execution traces:** +```python +# Always save detailed reports +agent.save_conversation_history('./reports/full_analysis.pdf') +``` + +## Performance Optimization + +### Model Selection Strategy + +Choose models based on task characteristics: + +```python +# For exploratory, simple tasks +default_config.llm = "gpt-3.5-turbo" # Fast, cost-effective + +# For standard biomedical analyses +default_config.llm = "claude-sonnet-4-20250514" # Recommended + +# For complex reasoning and hypothesis generation +default_config.llm = "claude-opus-4-20250514" # Highest quality + +# For specialized biological reasoning +default_config.llm = "openai/biomni-r0" # Requires local deployment +``` + +### Timeout Tuning + +Set appropriate timeouts based on task complexity: + +```python +# Quick queries and simple analyses +agent = A1(path='./data', timeout=300) + +# Standard workflows +agent = A1(path='./data', timeout=1200) + +# Full pipelines with ML training +agent = A1(path='./data', timeout=3600) +``` + +### Caching and Reuse + +Reuse agent instances for multiple related tasks: + +```python +# Create agent once +agent = A1(path='./data', llm='claude-sonnet-4-20250514') + +# Execute multiple related tasks +tasks = [ + "Load and QC the scRNA-seq dataset", + "Perform clustering with resolution 0.5", + "Identify marker genes for each cluster", + "Annotate cell types based on markers" +] + +for task in tasks: + agent.go(task) + +# Save complete workflow +agent.save_conversation_history('./reports/full_workflow.pdf') +``` diff --git a/scientific-packages/biomni/references/llm_providers.md b/scientific-packages/biomni/references/llm_providers.md new file mode 100644 index 0000000..d4dfd2e --- /dev/null +++ b/scientific-packages/biomni/references/llm_providers.md @@ -0,0 +1,649 @@ +# LLM Provider Configuration Guide + +This document provides comprehensive configuration instructions for all LLM providers supported by Biomni. + +## Overview + +Biomni supports multiple LLM providers through a unified interface. Configure providers using: +- Environment variables +- `.env` files +- Runtime configuration via `default_config` + +## Quick Reference Table + +| Provider | Recommended For | API Key Required | Cost | Setup Complexity | +|----------|----------------|------------------|------|------------------| +| Anthropic Claude | Most biomedical tasks | Yes | Medium | Easy | +| OpenAI | General tasks | Yes | Medium-High | Easy | +| Azure OpenAI | Enterprise deployment | Yes | Varies | Medium | +| Google Gemini | Multimodal tasks | Yes | Medium | Easy | +| Groq | Fast inference | Yes | Low | Easy | +| Ollama | Local/offline use | No | Free | Medium | +| AWS Bedrock | AWS ecosystem | Yes | Varies | Hard | +| Biomni-R0 | Complex biological reasoning | No | Free | Hard | + +## Anthropic Claude (Recommended) + +### Overview + +Claude models from Anthropic provide excellent biological reasoning capabilities and are the recommended choice for most Biomni tasks. + +### Setup + +1. **Obtain API Key:** + - Sign up at https://console.anthropic.com/ + - Navigate to API Keys section + - Generate a new key + +2. **Configure Environment:** + + **Option A: Environment Variable** + ```bash + export ANTHROPIC_API_KEY="sk-ant-api03-..." + ``` + + **Option B: .env File** + ```bash + # .env file in project root + ANTHROPIC_API_KEY=sk-ant-api03-... + ``` + +3. **Set Model in Code:** + ```python + from biomni.config import default_config + + # Claude Sonnet 4 (Recommended) + default_config.llm = "claude-sonnet-4-20250514" + + # Claude Opus 4 (Most capable) + default_config.llm = "claude-opus-4-20250514" + + # Claude 3.5 Sonnet (Previous version) + default_config.llm = "claude-3-5-sonnet-20241022" + ``` + +### Available Models + +| Model | Context Window | Strengths | Best For | +|-------|---------------|-----------|----------| +| `claude-sonnet-4-20250514` | 200K tokens | Balanced performance, cost-effective | Most biomedical tasks | +| `claude-opus-4-20250514` | 200K tokens | Highest capability, complex reasoning | Difficult multi-step analyses | +| `claude-3-5-sonnet-20241022` | 200K tokens | Fast, reliable | Standard workflows | +| `claude-3-opus-20240229` | 200K tokens | Strong reasoning | Legacy support | + +### Advanced Configuration + +```python +from biomni.config import default_config + +# Use Claude with custom parameters +default_config.llm = "claude-sonnet-4-20250514" +default_config.timeout_seconds = 1800 + +# Optional: Custom API endpoint (for proxy/enterprise) +default_config.api_base = "https://your-proxy.com/v1" +``` + +### Cost Estimation + +Approximate costs per 1M tokens (as of January 2025): +- Input: $3-15 depending on model +- Output: $15-75 depending on model + +For a typical biomedical analysis (~50K tokens total): $0.50-$2.00 + +## OpenAI + +### Overview + +OpenAI's GPT models provide strong general capabilities suitable for diverse biomedical tasks. + +### Setup + +1. **Obtain API Key:** + - Sign up at https://platform.openai.com/ + - Navigate to API Keys + - Create new secret key + +2. **Configure Environment:** + + ```bash + export OPENAI_API_KEY="sk-proj-..." + ``` + + Or in `.env`: + ``` + OPENAI_API_KEY=sk-proj-... + ``` + +3. **Set Model:** + ```python + from biomni.config import default_config + + default_config.llm = "gpt-4o" # Recommended + # default_config.llm = "gpt-4" # Previous flagship + # default_config.llm = "gpt-4-turbo" # Fast variant + # default_config.llm = "gpt-3.5-turbo" # Budget option + ``` + +### Available Models + +| Model | Context Window | Strengths | Cost | +|-------|---------------|-----------|------| +| `gpt-4o` | 128K tokens | Fast, multimodal | Medium | +| `gpt-4-turbo` | 128K tokens | Fast inference | Medium | +| `gpt-4` | 8K tokens | Reliable | High | +| `gpt-3.5-turbo` | 16K tokens | Fast, cheap | Low | + +### Cost Optimization + +```python +# For exploratory analysis (budget-conscious) +default_config.llm = "gpt-3.5-turbo" + +# For production analysis (quality-focused) +default_config.llm = "gpt-4o" +``` + +## Azure OpenAI + +### Overview + +Azure-hosted OpenAI models for enterprise users requiring data residency and compliance. + +### Setup + +1. **Azure Prerequisites:** + - Active Azure subscription + - Azure OpenAI resource created + - Model deployment configured + +2. **Environment Variables:** + ```bash + export AZURE_OPENAI_API_KEY="your-key" + export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com/" + export AZURE_OPENAI_API_VERSION="2024-02-15-preview" + ``` + +3. **Configuration:** + ```python + from biomni.config import default_config + + # Option 1: Use deployment name + default_config.llm = "azure/your-deployment-name" + + # Option 2: Specify endpoint explicitly + default_config.llm = "azure/gpt-4" + default_config.api_base = "https://your-resource.openai.azure.com/" + ``` + +### Deployment Setup + +Azure OpenAI requires explicit model deployments: + +1. Navigate to Azure OpenAI Studio +2. Create deployment for desired model (e.g., GPT-4) +3. Note the deployment name +4. Use deployment name in Biomni configuration + +### Example Configuration + +```python +from biomni.config import default_config +import os + +# Set Azure credentials +os.environ['AZURE_OPENAI_API_KEY'] = 'your-key' +os.environ['AZURE_OPENAI_ENDPOINT'] = 'https://your-resource.openai.azure.com/' + +# Configure Biomni to use Azure deployment +default_config.llm = "azure/gpt-4-biomni" # Your deployment name +default_config.api_base = os.environ['AZURE_OPENAI_ENDPOINT'] +``` + +## Google Gemini + +### Overview + +Google's Gemini models offer multimodal capabilities and competitive performance. + +### Setup + +1. **Obtain API Key:** + - Visit https://makersuite.google.com/app/apikey + - Create new API key + +2. **Environment Configuration:** + ```bash + export GEMINI_API_KEY="your-key" + ``` + +3. **Set Model:** + ```python + from biomni.config import default_config + + default_config.llm = "gemini/gemini-1.5-pro" + # Or: default_config.llm = "gemini/gemini-pro" + ``` + +### Available Models + +| Model | Context Window | Strengths | +|-------|---------------|-----------| +| `gemini/gemini-1.5-pro` | 1M tokens | Very large context, multimodal | +| `gemini/gemini-pro` | 32K tokens | Balanced performance | + +### Use Cases + +Gemini excels at: +- Tasks requiring very large context windows +- Multimodal analysis (when incorporating images) +- Cost-effective alternative to GPT-4 + +```python +# For tasks with large context requirements +default_config.llm = "gemini/gemini-1.5-pro" +default_config.timeout_seconds = 2400 # May need longer timeout +``` + +## Groq + +### Overview + +Groq provides ultra-fast inference with open-source models, ideal for rapid iteration. + +### Setup + +1. **Get API Key:** + - Sign up at https://console.groq.com/ + - Generate API key + +2. **Configure:** + ```bash + export GROQ_API_KEY="gsk_..." + ``` + +3. **Set Model:** + ```python + from biomni.config import default_config + + default_config.llm = "groq/llama-3.1-70b-versatile" + # Or: default_config.llm = "groq/mixtral-8x7b-32768" + ``` + +### Available Models + +| Model | Context Window | Speed | Quality | +|-------|---------------|-------|---------| +| `groq/llama-3.1-70b-versatile` | 32K tokens | Very Fast | Good | +| `groq/mixtral-8x7b-32768` | 32K tokens | Very Fast | Good | +| `groq/llama-3-70b-8192` | 8K tokens | Ultra Fast | Moderate | + +### Best Practices + +```python +# For rapid prototyping and testing +default_config.llm = "groq/llama-3.1-70b-versatile" +default_config.timeout_seconds = 600 # Groq is fast + +# Note: Quality may be lower than GPT-4/Claude for complex tasks +# Recommended for: QC, simple analyses, testing workflows +``` + +## Ollama (Local Deployment) + +### Overview + +Run LLMs entirely locally for offline use, data privacy, or cost savings. + +### Setup + +1. **Install Ollama:** + ```bash + # macOS/Linux + curl -fsSL https://ollama.com/install.sh | sh + + # Or download from https://ollama.com/download + ``` + +2. **Pull Models:** + ```bash + ollama pull llama3 # Meta Llama 3 (8B) + ollama pull mixtral # Mixtral (47B) + ollama pull codellama # Code-specialized + ollama pull medllama # Medical domain (if available) + ``` + +3. **Start Ollama Server:** + ```bash + ollama serve # Runs on http://localhost:11434 + ``` + +4. **Configure Biomni:** + ```python + from biomni.config import default_config + + default_config.llm = "ollama/llama3" + default_config.api_base = "http://localhost:11434" + ``` + +### Hardware Requirements + +Minimum recommendations: +- **8B models:** 16GB RAM, CPU inference acceptable +- **70B models:** 64GB RAM, GPU highly recommended +- **Storage:** 5-50GB per model + +### Model Selection + +```python +# Fast, local, good for testing +default_config.llm = "ollama/llama3" + +# Better quality (requires more resources) +default_config.llm = "ollama/mixtral" + +# Code generation tasks +default_config.llm = "ollama/codellama" +``` + +### Advantages & Limitations + +**Advantages:** +- Complete data privacy +- No API costs +- Offline operation +- Unlimited usage + +**Limitations:** +- Lower quality than GPT-4/Claude for complex tasks +- Requires significant hardware +- Slower inference (especially on CPU) +- May struggle with specialized biomedical knowledge + +## AWS Bedrock + +### Overview + +AWS-managed LLM service offering multiple model providers. + +### Setup + +1. **AWS Prerequisites:** + - AWS account with Bedrock access + - Model access enabled in Bedrock console + - AWS credentials configured + +2. **Configure AWS Credentials:** + ```bash + # Option 1: AWS CLI + aws configure + + # Option 2: Environment variables + export AWS_ACCESS_KEY_ID="your-key" + export AWS_SECRET_ACCESS_KEY="your-secret" + export AWS_REGION="us-east-1" + ``` + +3. **Enable Model Access:** + - Navigate to AWS Bedrock console + - Request access to desired models + - Wait for approval (may take hours/days) + +4. **Configure Biomni:** + ```python + from biomni.config import default_config + + default_config.llm = "bedrock/anthropic.claude-3-sonnet" + # Or: default_config.llm = "bedrock/anthropic.claude-v2" + ``` + +### Available Models + +Bedrock provides access to: +- Anthropic Claude models +- Amazon Titan models +- AI21 Jurassic models +- Cohere Command models +- Meta Llama models + +### IAM Permissions + +Required IAM policy: +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream" + ], + "Resource": "arn:aws:bedrock:*::foundation-model/*" + } + ] +} +``` + +### Example Configuration + +```python +from biomni.config import default_config +import boto3 + +# Verify AWS credentials +session = boto3.Session() +credentials = session.get_credentials() +print(f"AWS Access Key: {credentials.access_key[:8]}...") + +# Configure Biomni +default_config.llm = "bedrock/anthropic.claude-3-sonnet" +default_config.timeout_seconds = 1800 +``` + +## Biomni-R0 (Local Specialized Model) + +### Overview + +Biomni-R0 is a 32B parameter reasoning model specifically trained for biological problem-solving. Provides the highest quality for complex biomedical reasoning but requires local deployment. + +### Setup + +1. **Hardware Requirements:** + - GPU with 48GB+ VRAM (e.g., A100, H100) + - Or multi-GPU setup (2x 24GB) + - 100GB+ storage for model weights + +2. **Install Dependencies:** + ```bash + pip install "sglang[all]" + pip install flashinfer # Optional but recommended + ``` + +3. **Deploy Model:** + ```bash + python -m sglang.launch_server \ + --model-path snap-stanford/biomni-r0 \ + --host 0.0.0.0 \ + --port 30000 \ + --trust-remote-code \ + --mem-fraction-static 0.8 + ``` + + For multi-GPU: + ```bash + python -m sglang.launch_server \ + --model-path snap-stanford/biomni-r0 \ + --host 0.0.0.0 \ + --port 30000 \ + --trust-remote-code \ + --tp 2 # Tensor parallelism across 2 GPUs + ``` + +4. **Configure Biomni:** + ```python + from biomni.config import default_config + + default_config.llm = "openai/biomni-r0" + default_config.api_base = "http://localhost:30000/v1" + default_config.timeout_seconds = 2400 # Longer for complex reasoning + ``` + +### When to Use Biomni-R0 + +Biomni-R0 excels at: +- Multi-step biological reasoning +- Complex experimental design +- Hypothesis generation and evaluation +- Literature-informed analysis +- Tasks requiring deep biological knowledge + +```python +# For complex biological reasoning tasks +default_config.llm = "openai/biomni-r0" + +agent.go(""" +Design a comprehensive CRISPR screening experiment to identify synthetic +lethal interactions with TP53 mutations in cancer cells, including: +1. Rationale and hypothesis +2. Guide RNA library design strategy +3. Experimental controls +4. Statistical analysis plan +5. Expected outcomes and validation approach +""") +``` + +### Performance Comparison + +| Model | Speed | Biological Reasoning | Code Quality | Cost | +|-------|-------|---------------------|--------------|------| +| GPT-4 | Fast | Good | Excellent | Medium | +| Claude Sonnet 4 | Fast | Excellent | Excellent | Medium | +| Biomni-R0 | Moderate | Outstanding | Good | Free (local) | + +## Multi-Provider Strategy + +### Intelligent Model Selection + +Use different models for different task types: + +```python +from biomni.agent import A1 +from biomni.config import default_config + +# Strategy 1: Task-based selection +def get_agent_for_task(task_complexity): + if task_complexity == "simple": + default_config.llm = "gpt-3.5-turbo" + default_config.timeout_seconds = 300 + elif task_complexity == "medium": + default_config.llm = "claude-sonnet-4-20250514" + default_config.timeout_seconds = 1200 + else: # complex + default_config.llm = "openai/biomni-r0" + default_config.timeout_seconds = 2400 + + return A1(path='./data') + +# Strategy 2: Fallback on failure +def execute_with_fallback(task): + models = [ + "claude-sonnet-4-20250514", + "gpt-4o", + "claude-opus-4-20250514" + ] + + for model in models: + try: + default_config.llm = model + agent = A1(path='./data') + agent.go(task) + return + except Exception as e: + print(f"Failed with {model}: {e}, trying next...") + + raise Exception("All models failed") +``` + +### Cost Optimization Strategy + +```python +# Phase 1: Rapid prototyping with cheap models +default_config.llm = "gpt-3.5-turbo" +agent.go("Quick exploratory analysis of dataset structure") + +# Phase 2: Detailed analysis with high-quality models +default_config.llm = "claude-sonnet-4-20250514" +agent.go("Comprehensive differential expression analysis with pathway enrichment") + +# Phase 3: Complex reasoning with specialized models +default_config.llm = "openai/biomni-r0" +agent.go("Generate biological hypotheses based on multi-omics integration") +``` + +## Troubleshooting + +### Common Issues + +**Issue: "API key not found"** +- Verify environment variable is set: `echo $ANTHROPIC_API_KEY` +- Check `.env` file exists and is in correct location +- Try setting key programmatically: `os.environ['ANTHROPIC_API_KEY'] = 'key'` + +**Issue: "Rate limit exceeded"** +- Implement exponential backoff and retry +- Upgrade API tier if available +- Switch to alternative provider temporarily + +**Issue: "Model not found"** +- Verify model identifier is correct +- Check API key has access to requested model +- For Azure: ensure deployment exists with exact name + +**Issue: "Timeout errors"** +- Increase `default_config.timeout_seconds` +- Break complex tasks into smaller steps +- Consider using faster model for initial phases + +**Issue: "Connection refused (Ollama/Biomni-R0)"** +- Verify local server is running +- Check port is not blocked by firewall +- Confirm `api_base` URL is correct + +### Testing Configuration + +```python +from biomni.utils import list_available_models, validate_environment + +# Check environment setup +status = validate_environment() +print("Environment Status:", status) + +# List available models based on configured keys +models = list_available_models() +print("Available Models:", models) + +# Test specific model +try: + from biomni.agent import A1 + agent = A1(path='./data', llm='claude-sonnet-4-20250514') + agent.go("Print 'Configuration successful!'") +except Exception as e: + print(f"Configuration test failed: {e}") +``` + +## Best Practices Summary + +1. **For most users:** Start with Claude Sonnet 4 or GPT-4o +2. **For cost sensitivity:** Use GPT-3.5-turbo for exploration, Claude Sonnet 4 for production +3. **For privacy/offline:** Deploy Ollama locally +4. **For complex reasoning:** Use Biomni-R0 if hardware available +5. **For enterprise:** Consider Azure OpenAI or AWS Bedrock +6. **For speed:** Use Groq for rapid iteration + +7. **Always:** + - Set appropriate timeouts + - Implement error handling and retries + - Log model and configuration for reproducibility + - Test configuration before production use diff --git a/scientific-packages/biomni/references/task_examples.md b/scientific-packages/biomni/references/task_examples.md new file mode 100644 index 0000000..4294c7d --- /dev/null +++ b/scientific-packages/biomni/references/task_examples.md @@ -0,0 +1,1472 @@ +# Biomni Task Examples + +Comprehensive collection of biomedical task examples with code patterns and best practices. + +## Table of Contents + +1. [Single-Cell RNA-seq Analysis](#single-cell-rna-seq-analysis) +2. [CRISPR Screening](#crispr-screening) +3. [Genomic Analysis (GWAS, Variant Calling)](#genomic-analysis) +4. [Protein Structure and Function](#protein-structure-and-function) +5. [Drug Discovery and ADMET](#drug-discovery-and-admet) +6. [Pathway and Network Analysis](#pathway-and-network-analysis) +7. [Disease Classification](#disease-classification) +8. [Multi-Omics Integration](#multi-omics-integration) +9. [Proteomics Analysis](#proteomics-analysis) +10. [Biomarker Discovery](#biomarker-discovery) + +--- + +## Single-Cell RNA-seq Analysis + +### Basic scRNA-seq Pipeline + +```python +from biomni.agent import A1 + +agent = A1(path='./data', llm='claude-sonnet-4-20250514') + +agent.go(""" +Analyze the 10X Genomics scRNA-seq dataset located at 'data/pbmc_10k.h5ad'. + +Workflow: +1. Load the data and perform QC: + - Filter cells with <200 genes or >5000 genes + - Filter cells with >10% mitochondrial reads + - Filter genes expressed in <3 cells + +2. Normalize and identify highly variable genes: + - Use SCTransform or standard log-normalization + - Identify top 2000 HVGs + +3. Dimensionality reduction: + - PCA (50 components) + - UMAP for visualization + +4. Clustering: + - Find neighbors (k=10) + - Leiden clustering with resolution 0.5 + +5. Visualization: + - UMAP colored by cluster + - QC metrics on UMAP + +Save processed data as 'results/pbmc_processed.h5ad' +""") +``` + +### Cell Type Annotation + +```python +agent.go(""" +Using the processed PBMC data at 'results/pbmc_processed.h5ad': + +1. Find marker genes for each cluster: + - Wilcoxon rank-sum test + - Log fold change > 0.5 + - Adjusted p-value < 0.01 + - Present in >25% of cluster cells + +2. Annotate cell types using markers: + - T cells: CD3D, CD3E, CD3G + - B cells: CD19, MS4A1 (CD20) + - NK cells: GNLY, NKG7, NCAM1 + - Monocytes: CD14, LYZ, CD68 + - Dendritic cells: FCER1A, CD1C + +3. Create visualization: + - UMAP with cell type labels + - Dotplot of marker genes by cell type + - Proportion of cell types (bar plot) + +4. Save annotated data with cell types +""") +``` + +### Differential Expression Between Conditions + +```python +agent.go(""" +Compare gene expression between stimulated and control conditions: + +Data: 'data/immune_stim_experiment.h5ad' (contains 'condition' metadata) + +Analysis: +1. Subset to T cells only (cell_type == 'T cell') + +2. Differential expression between stim vs control: + - Use pseudobulk approach (aggregate by donor + condition) + - DESeq2 or edgeR for statistical testing + - Filter: |log2FC| > 1, padj < 0.05 + +3. Pathway enrichment on DEGs: + - Use GO biological processes + - Use KEGG pathways + - Run enrichment analysis with gprofiler or enrichr + +4. Visualization: + - Volcano plot of DEGs + - Heatmap of top 50 DEGs + - Bar plot of top enriched pathways + +5. Export results table with gene symbols, log2FC, p-values, and pathway annotations +""") +``` + +### Trajectory Analysis + +```python +agent.go(""" +Perform pseudotime trajectory analysis on hematopoietic differentiation data: + +Data: 'data/hematopoiesis.h5ad' + +Steps: +1. Subset to progenitor and mature cell types: + - HSC, MPP, GMP, Monocytes, Neutrophils + +2. Run trajectory inference: + - Use PAGA or Monocle3 + - Set HSC as root cell type + +3. Calculate pseudotime for all cells + +4. Identify trajectory-associated genes: + - Genes that change along pseudotime + - Statistical test with FDR < 0.05 + - Cluster genes by expression pattern (early, middle, late) + +5. Visualization: + - UMAP colored by pseudotime + - Heatmap of trajectory genes + - Gene expression along pseudotime for key TFs + +6. Functional analysis: + - GO enrichment for early/middle/late gene clusters +""") +``` + +### Integration of Multiple Datasets + +```python +agent.go(""" +Integrate three scRNA-seq datasets from different batches: + +Data files: +- 'data/batch1_pbmc.h5ad' +- 'data/batch2_pbmc.h5ad' +- 'data/batch3_pbmc.h5ad' + +Integration workflow: +1. Load all three datasets + +2. Perform individual QC on each batch: + - Same filters as standard QC + - Note batch-specific statistics + +3. Integration using Harmony or Scanorama: + - Concatenate datasets + - Identify HVGs on combined data + - Run batch correction + - Verify batch mixing with LISI score + +4. Re-cluster integrated data: + - Use corrected embeddings + - Leiden clustering + +5. Cell type annotation on integrated data + +6. Visualization: + - UMAP split by batch (before/after correction) + - UMAP colored by cell type + - Batch mixing statistics + +7. Save integrated dataset +""") +``` + +--- + +## CRISPR Screening + +### Guide RNA Design + +```python +agent.go(""" +Design guide RNAs for CRISPR knockout screening of cell cycle genes: + +Target genes: +- CDK1, CDK2, CDK4, CDK6 +- CCNA2, CCNB1, CCND1, CCNE1 +- TP53, RB1, MYC + +Requirements: +1. Design 4-6 guides per gene targeting early exons + +2. For each guide, evaluate: + - On-target efficiency score (Doench 2016) + - Off-target potential (CFD score < 0.3) + - Avoid common SNPs (1000 Genomes) + +3. Add control guides: + - 100 non-targeting controls + - 20 positive controls (essential genes) + +4. Output: + - Table with: gene, guide_sequence, PAM, position, on_target_score, off_target_count + - Sequences in format for oligonucleotide ordering + - Visual summary of guide distribution per gene + +Reference genome: hg38 +""") +``` + +### CRISPR Screen Analysis + +```python +agent.go(""" +Analyze data from a genome-wide CRISPR knockout screen: + +Data: 'data/crispr_screen_counts.csv' +- Columns: guide_id, gene, sample_T0, sample_T15, replicate +- ~80,000 guides targeting ~18,000 genes + +Analysis: +1. Quality control: + - Guide representation (reads per guide) + - Sample correlation + - Remove guides with <30 reads in T0 + +2. Normalize counts: + - Reads per million (RPM) + - Log2 fold change (T15 vs T0) + +3. Statistical analysis using MAGeCK: + - Identify significantly depleted/enriched genes + - FDR < 0.05 + - Rank genes by robust rank aggregation (RRA) + +4. Functional analysis: + - Pathway enrichment of hit genes + - Known vs novel essential genes + - Correlation with Cancer Dependency Map + +5. Visualization: + - Scatterplot: log2FC vs -log10(FDR) + - Heatmap: top 50 depleted genes across replicates + - Network: PPI network of hit genes + +6. Export: + - Ranked gene list with statistics + - Enriched pathways table +""") +``` + +### Pooled Optical Screening Analysis + +```python +agent.go(""" +Analyze pooled CRISPR screen with imaging readout (e.g., Cell Painting): + +Data structure: +- 'data/guide_assignments.csv': cell_id, guide_id, gene +- 'data/morphology_features.csv': cell_id, feature_1...feature_500 + +Analysis: +1. Feature preprocessing: + - Remove low-variance features + - Normalize features (z-score per plate) + - PCA for dimensionality reduction + +2. Associate phenotypes with perturbations: + - Aggregate cells by guide (mean/median) + - Calculate morphological distance from controls + - Statistical test for phenotype change + +3. Identify phenotype-altering genes: + - Mahalanobis distance from control distribution + - Bonferroni correction for multiple testing + - Effect size threshold + +4. Cluster genes by phenotype similarity: + - Hierarchical clustering of gene profiles + - Identify phenotype classes + +5. Validation and interpretation: + - Compare to known gene functions + - Pathway enrichment per phenotype cluster + +6. Visualization: + - UMAP of all perturbations + - Heatmap of gene clusters × morphology features + - Representative images for each cluster +""") +``` + +--- + +## Genomic Analysis + +### GWAS Analysis + +```python +agent.go(""" +Perform genome-wide association study for Type 2 Diabetes: + +Data: +- 'data/genotypes.bed' (PLINK format, 500K SNPs, 5000 cases, 5000 controls) +- 'data/phenotypes.txt' (sample_id, T2D_status, age, sex, BMI, ancestry_PCs) + +Workflow: +1. Quality control: + - SNP QC: MAF > 0.01, HWE p > 1e-6, genotyping rate > 0.95 + - Sample QC: genotyping rate > 0.95, heterozygosity check + - Remove related individuals (kinship > 0.125) + +2. Association testing: + - Logistic regression: T2D ~ SNP + age + sex + BMI + PC1-10 + - Genome-wide significance threshold: p < 5e-8 + - Suggestive threshold: p < 1e-5 + +3. Post-GWAS analysis: + - LD clumping (r² > 0.1, 500kb window) + - Annotate lead SNPs with nearby genes (±100kb) + - Query GWAS Catalog for known associations + +4. Functional annotation: + - Overlap with regulatory elements (ENCODE) + - eQTL colocalization (GTEx) + - GWAS prioritization scores (PoPS, ABC) + +5. Visualization: + - Manhattan plot + - QQ plot + - Regional association plots for top loci + - Locus zoom plots + +6. Heritability and genetic correlation: + - SNP heritability (LDSC) + - Genetic correlation with related traits + +Export summary statistics for meta-analysis +""") +``` + +### Whole Exome Sequencing Analysis + +```python +agent.go(""" +Analyze whole exome sequencing data for rare disease diagnosis: + +Data: Family trio (proband, mother, father) +- 'data/proband.bam' +- 'data/mother.bam' +- 'data/father.bam' + +Phenotype: Developmental delay, seizures, intellectual disability + +Pipeline: +1. Variant calling: + - GATK HaplotypeCaller on each sample + - Joint genotyping across trio + - VQSR filtering (SNPs and indels separately) + +2. Variant annotation: + - Functional consequence (VEP or ANNOVAR) + - Population frequencies (gnomAD) + - Pathogenicity predictions (CADD, REVEL, SpliceAI) + - Disease databases (ClinVar, OMIM) + +3. Inheritance analysis: + - De novo variants (absent in both parents) + - Compound heterozygous variants + - Rare homozygous variants (autozygosity) + - X-linked variants (if proband is male) + +4. Filtering strategy: + - Population AF < 0.001 (gnomAD) + - High-quality variants (GQ > 20, DP > 10) + - Loss-of-function or missense with CADD > 20 + - Match phenotype to gene function (HPO terms) + +5. Prioritization: + - Known disease genes for phenotype + - De novo in intolerant genes (pLI > 0.9) + - Protein-truncating variants + +6. Report: + - Top candidate variants with evidence + - Gene function and disease association + - Segregation analysis + - Recommended validation (Sanger sequencing) + - ACMG variant classification + +Save VCF with annotations and prioritized candidate list +""") +``` + +### Variant Calling from RNA-seq + +```python +agent.go(""" +Identify expressed variants from RNA-seq data: + +Data: Tumor RNA-seq BAM file +- 'data/tumor_RNAseq.bam' +- Reference: hg38 + +Purpose: Identify expressed somatic mutations for neoantigen prediction + +Steps: +1. Pre-processing: + - Mark duplicates (Picard) + - Split reads at junctions (GATK SplitNCigarReads) + - Base quality recalibration + +2. Variant calling: + - GATK HaplotypeCaller (RNA-seq mode) + - Filter: DP > 10, AF > 0.05 + +3. Filtering artifacts: + - Remove common SNPs (gnomAD AF > 0.001) + - Filter intronic/intergenic variants + - Remove known RNA editing sites (RADAR database) + - Panel of normals (if available) + +4. Annotation: + - Functional impact (VEP) + - Identify non-synonymous variants + - Predict MHC binding (NetMHCpan for patient HLA type) + +5. Prioritize neoantigens: + - Strong MHC binding (IC50 < 500nM) + - High expression (TPM > 5) + - High variant allele frequency + +6. Output: + - Annotated VCF + - Neoantigen candidates table + - Peptide sequences for validation + +This requires patient HLA typing data +""") +``` + +--- + +## Protein Structure and Function + +### Protein Structure Prediction and Analysis + +```python +agent.go(""" +Predict and analyze structure for novel protein sequence: + +Sequence (FASTA format): +>Novel_Kinase_Domain +MKLLVVDDDGVADYSKRDGAFMVAYCIEPGDG... + +Tasks: +1. Structure prediction: + - Use AlphaFold2 or ESMFold + - Generate 5 models, rank by confidence + +2. Quality assessment: + - pLDDT scores (per-residue confidence) + - pTM score (global confidence) + - Identify low-confidence regions + +3. Domain identification: + - InterProScan for domain architecture + - Pfam domain search + - Identify catalytic residues + +4. Functional site prediction: + - Active site prediction + - Substrate binding pocket identification + - Post-translational modification sites + +5. Structural alignment: + - Search for similar structures (PDB) + - Align to close homologs + - Identify conserved structural motifs + +6. Mutation analysis: + - Known disease mutations in homologs + - Predict impact on structure (Rosetta ddG) + +7. Visualization and output: + - PyMOL/Chimera visualization scripts + - Structural alignment figures + - Annotated PDB file with functional sites + - Summary report with predictions +""") +``` + +### Protein-Protein Interaction Prediction + +```python +agent.go(""" +Predict and validate protein-protein interactions: + +Target protein: BRCA1 +Species: Human + +Analysis: +1. Literature-based interactions: + - Query BioGRID, STRING, IntAct databases + - Extract high-confidence interactors (score > 0.7) + +2. Structure-based prediction: + - Predict BRCA1 structure (if not available) + - Dock with known interactors (BRCA2, BARD1, etc.) + - Score interfaces (PISA, PDBePISA) + +3. Sequence-based prediction: + - Coevolution analysis (EVcouplings) + - Domain-domain interaction prediction + - Linear motif search (ELM database) + +4. Functional analysis of interactors: + - GO enrichment analysis + - KEGG pathway membership + - Tissue/cell type expression patterns + +5. Network analysis: + - Build PPI network + - Identify network modules + - Central hub proteins + +6. Experimental validation suggestions: + - Prioritize interactions for validation + - Suggest Co-IP or Y2H experiments + - Identify commercially available antibodies + +7. Output: + - Ranked interaction list with evidence + - PPI network visualization + - Structural models of key interactions +""") +``` + +### Protein Engineering Design + +```python +agent.go(""" +Design improved enzyme variant with enhanced thermostability: + +Target enzyme: TEM-1 β-lactamase +Goal: Increase melting temperature by >10°C while maintaining activity + +Strategy: +1. Analyze current structure: + - Load PDB structure (1BTL) + - Identify flexible regions (B-factors) + - Find potential disulfide bond sites + +2. Computational design: + - Rosetta design simulations + - Identify stabilizing mutations (ΔΔG < -1.0 kcal/mol) + - Avoid active site and substrate binding regions + +3. Prioritize mutations: + - Surface entropy reduction (SER) + - Disulfide bond introduction + - Salt bridge formation + - Hydrophobic core packing + +4. Check conservation: + - Multiple sequence alignment of β-lactamases + - Avoid highly conserved positions + - Prefer positions with natural variation + +5. Design library: + - Rank top 20 single mutants + - Design 5 combinatorial variants (2-3 mutations) + - Ensure codon optimization for E. coli + +6. Validation plan: + - Expression and purification protocol + - Thermal shift assay (DSF) + - Activity assay (nitrocefin) + - Recommended high-throughput screening + +7. Output: + - Ranked mutation list with predicted ΔΔG + - Structural visualizations + - Codon-optimized sequences + - Cloning primers + - Experimental validation protocol +""") +``` + +--- + +## Drug Discovery and ADMET + +### Virtual Screening + +```python +agent.go(""" +Perform virtual screening for SARS-CoV-2 Mpro inhibitors: + +Target: SARS-CoV-2 Main protease (Mpro) +Crystal structure: PDB 6LU7 + +Compound library: ZINC15 drug-like subset (~100K compounds) +File: 'data/zinc_druglike_100k.smi' (SMILES format) + +Workflow: +1. Protein preparation: + - Remove crystallographic waters (keep catalytic waters) + - Add hydrogens, optimize H-bond network + - Define binding site (residues within 5Å of native ligand) + +2. Ligand preparation: + - Generate 3D coordinates from SMILES + - Enumerate tautomers and protonation states + - Energy minimization + +3. Molecular docking: + - Dock all compounds (AutoDock Vina or Glide) + - Generate top 3 poses per compound + - Score binding affinity + +4. Consensus scoring: + - Combine multiple scoring functions + - Rank compounds by consensus score + +5. ADMET filtering: + - Lipinski's rule of 5 + - BBB permeability (not needed for this target) + - hERG liability (pIC50 > 5) + - CYP450 inhibition prediction + - Toxicity prediction (Tox21) + +6. Visual inspection: + - Top 100 compounds + - Check key interactions (His41, Cys145 catalytic dyad) + - Remove PAINS and frequent hitters + +7. Final selection: + - Top 20 compounds for experimental testing + - Cluster by scaffold diversity + +8. Output: + - Ranked compound list with scores and ADMET properties + - Docking poses (mol2 or PDB format) + - 2D interaction diagrams + - Purchase availability from vendors +""") +``` + +### ADMET Property Prediction + +```python +agent.go(""" +Predict ADMET properties for drug candidate series: + +Input: 'data/compound_series.smi' (25 analogs, SMILES format) +Lead scaffold: Novel kinase inhibitor series + +Properties to predict: +1. Absorption: + - Caco-2 permeability + - Human intestinal absorption (HIA) + - P-glycoprotein substrate + +2. Distribution: + - Plasma protein binding (% bound) + - Volume of distribution (VDss) + - Blood-brain barrier permeability (LogBB) + +3. Metabolism: + - CYP450 substrate (1A2, 2C9, 2C19, 2D6, 3A4) + - CYP450 inhibition (same isoforms) + - Sites of metabolism (SOM prediction) + +4. Excretion: + - Clearance estimation + - Half-life prediction + - Renal excretion likelihood + +5. Toxicity: + - hERG inhibition (cardiotoxicity) + - AMES mutagenicity + - Hepatotoxicity + - Skin sensitization + - Rat acute toxicity (LD50) + +6. Drug-likeness: + - Lipinski's Ro5 + - QED score + - Synthetic accessibility + +Analysis: +- Compare all analogs in the series +- Structure-property relationships +- Identify best balanced compound +- Suggest modifications for improvement + +Output: +- Comprehensive ADMET table +- Radar plots for each compound +- SAR analysis for each property +- Recommendations for next design iteration +""") +``` + +### Lead Optimization + +```python +agent.go(""" +Optimize lead compound balancing potency and selectivity: + +Current lead: +- IC50 (target kinase): 50 nM +- IC50 (off-target kinases): 100-500 nM (poor selectivity) +- Microsomal stability: t1/2 = 20 min (too short) +- Solubility: 5 μM (low) + +Goal: Maintain potency, improve selectivity (>100x), improve PK properties + +Strategy: +1. Analyze current binding mode: + - Docking to target and off-targets + - Identify selectivity-determining residues + - Map interaction hotspots + +2. Design focused library: + - Modifications to improve selectivity: + * Target residues unique to on-target + * Avoid conserved kinase regions + - Modifications to improve solubility: + * Add polar groups to solvent-exposed regions + * Replace lipophilic groups + - Modifications to improve metabolic stability: + * Block metabolically labile positions + * Replace metabolically unstable groups + +3. Virtual enumeration: + - Generate ~200 analogs + - Predict binding affinity (docking) + - Predict ADMET properties + +4. Multi-parameter optimization: + - Calculate MPO score (potency + selectivity + ADMET) + - Pareto optimization + - Select top 20 compounds + +5. Clustering and diversity: + - Ensure structural diversity + - Test different modification strategies + +6. Synthetic feasibility: + - Retrosynthetic analysis + - Flag difficult syntheses + - Prioritize 10 compounds for synthesis + +7. Deliverables: + - Ranked compound designs + - Predicted properties table + - Binding mode visualizations + - Synthetic routes + - Recommended testing cascade +""") +``` + +--- + +## Pathway and Network Analysis + +### Pathway Enrichment Analysis + +```python +agent.go(""" +Perform comprehensive pathway enrichment on differentially expressed genes: + +Input: 'data/DEGs.csv' +Columns: gene_symbol, log2FC, padj +Significant DEGs: padj < 0.05, |log2FC| > 1 +Total: 450 upregulated, 380 downregulated genes + +Background: all detected genes in the experiment (~15,000) + +Analysis: +1. GO enrichment (biological processes): + - Test upregulated and downregulated genes separately + - Use hypergeometric test + - FDR correction (Benjamini-Hochberg) + - Filter: padj < 0.05, fold enrichment > 2 + +2. KEGG pathway enrichment: + - Same approach as GO + - Focus on signaling and metabolic pathways + +3. Reactome pathway enrichment: + - More detailed pathway hierarchy + +4. Disease association: + - DisGeNET disease enrichment + - Compare to disease gene signatures (MSigDB) + +5. Transcription factor enrichment: + - Predict upstream regulators (ChEA3) + - ENCODE ChIP-seq enrichment + +6. Drug/compound perturbations: + - L1000 connectivity map + - Identify drugs that reverse/mimic signature + +7. Cross-pathway analysis: + - Pathway crosstalk + - Hierarchical clustering of pathways by gene overlap + - Network visualization of enriched pathways + +8. Visualization: + - Dot plots (GO, KEGG, Reactome) + - Enrichment map network + - Chord diagram (genes-pathways) + - Treemap of hierarchical GO terms + +9. Export: + - All enrichment tables + - Pathway gene lists + - Interactive HTML report +""") +``` + +### Protein-Protein Interaction Network + +```python +agent.go(""" +Build and analyze PPI network for Alzheimer's disease genes: + +Seed genes: Known AD risk genes (APP, PSEN1, PSEN2, APOE, MAPT, etc.) +File: 'data/AD_seed_genes.txt' + +Network construction: +1. Build network from seed genes: + - Query STRING database (confidence > 0.7) + - Include direct and second-degree interactors + - Maximum network size: 500 proteins + +2. Network enrichment: + - Add disease associations (DisGeNET) + - Add tissue expression (GTEx - prioritize brain) + - Add functional annotations (GO, Reactome) + +3. Network analysis: + - Calculate centrality measures: + * Degree centrality + * Betweenness centrality + * Eigenvector centrality + - Identify hub proteins + - Community detection (Louvain algorithm) + +4. Module analysis: + - Functional enrichment per community + - Identify disease-relevant modules + - Key bridge proteins between modules + +5. Druggability analysis: + - Identify druggable targets (DGIdb) + - Known drugs targeting network proteins + - Clinical trial status + +6. Network perturbation: + - Simulate gene knockout + - Network robustness analysis + - Identify critical nodes + +7. Visualization: + - Interactive network (Cytoscape format) + - Layout by module membership + - Color by centrality/expression + - Size by degree + +8. Prioritization: + - Rank proteins by: + * Network centrality + * Brain expression + * Druggability + * Genetic evidence (GWAS) + - Top therapeutic targets + +Output: +- Network file (graphML, SIF) +- Module membership table +- Prioritized target list +- Druggable targets with existing compounds +""") +``` + +### Gene Regulatory Network Inference + +```python +agent.go(""" +Infer gene regulatory network from scRNA-seq data: + +Data: 'data/development_timecourse.h5ad' +- Cells from 5 developmental timepoints +- 3000 HVGs quantified + +Goal: Identify TF→target relationships during development + +Methods: +1. Preprocessing: + - Select TFs (from TF census list) + - Select potential target genes (HVGs) + - Normalize expression + +2. GRN inference using multiple methods: + - GENIE3 (random forest) + - SCENIC (motif-based) + - CellOracle (perturbation-based) + - Pearson/Spearman correlation (baseline) + +3. Integrate predictions: + - Combine scores from multiple methods + - Weight by motif evidence (JASPAR) + - Filter low-confidence edges + +4. Network refinement: + - Remove indirect edges (transitive reduction) + - Validate with ChIP-seq data (if available) + - Literature validation (TRRUST database) + +5. Dynamic network analysis: + - TF activity per timepoint/cell state + - Identify stage-specific regulators + - Find regulatory switches + +6. Downstream analysis: + - Master regulators (high out-degree) + - Regulatory cascades + - Feed-forward loops + - Coherent vs incoherent motifs + +7. Experimental validation priorities: + - Rank TF→target edges for validation + - Suggest ChIP-seq or CUT&RUN experiments + - Suggest perturbation experiments (knockout/CRISPRi) + +8. Visualization: + - Full GRN network (Cytoscape) + - Key TF subnetworks + - TF activity heatmap across development + - Sankey diagram of regulatory flow + +Output: +- Edge list with confidence scores +- TF activity matrix +- Validated vs novel interactions +- Prioritized validation experiments +""") +``` + +--- + +## Disease Classification + +### Cancer Type Classification from Gene Expression + +```python +agent.go(""" +Build multi-class classifier for cancer type prediction: + +Data: TCGA pan-cancer RNA-seq data +- Training: 8000 samples across 33 cancer types +- Expression: 'data/tcga_expression.csv' (samples × genes) +- Labels: 'data/tcga_labels.csv' (sample_id, cancer_type) + +Task: Classify tumor samples by cancer type + +Pipeline: +1. Data preprocessing: + - Log2(TPM + 1) transformation + - Remove low-variance genes (variance < 0.1) + - Z-score normalization + +2. Feature selection: + - Variance filtering (top 5000 genes) + - Univariate feature selection (ANOVA F-test) + - Select top 500 features + +3. Train-test split: + - 80% train, 20% test + - Stratified by cancer type + +4. Model training (compare multiple algorithms): + - Random Forest + - Gradient Boosting (XGBoost) + - Neural Network (MLP) + - Elastic Net logistic regression + +5. Model evaluation: + - Accuracy, precision, recall per class + - Confusion matrix + - ROC curves (one-vs-rest) + - Feature importance ranking + +6. Model interpretation: + - SHAP values for predictions + - Top predictive genes per cancer type + - Pathway enrichment of predictive features + +7. Clinical validation: + - Test on independent dataset (if available) + - Analyze misclassifications + - Identify hard-to-classify subtypes + +8. Deliverables: + - Trained model (pickle) + - Performance metrics report + - Feature importance table + - Confusion matrix heatmap + - Prediction script for new samples +""") +``` + +### Disease Risk Prediction from Multi-Omics + +```python +agent.go(""" +Develop integrative model predicting cardiovascular disease risk: + +Data sources: +1. Genotypes: 'data/genotypes.csv' (500K SNPs, polygenic risk scores) +2. Clinical: 'data/clinical.csv' (age, sex, BMI, blood pressure, cholesterol) +3. Proteomics: 'data/proteomics.csv' (200 plasma proteins, Olink panel) +4. Metabolomics: 'data/metabolomics.csv' (150 metabolites) + +Outcome: 10-year CVD incidence (binary) +- Cases: 800 +- Controls: 3200 + +Approach: +1. Data preprocessing: + - Impute missing values (missForest) + - Transform skewed features (log/Box-Cox) + - Normalize each omics layer separately + +2. Feature engineering: + - Calculate PRS from SNP data + - Interaction terms (age × metabolites, etc.) + - Metabolite ratios (known CVD markers) + +3. Feature selection per omics: + - Lasso for each data type + - Select informative features + +4. Integration strategies (compare): + - Early integration: concatenate all features + - Late integration: separate models, combine predictions + - Intermediate integration: Multi-omics factor analysis (MOFA) + +5. Model development: + - Logistic regression (interpretable baseline) + - Random Forest + - Elastic Net + - Neural network with omics-specific layers + +6. Cross-validation: + - 5-fold CV, stratified + - Hyperparameter tuning + - Calculate confidence intervals + +7. Model evaluation: + - AUC-ROC, AUC-PR + - Calibration plots + - Net reclassification improvement (NRI) + - Compare to clinical models (Framingham, SCORE) + +8. Interpretation: + - Feature importance (permutation importance) + - SHAP values for individuals + - Identify most informative omics layer + +9. Clinical utility: + - Decision curve analysis + - Risk stratification groups + - Biomarker panel selection + +Outputs: +- Model comparison table +- ROC curves all models +- Feature importance per omics +- Reclassification table +- Clinical implementation recommendations +""") +``` + +--- + +## Multi-Omics Integration + +### Multi-Omics Data Integration + +```python +agent.go(""" +Integrate transcriptomics, proteomics, and metabolomics data: + +Study: Drug response in cancer cell lines +Data: +- RNA-seq: 'data/transcriptomics.csv' (15000 genes × 50 cell lines) +- Proteomics: 'data/proteomics.csv' (3000 proteins × 50 cell lines) +- Metabolomics: 'data/metabolomics.csv' (200 metabolites × 50 cell lines) +- Drug response: 'data/drug_response.csv' (cell line, drug, IC50) + +Goal: Identify multi-omics signatures of drug sensitivity + +Analysis: +1. Data preprocessing: + - Match samples across omics layers + - Filter low-variance features per omics + - Normalize each omics separately (z-score) + +2. Integration methods (compare): + + **Method 1: MOFA (Multi-Omics Factor Analysis)** + - Identify latent factors capturing variance across omics + - Determine factor contributions per omics + - Relate factors to drug response + + **Method 2: DIABLO (sparse PLS-DA)** + - Supervised integration + - Maximize covariance between omics and drug response + - Select features from each omics layer + + **Method 3: Similarity Network Fusion (SNF)** + - Build patient similarity networks per omics + - Fuse networks + - Cluster cell lines by integrated similarity + +3. Association with drug response: + - Correlation of factors/components with IC50 + - Identify drug-sensitive vs resistant groups + - Multi-omics biomarkers + +4. Network analysis: + - Build multi-layer network: + * Gene regulatory network (RNA) + * Protein-protein interactions (proteins) + * Gene-metabolite associations + - Integrate layers + - Find dysregulated pathways + +5. Predictive modeling: + - Train model predicting drug response from multi-omics + - Compare: using all omics vs individual omics + - Feature selection across omics + +6. Biological interpretation: + - Map features to pathways + - Identify mechanism of drug action + - Suggest combination therapies + +7. Validation: + - Leave-one-out cross-validation + - Test in independent cell line panel + +Outputs: +- Factor loadings per omics (MOFA) +- Multi-omics biomarker signature +- Integrated network visualization +- Predictive model of drug response +- Mechanistic hypotheses +""") +``` + +--- + +## Proteomics Analysis + +### Label-Free Quantitative Proteomics + +```python +agent.go(""" +Analyze label-free proteomics data from mass spectrometry: + +Study: Comparison of normal vs diseased tissue (n=6 per group) +Data: MaxQuant output +- 'data/proteinGroups.txt' (MaxQuant protein quantification) +- 'data/peptides.txt' (peptide-level data) + +Experimental design: +- 6 normal samples +- 6 disease samples +- TMT-labeled, 3 fractions each + +Analysis: +1. Data loading and QC: + - Load proteinGroups.txt + - Remove contaminants, reverse hits + - Filter: valid values in ≥50% of samples per group + - Check sample correlations and outliers + - PCA for quality assessment + +2. Imputation: + - Impute missing values (MAR vs MNAR approach) + - Use MinProb for low-abundance missing values + - Use kNN for random missing values + +3. Normalization: + - Median normalization + - Or: VSN (variance stabilizing normalization) + +4. Differential expression: + - Two-sample t-test (for each protein) + - Moderated t-test (limma) + - Filter: |log2FC| > 0.58 (~1.5-fold), adj.p < 0.05 + +5. Visualization: + - Volcano plot + - Heatmap of significant proteins + - PCA colored by condition + - Intensity distributions (before/after normalization) + +6. Functional enrichment: + - GO enrichment (up and down separately) + - KEGG pathways + - Reactome pathways + - STRING PPI network of DEPs + +7. PTM analysis (if available): + - Phosphorylation site analysis + - Kinase enrichment analysis (KEA3) + +8. Orthogonal validation: + - Compare to RNA-seq data (if available) + - Protein-RNA correlation + - Identify discordant genes + +9. Biomarker candidates: + - Rank proteins by fold-change and significance + - Filter for secreted proteins (potential biomarkers) + - Check if targetable (druggable) + +Outputs: +- Differential abundance table +- QC report with plots +- Enrichment analysis results +- PPI network of DEPs +- Candidate biomarkers list +""") +``` + +--- + +## Biomarker Discovery + +### Diagnostic Biomarker Discovery + +```python +agent.go(""" +Discover diagnostic biomarkers for early cancer detection: + +Study: Plasma proteomics comparing early-stage cancer vs healthy controls +Data: +- 'data/proteomics.csv' (1000 proteins × 200 samples) +- 'data/metadata.csv' (sample_id, group [cancer/healthy], age, sex) + +Groups: +- Early-stage cancer: 100 samples +- Healthy controls: 100 samples + +Goal: Identify protein panel for early detection (target AUC > 0.90) + +Workflow: +1. Exploratory analysis: + - PCA, tSNE to visualize separation + - Univariate differential abundance + - Volcano plot + +2. Feature selection: + - Rank proteins by: + * Fold change + * Statistical significance (t-test, Mann-Whitney) + * AUC (each protein individually) + - Select proteins with AUC > 0.70 + +3. Biomarker panel construction: + - Correlation analysis (remove redundant markers) + - Forward selection: + * Start with best single marker + * Add markers improving panel performance + * Stop when no improvement + - Aim for 5-10 marker panel (practical for assay) + +4. Model building: + - Logistic regression on selected panel + - Calculate combined risk score + - Cross-validation (10-fold) + +5. Performance evaluation: + - AUC-ROC, AUC-PR + - Sensitivity/specificity at different thresholds + - Clinical decision threshold (e.g., 90% sensitivity) + - Calibration plot + +6. Biological validation: + - Literature support for cancer association + - Expression in tumor vs blood + - Mechanism of release/shedding + +7. Clinical utility: + - Compare to existing biomarkers (CEA, CA19-9, etc.) + - Cost-effectiveness consideration + - Assay feasibility (ELISA, MRM, etc.) + +8. Independent validation plan: + - Power calculation for validation cohort + - Suggested sample size + - Pre-analytical variables to control + +Outputs: +- Ranked protein list with individual performance +- Final biomarker panel +- Logistic regression model +- ROC curves (individual + panel) +- Clinical characteristics table +- Validation study protocol +""") +``` + +--- + +## Additional Advanced Examples + +### Spatial Transcriptomics Analysis + +```python +agent.go(""" +Analyze Visium spatial transcriptomics data: + +Data: 'data/visium_brain_tumor.h5ad' +- Contains spatial coordinates and gene expression +- Tissue: Brain tumor biopsy + +Analysis: +1. Data QC and normalization: + - Filter low-quality spots (total counts, detected genes) + - Normalize, log-transform + - Calculate spatial statistics + +2. Spatial clustering: + - Graph-based clustering considering spatial proximity + - Identify tumor regions, stroma, necrosis, etc. + +3. Spatially variable genes: + - Test for spatial patterns (Moran's I, SpatialDE) + - Identify genes with spatial gradients + +4. Deconvolution: + - Estimate cell type composition per spot + - Use scRNA-seq reference (if available) + - Methods: Cell2location, RCTD, SPOTlight + +5. Niche analysis: + - Define tissue niches by cell type composition + - Identify tumor-stroma interface + - Analyze cell-cell interactions + +6. Spatial pathway analysis: + - Map pathway activity onto tissue + - Identify spatially localized processes + +7. Visualization: + - Spatial plots colored by cluster, gene expression + - Cell type composition maps + - Pathway activity maps + +Output: +- Annotated spatial data object +- Spatially variable gene list +- Cell type composition per spot +- Niche definitions and cell-cell interactions +""") +``` + +--- + +## Tips for Effective Task Specification + +### 1. Be Specific About Data Formats and Locations + +✅ Good: +```python +agent.go("Analyze scRNA-seq data in AnnData format at 'data/experiment1.h5ad'") +``` + +❌ Vague: +```python +agent.go("Analyze my data") +``` + +### 2. Specify Analysis Parameters + +✅ Good: +```python +agent.go(""" +Cluster cells using Leiden algorithm with resolution 0.5, +k-neighbors=10, using PCA components 1-30 +""") +``` + +❌ Vague: +```python +agent.go("Cluster the cells") +``` + +### 3. Request Specific Outputs + +✅ Good: +```python +agent.go(""" +... and save results as: +- CSV table with statistics +- PNG figures at 300 DPI +- Processed data as AnnData at 'results/processed.h5ad' +""") +``` + +❌ Vague: +```python +agent.go("... and save the results") +``` + +### 4. Provide Biological Context + +✅ Good: +```python +agent.go(""" +This is a drug treatment experiment. Compare vehicle vs treated groups +to identify drug-induced transcriptional changes. Focus on apoptosis and +cell cycle pathways. +""") +``` + +❌ Vague: +```python +agent.go("Compare the two groups") +``` + +### 5. Break Complex Analyses into Steps + +✅ Good: +```python +# Step 1 +agent.go("Load and QC the data, save QC metrics") + +# Step 2 +agent.go("Based on QC, normalize and find HVGs") + +# Step 3 +agent.go("Cluster and annotate cell types") +``` + +❌ Overwhelming: +```python +agent.go("Do a complete scRNA-seq analysis pipeline") +``` diff --git a/scientific-packages/biomni/scripts/generate_report.py b/scientific-packages/biomni/scripts/generate_report.py new file mode 100644 index 0000000..df09085 --- /dev/null +++ b/scientific-packages/biomni/scripts/generate_report.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +""" +Enhanced PDF Report Generation for Biomni + +This script provides advanced PDF report generation with custom formatting, +styling, and metadata for Biomni analysis results. +""" + +import argparse +import sys +from pathlib import Path +from datetime import datetime +from typing import Optional, Dict, Any + + +def generate_markdown_report( + title: str, + sections: list, + metadata: Optional[Dict[str, Any]] = None, + output_path: str = "report.md" +) -> str: + """ + Generate a formatted markdown report. + + Args: + title: Report title + sections: List of dicts with 'heading' and 'content' keys + metadata: Optional metadata dict (author, date, etc.) + output_path: Path to save markdown file + + Returns: + Path to generated markdown file + """ + md_content = [] + + # Title + md_content.append(f"# {title}\n") + + # Metadata + if metadata: + md_content.append("---\n") + for key, value in metadata.items(): + md_content.append(f"**{key}:** {value} \n") + md_content.append("---\n\n") + + # Sections + for section in sections: + heading = section.get('heading', 'Section') + content = section.get('content', '') + level = section.get('level', 2) # Default to h2 + + md_content.append(f"{'#' * level} {heading}\n\n") + md_content.append(f"{content}\n\n") + + # Write to file + output = Path(output_path) + output.write_text('\n'.join(md_content)) + + return str(output) + + +def convert_to_pdf_weasyprint( + markdown_path: str, + output_path: str, + css_style: Optional[str] = None +) -> bool: + """ + Convert markdown to PDF using WeasyPrint. + + Args: + markdown_path: Path to markdown file + output_path: Path for output PDF + css_style: Optional CSS stylesheet path + + Returns: + True if successful, False otherwise + """ + try: + import markdown + from weasyprint import HTML, CSS + + # Read markdown + with open(markdown_path, 'r') as f: + md_content = f.read() + + # Convert to HTML + html_content = markdown.markdown( + md_content, + extensions=['tables', 'fenced_code', 'codehilite'] + ) + + # Wrap in HTML template + html_template = f""" + + + + + Biomni Report + + + + {html_content} + + + """ + + # Generate PDF + pdf = HTML(string=html_template) + + # Add custom CSS if provided + stylesheets = [] + if css_style and Path(css_style).exists(): + stylesheets.append(CSS(filename=css_style)) + + pdf.write_pdf(output_path, stylesheets=stylesheets) + + return True + + except ImportError: + print("Error: WeasyPrint not installed. Install with: pip install weasyprint") + return False + except Exception as e: + print(f"Error generating PDF: {e}") + return False + + +def convert_to_pdf_pandoc(markdown_path: str, output_path: str) -> bool: + """ + Convert markdown to PDF using Pandoc. + + Args: + markdown_path: Path to markdown file + output_path: Path for output PDF + + Returns: + True if successful, False otherwise + """ + try: + import subprocess + + # Check if pandoc is installed + result = subprocess.run( + ['pandoc', '--version'], + capture_output=True, + text=True + ) + + if result.returncode != 0: + print("Error: Pandoc not installed") + return False + + # Convert with pandoc + result = subprocess.run( + [ + 'pandoc', + markdown_path, + '-o', output_path, + '--pdf-engine=pdflatex', + '-V', 'geometry:margin=1in', + '--toc' + ], + capture_output=True, + text=True + ) + + if result.returncode != 0: + print(f"Pandoc error: {result.stderr}") + return False + + return True + + except FileNotFoundError: + print("Error: Pandoc not found. Install from https://pandoc.org/") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def create_biomni_report( + conversation_history: list, + output_path: str = "biomni_report.pdf", + method: str = "weasyprint" +) -> bool: + """ + Create a formatted PDF report from Biomni conversation history. + + Args: + conversation_history: List of conversation turns + output_path: Output PDF path + method: Conversion method ('weasyprint' or 'pandoc') + + Returns: + True if successful + """ + # Prepare report sections + metadata = { + 'Date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'Tool': 'Biomni AI Agent', + 'Report Type': 'Analysis Summary' + } + + sections = [] + + # Executive Summary + sections.append({ + 'heading': 'Executive Summary', + 'level': 2, + 'content': 'This report contains the complete analysis workflow executed by the Biomni biomedical AI agent.' + }) + + # Conversation history + for i, turn in enumerate(conversation_history, 1): + sections.append({ + 'heading': f'Task {i}: {turn.get("task", "Analysis")}', + 'level': 2, + 'content': f'**Input:**\n```\n{turn.get("input", "")}\n```\n\n**Output:**\n{turn.get("output", "")}' + }) + + # Generate markdown + md_path = output_path.replace('.pdf', '.md') + generate_markdown_report( + title="Biomni Analysis Report", + sections=sections, + metadata=metadata, + output_path=md_path + ) + + # Convert to PDF + if method == 'weasyprint': + success = convert_to_pdf_weasyprint(md_path, output_path) + elif method == 'pandoc': + success = convert_to_pdf_pandoc(md_path, output_path) + else: + print(f"Unknown method: {method}") + return False + + if success: + print(f"✓ Report generated: {output_path}") + print(f" Markdown: {md_path}") + else: + print("✗ Failed to generate PDF") + print(f" Markdown available: {md_path}") + + return success + + +def main(): + """CLI for report generation.""" + parser = argparse.ArgumentParser( + description='Generate formatted PDF reports for Biomni analyses' + ) + + parser.add_argument( + 'input', + type=str, + help='Input markdown file or conversation history' + ) + + parser.add_argument( + '-o', '--output', + type=str, + default='biomni_report.pdf', + help='Output PDF path (default: biomni_report.pdf)' + ) + + parser.add_argument( + '-m', '--method', + type=str, + choices=['weasyprint', 'pandoc'], + default='weasyprint', + help='Conversion method (default: weasyprint)' + ) + + parser.add_argument( + '--css', + type=str, + help='Custom CSS stylesheet path' + ) + + args = parser.parse_args() + + # Check if input is markdown or conversation history + input_path = Path(args.input) + + if not input_path.exists(): + print(f"Error: Input file not found: {args.input}") + return 1 + + # If input is markdown, convert directly + if input_path.suffix == '.md': + if args.method == 'weasyprint': + success = convert_to_pdf_weasyprint( + str(input_path), + args.output, + args.css + ) + else: + success = convert_to_pdf_pandoc(str(input_path), args.output) + + return 0 if success else 1 + + # Otherwise, assume it's conversation history (JSON) + try: + import json + with open(input_path) as f: + history = json.load(f) + + success = create_biomni_report( + history, + args.output, + args.method + ) + + return 0 if success else 1 + + except json.JSONDecodeError: + print("Error: Input file is not valid JSON or markdown") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scientific-packages/biomni/scripts/setup_environment.py b/scientific-packages/biomni/scripts/setup_environment.py new file mode 100644 index 0000000..cf3e1f2 --- /dev/null +++ b/scientific-packages/biomni/scripts/setup_environment.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Biomni Environment Setup and Validation Script + +This script helps users set up and validate their Biomni environment, +including checking dependencies, API keys, and data availability. +""" + +import os +import sys +import subprocess +from pathlib import Path +from typing import Dict, List, Tuple + + +def check_python_version() -> Tuple[bool, str]: + """Check if Python version is compatible.""" + version = sys.version_info + if version.major == 3 and version.minor >= 8: + return True, f"Python {version.major}.{version.minor}.{version.micro} ✓" + else: + return False, f"Python {version.major}.{version.minor} - requires Python 3.8+" + + +def check_conda_env() -> Tuple[bool, str]: + """Check if running in biomni conda environment.""" + conda_env = os.environ.get('CONDA_DEFAULT_ENV', None) + if conda_env == 'biomni_e1': + return True, f"Conda environment: {conda_env} ✓" + else: + return False, f"Not in biomni_e1 environment (current: {conda_env})" + + +def check_package_installed(package: str) -> bool: + """Check if a Python package is installed.""" + try: + __import__(package) + return True + except ImportError: + return False + + +def check_dependencies() -> Tuple[bool, List[str]]: + """Check for required and optional dependencies.""" + required = ['biomni'] + optional = ['weasyprint', 'markdown2pdf'] + + missing_required = [pkg for pkg in required if not check_package_installed(pkg)] + missing_optional = [pkg for pkg in optional if not check_package_installed(pkg)] + + messages = [] + success = len(missing_required) == 0 + + if missing_required: + messages.append(f"Missing required packages: {', '.join(missing_required)}") + messages.append("Install with: pip install biomni --upgrade") + else: + messages.append("Required packages: ✓") + + if missing_optional: + messages.append(f"Missing optional packages: {', '.join(missing_optional)}") + messages.append("For PDF reports, install: pip install weasyprint") + + return success, messages + + +def check_api_keys() -> Tuple[bool, Dict[str, bool]]: + """Check which API keys are configured.""" + api_keys = { + 'ANTHROPIC_API_KEY': os.environ.get('ANTHROPIC_API_KEY'), + 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'), + 'GEMINI_API_KEY': os.environ.get('GEMINI_API_KEY'), + 'GROQ_API_KEY': os.environ.get('GROQ_API_KEY'), + } + + configured = {key: bool(value) for key, value in api_keys.items()} + has_any = any(configured.values()) + + return has_any, configured + + +def check_data_directory(data_path: str = './data') -> Tuple[bool, str]: + """Check if Biomni data directory exists and has content.""" + path = Path(data_path) + + if not path.exists(): + return False, f"Data directory not found at {data_path}" + + # Check if directory has files (data has been downloaded) + files = list(path.glob('*')) + if len(files) == 0: + return False, f"Data directory exists but is empty. Run agent once to download." + + # Rough size check (should be ~11GB) + total_size = sum(f.stat().st_size for f in path.rglob('*') if f.is_file()) + size_gb = total_size / (1024**3) + + if size_gb < 1: + return False, f"Data directory exists but seems incomplete ({size_gb:.1f} GB)" + + return True, f"Data directory: {data_path} ({size_gb:.1f} GB) ✓" + + +def check_disk_space(required_gb: float = 20) -> Tuple[bool, str]: + """Check if sufficient disk space is available.""" + try: + import shutil + stat = shutil.disk_usage('.') + free_gb = stat.free / (1024**3) + + if free_gb >= required_gb: + return True, f"Disk space: {free_gb:.1f} GB available ✓" + else: + return False, f"Low disk space: {free_gb:.1f} GB (need {required_gb} GB)" + except Exception as e: + return False, f"Could not check disk space: {e}" + + +def test_biomni_import() -> Tuple[bool, str]: + """Test if Biomni can be imported and initialized.""" + try: + from biomni.agent import A1 + from biomni.config import default_config + return True, "Biomni import successful ✓" + except ImportError as e: + return False, f"Cannot import Biomni: {e}" + except Exception as e: + return False, f"Biomni import error: {e}" + + +def suggest_fixes(results: Dict[str, Tuple[bool, any]]) -> List[str]: + """Generate suggestions for fixing issues.""" + suggestions = [] + + if not results['python'][0]: + suggestions.append("➜ Upgrade Python to 3.8 or higher") + + if not results['conda'][0]: + suggestions.append("➜ Activate biomni environment: conda activate biomni_e1") + + if not results['dependencies'][0]: + suggestions.append("➜ Install Biomni: pip install biomni --upgrade") + + if not results['api_keys'][0]: + suggestions.append("➜ Set API key: export ANTHROPIC_API_KEY='your-key'") + suggestions.append(" Or create .env file with API keys") + + if not results['data'][0]: + suggestions.append("➜ Data will auto-download on first agent.go() call") + + if not results['disk_space'][0]: + suggestions.append("➜ Free up disk space (need ~20GB total)") + + return suggestions + + +def main(): + """Run all environment checks and display results.""" + print("=" * 60) + print("Biomni Environment Validation") + print("=" * 60) + print() + + # Run all checks + results = {} + + print("Checking Python version...") + results['python'] = check_python_version() + print(f" {results['python'][1]}") + print() + + print("Checking conda environment...") + results['conda'] = check_conda_env() + print(f" {results['conda'][1]}") + print() + + print("Checking dependencies...") + results['dependencies'] = check_dependencies() + for msg in results['dependencies'][1]: + print(f" {msg}") + print() + + print("Checking API keys...") + results['api_keys'] = check_api_keys() + has_keys, key_status = results['api_keys'] + for key, configured in key_status.items(): + status = "✓" if configured else "✗" + print(f" {key}: {status}") + print() + + print("Checking Biomni data directory...") + results['data'] = check_data_directory() + print(f" {results['data'][1]}") + print() + + print("Checking disk space...") + results['disk_space'] = check_disk_space() + print(f" {results['disk_space'][1]}") + print() + + print("Testing Biomni import...") + results['biomni_import'] = test_biomni_import() + print(f" {results['biomni_import'][1]}") + print() + + # Summary + print("=" * 60) + all_passed = all(result[0] for result in results.values()) + + if all_passed: + print("✓ All checks passed! Environment is ready.") + print() + print("Quick start:") + print(" from biomni.agent import A1") + print(" agent = A1(path='./data', llm='claude-sonnet-4-20250514')") + print(" agent.go('Your biomedical task')") + else: + print("⚠ Some checks failed. See suggestions below:") + print() + suggestions = suggest_fixes(results) + for suggestion in suggestions: + print(suggestion) + + print("=" * 60) + + return 0 if all_passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scientific-packages/biopython/SKILL.md b/scientific-packages/biopython/SKILL.md new file mode 100644 index 0000000..b0c7415 --- /dev/null +++ b/scientific-packages/biopython/SKILL.md @@ -0,0 +1,450 @@ +--- +name: biopython +description: Comprehensive toolkit for computational molecular biology using BioPython. Use this skill when working with biological sequences (DNA, RNA, protein), parsing sequence files (FASTA, GenBank, FASTQ), accessing NCBI databases (Entrez, BLAST), performing sequence alignments, building phylogenetic trees, analyzing protein structures (PDB), or any bioinformatics task requiring BioPython modules. +--- + +# BioPython + +## Overview + +BioPython is a comprehensive Python library for computational molecular biology and bioinformatics. This skill provides guidance on using BioPython's extensive modules for sequence manipulation, file I/O, database access, sequence similarity searches, alignments, phylogenetics, structural biology, and population genetics. + +## When to Use This Skill + +Use this skill when: +- Working with biological sequences (DNA, RNA, protein) +- Reading or writing sequence files (FASTA, GenBank, FASTQ, etc.) +- Accessing NCBI databases (GenBank, PubMed, Protein, Nucleotide) +- Running or parsing BLAST searches +- Performing sequence alignments (pairwise or multiple) +- Building or analyzing phylogenetic trees +- Analyzing protein structures (PDB files) +- Calculating sequence properties (GC content, melting temp, molecular weight) +- Converting between sequence file formats +- Performing population genetics analysis +- Any bioinformatics task requiring BioPython + +## Core Capabilities + +### 1. Sequence Manipulation + +Create and manipulate biological sequences using `Bio.Seq`: + +```python +from Bio.Seq import Seq + +dna_seq = Seq("ATGGTGCATCTGACT") +rna_seq = dna_seq.transcribe() # DNA → RNA +protein = dna_seq.translate() # DNA → Protein +rev_comp = dna_seq.reverse_complement() # Reverse complement +``` + +**Common operations:** +- Transcription and back-transcription +- Translation with custom genetic codes +- Complement and reverse complement +- Sequence slicing and concatenation +- Pattern searching and counting + +**Reference:** See `references/core_modules.md` (section: Bio.Seq) for detailed operations and examples. + +### 2. File Input/Output + +Read and write sequence files in multiple formats using `Bio.SeqIO`: + +```python +from Bio import SeqIO + +# Read sequences +for record in SeqIO.parse("sequences.fasta", "fasta"): + print(record.id, len(record.seq)) + +# Write sequences +SeqIO.write(records, "output.gb", "genbank") + +# Convert formats +SeqIO.convert("input.fasta", "fasta", "output.gb", "genbank") +``` + +**Supported formats:** FASTA, FASTQ, GenBank, EMBL, Swiss-Prot, PDB, Clustal, PHYLIP, NEXUS, Stockholm, and many more. + +**Common workflows:** +- Format conversion (FASTA ↔ GenBank ↔ FASTQ) +- Filtering sequences by length, ID, or content +- Batch processing large files with iterators +- Random access with `SeqIO.index()` for large files + +**Script:** Use `scripts/file_io.py` for file I/O examples and patterns. + +**Reference:** See `references/core_modules.md` (section: Bio.SeqIO) for comprehensive format details and workflows. + +### 3. NCBI Database Access + +Access NCBI databases (GenBank, PubMed, Protein, etc.) using `Bio.Entrez`: + +```python +from Bio import Entrez + +Entrez.email = "your.email@example.com" # Required! + +# Search database +handle = Entrez.esearch(db="nucleotide", term="human kinase", retmax=100) +record = Entrez.read(handle) +id_list = record["IdList"] + +# Fetch sequences +handle = Entrez.efetch(db="nucleotide", id=id_list, rettype="fasta", retmode="text") +records = SeqIO.parse(handle, "fasta") +``` + +**Key Entrez functions:** +- `esearch()`: Search databases, retrieve IDs +- `efetch()`: Download full records +- `esummary()`: Get document summaries +- `elink()`: Find related records across databases +- `einfo()`: Get database information +- `epost()`: Upload ID lists for large queries + +**Important:** Always set `Entrez.email` before using Entrez functions. + +**Script:** Use `scripts/ncbi_entrez.py` for complete Entrez workflows including batch downloads and WebEnv usage. + +**Reference:** See `references/database_tools.md` (section: Bio.Entrez) for detailed function documentation and parameters. + +### 4. BLAST Searches + +Run BLAST searches and parse results using `Bio.Blast`: + +```python +from Bio.Blast import NCBIWWW, NCBIXML + +# Run BLAST online +result_handle = NCBIWWW.qblast("blastn", "nt", sequence) + +# Save results +with open("blast_results.xml", "w") as out: + out.write(result_handle.read()) + +# Parse results +with open("blast_results.xml") as result_handle: + blast_record = NCBIXML.read(result_handle) + + for alignment in blast_record.alignments: + for hsp in alignment.hsps: + if hsp.expect < 0.001: + print(f"Hit: {alignment.title}") + print(f"E-value: {hsp.expect}") + print(f"Identity: {hsp.identities}/{hsp.align_length}") +``` + +**BLAST programs:** blastn, blastp, blastx, tblastn, tblastx + +**Key result attributes:** +- `alignment.title`: Hit description +- `hsp.expect`: E-value +- `hsp.identities`: Number of identical residues +- `hsp.query`, `hsp.match`, `hsp.sbjct`: Aligned sequences + +**Script:** Use `scripts/blast_search.py` for complete BLAST workflows including result filtering and extraction. + +**Reference:** See `references/database_tools.md` (section: Bio.Blast) for detailed parsing and filtering strategies. + +### 5. Sequence Alignment + +Perform pairwise and multiple sequence alignments using `Bio.Align`: + +**Pairwise alignment:** +```python +from Bio import Align + +aligner = Align.PairwiseAligner() +aligner.mode = 'global' # or 'local' +aligner.match_score = 2 +aligner.mismatch_score = -1 +aligner.gap_score = -2 + +alignments = aligner.align(seq1, seq2) +print(alignments[0]) +print(f"Score: {alignments.score}") +``` + +**Multiple sequence alignment I/O:** +```python +from Bio import AlignIO + +# Read alignment +alignment = AlignIO.read("alignment.clustal", "clustal") + +# Write alignment +AlignIO.write(alignment, "output.phylip", "phylip") + +# Convert formats +AlignIO.convert("input.clustal", "clustal", "output.fasta", "fasta") +``` + +**Supported formats:** Clustal, PHYLIP, Stockholm, NEXUS, FASTA, MAF + +**Script:** Use `scripts/alignment_phylogeny.py` for alignment examples and workflows. + +**Reference:** See `references/core_modules.md` (sections: Bio.Align, Bio.AlignIO) for detailed alignment capabilities. + +### 6. Phylogenetic Analysis + +Build and analyze phylogenetic trees using `Bio.Phylo`: + +```python +from Bio import Phylo +from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor + +# Read alignment +alignment = AlignIO.read("sequences.fasta", "fasta") + +# Calculate distance matrix +calculator = DistanceCalculator('identity') +dm = calculator.get_distance(alignment) + +# Build tree (UPGMA or Neighbor-Joining) +constructor = DistanceTreeConstructor(calculator) +tree = constructor.upgma(dm) # or constructor.nj(dm) + +# Visualize tree +Phylo.draw_ascii(tree) +Phylo.draw(tree) # matplotlib visualization + +# Save tree +Phylo.write(tree, "tree.nwk", "newick") +``` + +**Tree manipulation:** +- `tree.ladderize()`: Sort branches +- `tree.root_at_midpoint()`: Root at midpoint +- `tree.prune()`: Remove taxa +- `tree.collapse_all()`: Collapse short branches +- `tree.distance()`: Calculate distances between clades + +**Supported formats:** Newick, NEXUS, PhyloXML, NeXML + +**Script:** Use `scripts/alignment_phylogeny.py` for tree construction and manipulation examples. + +**Reference:** See `references/specialized_modules.md` (section: Bio.Phylo) for comprehensive tree analysis capabilities. + +### 7. Structural Bioinformatics + +Analyze protein structures using `Bio.PDB`: + +```python +from Bio.PDB import PDBParser, PDBList + +# Download structure +pdbl = PDBList() +pdbl.retrieve_pdb_file("1ABC", file_format="pdb", pdir=".") + +# Parse structure +parser = PDBParser() +structure = parser.get_structure("protein", "1abc.pdb") + +# Navigate hierarchy: Structure → Model → Chain → Residue → Atom +for model in structure: + for chain in model: + for residue in chain: + for atom in residue: + print(atom.name, atom.coord) + +# Secondary structure with DSSP +from Bio.PDB import DSSP +dssp = DSSP(model, "structure.pdb") + +# Structural alignment +from Bio.PDB import Superimposer +sup = Superimposer() +sup.set_atoms(ref_atoms, alt_atoms) +print(f"RMSD: {sup.rms}") +``` + +**Key capabilities:** +- Parse PDB, mmCIF, MMTF formats +- Secondary structure analysis (DSSP) +- Solvent accessibility calculations +- Structural superimposition +- Distance and angle calculations +- Structure quality validation + +**Reference:** See `references/specialized_modules.md` (section: Bio.PDB) for complete structural analysis capabilities. + +### 8. Sequence Analysis Utilities + +Calculate sequence properties using `Bio.SeqUtils`: + +```python +from Bio.SeqUtils import gc_fraction, MeltingTemp as mt +from Bio.SeqUtils.ProtParam import ProteinAnalysis + +# DNA analysis +gc = gc_fraction(dna_seq) * 100 +tm = mt.Tm_NN(dna_seq) # Melting temperature + +# Protein analysis +protein_analysis = ProteinAnalysis(str(protein_seq)) +mw = protein_analysis.molecular_weight() +pi = protein_analysis.isoelectric_point() +aromaticity = protein_analysis.aromaticity() +instability = protein_analysis.instability_index() +``` + +**Available analyses:** +- GC content and GC skew +- Melting temperature (multiple methods) +- Molecular weight +- Isoelectric point +- Aromaticity +- Instability index +- Secondary structure prediction +- Sequence checksums + +**Script:** Use `scripts/sequence_operations.py` for sequence analysis examples. + +**Reference:** See `references/core_modules.md` (section: Bio.SeqUtils) for all available utilities. + +### 9. Specialized Modules + +**Restriction enzymes:** +```python +from Bio import Restriction +enzyme = Restriction.EcoRI +sites = enzyme.search(seq) +``` + +**Motif analysis:** +```python +from Bio import motifs +m = motifs.create([seq1, seq2, seq3]) +pwm = m.counts.normalize(pseudocounts=0.5) +``` + +**Population genetics:** +Use `Bio.PopGen` for allele frequencies, Hardy-Weinberg equilibrium, FST calculations. + +**Clustering:** +Use `Bio.Cluster` for hierarchical clustering, k-means, PCA on biological data. + +**Reference:** See `references/core_modules.md` and `references/specialized_modules.md` for specialized module documentation. + +## Common Workflows + +### Workflow 1: Download and Analyze NCBI Sequences + +1. Search NCBI database with `Entrez.esearch()` +2. Fetch sequences with `Entrez.efetch()` +3. Parse with `SeqIO.parse()` +4. Analyze sequences (GC content, translation, etc.) +5. Save results to file + +**Script:** Use `scripts/ncbi_entrez.py` for complete implementation. + +### Workflow 2: Sequence Similarity Search + +1. Run BLAST with `NCBIWWW.qblast()` or parse existing results +2. Parse XML results with `NCBIXML.read()` +3. Filter hits by E-value, identity, coverage +4. Extract and save significant hits +5. Perform downstream analysis + +**Script:** Use `scripts/blast_search.py` for complete implementation. + +### Workflow 3: Phylogenetic Tree Construction + +1. Read multiple sequence alignment with `AlignIO.read()` +2. Calculate distance matrix with `DistanceCalculator` +3. Build tree with `DistanceTreeConstructor` (UPGMA or NJ) +4. Manipulate tree (ladderize, root, prune) +5. Visualize with `Phylo.draw()` or `Phylo.draw_ascii()` +6. Save tree with `Phylo.write()` + +**Script:** Use `scripts/alignment_phylogeny.py` for complete implementation. + +### Workflow 4: Format Conversion Pipeline + +1. Read sequences in original format with `SeqIO.parse()` +2. Filter or modify sequences as needed +3. Write to new format with `SeqIO.write()` +4. Or use `SeqIO.convert()` for direct conversion + +**Script:** Use `scripts/file_io.py` for format conversion examples. + +## Best Practices + +### Email Configuration +Always set `Entrez.email` before using NCBI services: +```python +Entrez.email = "your.email@example.com" +``` + +### Rate Limiting +Be polite to NCBI servers: +- Use `time.sleep()` between requests +- Use WebEnv for large queries +- Batch downloads in reasonable chunks (100-500 sequences) + +### Memory Management +For large files: +- Use iterators (`SeqIO.parse()`) instead of lists +- Use `SeqIO.index()` for random access without loading entire file +- Process in batches when possible + +### Error Handling +Always handle potential errors: +```python +try: + record = SeqIO.read(handle, format) +except Exception as e: + print(f"Error: {e}") +``` + +### File Format Selection +Choose appropriate formats: +- FASTA: Simple sequences, no annotations +- GenBank: Rich annotations, features, references +- FASTQ: Sequences with quality scores +- PDB: 3D structural data + +## Resources + +### scripts/ +Executable Python scripts demonstrating common BioPython workflows: + +- `sequence_operations.py`: Basic sequence manipulation (transcription, translation, complement, GC content, melting temp) +- `file_io.py`: Reading, writing, and converting sequence files; filtering; indexing large files +- `ncbi_entrez.py`: Searching and downloading from NCBI databases; batch processing with WebEnv +- `blast_search.py`: Running BLAST searches online; parsing and filtering results +- `alignment_phylogeny.py`: Pairwise and multiple sequence alignment; phylogenetic tree construction and manipulation + +Run any script with `python3 scripts/.py` to see examples. + +### references/ +Comprehensive reference documentation for BioPython modules: + +- `core_modules.md`: Core sequence handling (Seq, SeqRecord, SeqIO, AlignIO, Align, SeqUtils, CodonTable, motifs, Restriction) +- `database_tools.md`: Database access and searches (Entrez, BLAST, SearchIO, BioSQL) +- `specialized_modules.md`: Advanced analyses (PDB, Phylo, PAML, PopGen, Cluster, Graphics) + +Reference these files when: +- Learning about specific module capabilities +- Looking up function parameters and options +- Understanding supported file formats +- Finding example code patterns + +Use `grep` to search references for specific topics: +```bash +grep -n "secondary structure" references/specialized_modules.md +grep -n "efetch" references/database_tools.md +``` + +## Additional Resources + +**Official Documentation:** https://biopython.org/docs/latest/ + +**Tutorial:** https://biopython.org/docs/latest/Tutorial/index.html + +**API Reference:** https://biopython.org/docs/latest/api/index.html + +**Cookbook:** https://biopython.org/wiki/Category:Cookbook diff --git a/scientific-packages/biopython/references/core_modules.md b/scientific-packages/biopython/references/core_modules.md new file mode 100644 index 0000000..6c33aa4 --- /dev/null +++ b/scientific-packages/biopython/references/core_modules.md @@ -0,0 +1,232 @@ +# BioPython Core Modules Reference + +This document provides detailed information about BioPython's core modules and their capabilities. + +## Sequence Handling + +### Bio.Seq - Sequence Objects + +Seq objects are BioPython's fundamental data structure for biological sequences, providing biological methods on top of string-like behavior. + +**Creation:** +```python +from Bio.Seq import Seq +my_seq = Seq("AGTACACTGGT") +``` + +**Key Operations:** +- String methods: `find()`, `count()`, `count_overlap()` (for overlapping patterns) +- Complement/Reverse complement: Returns complementary sequences +- Transcription: DNA → RNA (T → U) +- Back transcription: RNA → DNA +- Translation: DNA/RNA → protein with customizable genetic codes and stop codon handling + +**Use Cases:** +- DNA/RNA sequence manipulation +- Converting between nucleic acid types +- Protein translation from coding sequences +- Sequence searching and pattern counting + +### Bio.SeqRecord - Sequence Metadata + +SeqRecord wraps Seq objects with metadata like ID, description, and features. + +**Attributes:** +- `seq`: The sequence itself (Seq object) +- `id`: Unique identifier +- `name`: Short name +- `description`: Longer description +- `features`: List of SeqFeature objects +- `annotations`: Dictionary of additional information +- `letter_annotations`: Per-letter annotations (e.g., quality scores) + +### Bio.SeqFeature - Sequence Annotations + +Manages sequence annotations and features such as genes, promoters, and coding regions. + +**Common Features:** +- Gene locations +- CDS (coding sequences) +- Promoters and regulatory elements +- Exons and introns +- Protein domains + +## File Input/Output + +### Bio.SeqIO - Sequence File I/O + +Unified interface for reading and writing sequence files in multiple formats. + +**Supported Formats:** +- FASTA/FASTQ: Standard sequence formats +- GenBank/EMBL: Feature-rich annotation formats +- Clustal/Stockholm/PHYLIP: Alignment formats +- ABI/SFF: Trace and flowgram data +- Swiss-Prot/PIR: Protein databases +- PDB: Protein structure files + +**Key Functions:** + +**SeqIO.parse()** - Iterator for reading multiple records: +```python +from Bio import SeqIO +for record in SeqIO.parse("file.fasta", "fasta"): + print(record.id, len(record.seq)) +``` + +**SeqIO.read()** - Read single record: +```python +record = SeqIO.read("file.fasta", "fasta") +``` + +**SeqIO.write()** - Write sequences: +```python +SeqIO.write(sequences, "output.fasta", "fasta") +``` + +**SeqIO.convert()** - Direct format conversion: +```python +count = SeqIO.convert("input.gb", "genbank", "output.fasta", "fasta") +``` + +**SeqIO.index()** - Memory-efficient random access for large files: +```python +record_dict = SeqIO.index("large_file.fasta", "fasta") +sequence = record_dict["seq_id"] +``` + +**SeqIO.to_dict()** - Load all records into dictionary (memory-based): +```python +record_dict = SeqIO.to_dict(SeqIO.parse("file.fasta", "fasta")) +``` + +**Common Patterns:** +- Format conversion between FASTA, GenBank, FASTQ +- Filtering sequences by length, ID, or content +- Extracting subsequences +- Batch processing large files with iterators + +### Bio.AlignIO - Multiple Sequence Alignment I/O + +Handles multiple sequence alignment files. + +**Key Functions:** +- `write()`: Save alignments +- `parse()`: Read multiple alignments +- `read()`: Read single alignment +- `convert()`: Convert between formats + +**Supported Formats:** +- Clustal +- PHYLIP (sequential and interleaved) +- Stockholm +- NEXUS +- FASTA (aligned) +- MAF (Multiple Alignment Format) + +## Sequence Alignment + +### Bio.Align - Alignment Tools + +**PairwiseAligner** - High-performance pairwise alignment: +```python +from Bio import Align +aligner = Align.PairwiseAligner() +aligner.mode = 'global' # or 'local' +aligner.match_score = 2 +aligner.mismatch_score = -1 +aligner.gap_score = -2.5 +alignments = aligner.align(seq1, seq2) +``` + +**CodonAligner** - Codon-aware alignment + +**MultipleSeqAlignment** - Container for MSA with column access + +### Bio.pairwise2 (Legacy) + +Legacy pairwise alignment module with functions like `align.globalxx()`, `align.localxx()`. + +## Sequence Analysis Utilities + +### Bio.SeqUtils - Sequence Analysis + +Collection of utility functions: + +**CheckSum** - Calculate sequence checksums (CRC32, CRC64, GCG) + +**MeltingTemp** - DNA melting temperature calculations: +- Nearest-neighbor method +- Wallace rule +- GC content method + +**IsoelectricPoint** - Protein pI calculation + +**ProtParam** - Protein analysis: +- Molecular weight +- Aromaticity +- Instability index +- Secondary structure fractions + +**GC/GC_skew** - Calculate GC content and GC skew for sequence windows + +### Bio.Data.CodonTable - Genetic Codes + +Access to NCBI genetic code tables: +```python +from Bio.Data import CodonTable +standard_table = CodonTable.unambiguous_dna_by_id[1] +print(standard_table.forward_table) # codon to amino acid +print(standard_table.back_table) # amino acid to codons +print(standard_table.start_codons) +print(standard_table.stop_codons) +``` + +**Available codes:** +- Standard code (1) +- Vertebrate mitochondrial (2) +- Yeast mitochondrial (3) +- And many more organism-specific codes + +## Sequence Motifs and Patterns + +### Bio.motifs - Sequence Motif Analysis + +Tools for working with sequence motifs: + +**Position Weight Matrices (PWM):** +- Create PWM from aligned sequences +- Calculate information content +- Search sequences for motif matches +- Generate consensus sequences + +**Position Specific Scoring Matrices (PSSM):** +- Convert PWM to PSSM +- Score sequences against motifs +- Determine significance thresholds + +**Supported Formats:** +- JASPAR +- TRANSFAC +- MEME +- AlignAce + +### Bio.Restriction - Restriction Enzymes + +Comprehensive restriction enzyme database and analysis: + +**Capabilities:** +- Search for restriction sites +- Predict digestion products +- Analyze restriction maps +- Access enzyme properties (recognition site, cut positions, isoschizomers) + +**Example usage:** +```python +from Bio import Restriction +from Bio.Seq import Seq + +seq = Seq("GAATTC...") +enzyme = Restriction.EcoRI +results = enzyme.search(seq) +``` diff --git a/scientific-packages/biopython/references/database_tools.md b/scientific-packages/biopython/references/database_tools.md new file mode 100644 index 0000000..f9e77a4 --- /dev/null +++ b/scientific-packages/biopython/references/database_tools.md @@ -0,0 +1,306 @@ +# BioPython Database Access and Search Tools + +This document covers BioPython's capabilities for accessing biological databases and performing sequence searches. + +## NCBI Database Access + +### Bio.Entrez - NCBI E-utilities Interface + +Provides programmatic access to NCBI databases including PubMed, GenBank, Protein, Nucleotide, and more. + +**Important:** Always set your email before using Entrez: +```python +from Bio import Entrez +Entrez.email = "your.email@example.com" +``` + +#### Core Query Functions + +**esearch** - Search databases and retrieve IDs: +```python +handle = Entrez.esearch(db="nucleotide", term="Homo sapiens[Organism] AND COX1") +record = Entrez.read(handle) +id_list = record["IdList"] +``` + +Parameters: +- `db`: Database to search (nucleotide, protein, pubmed, etc.) +- `term`: Search query +- `retmax`: Maximum number of IDs to return +- `sort`: Sort order (relevance, pub_date, etc.) +- `usehistory`: Store results on server (useful for large queries) + +**efetch** - Retrieve full records: +```python +handle = Entrez.efetch(db="nucleotide", id="123456", rettype="gb", retmode="text") +record = SeqIO.read(handle, "genbank") +``` + +Parameters: +- `db`: Database name +- `id`: Single ID or comma-separated list +- `rettype`: Return type (gb, fasta, gp, xml, etc.) +- `retmode`: Return mode (text, xml, asn.1) +- Automatically uses POST for >200 IDs + +**elink** - Find related records across databases: +```python +handle = Entrez.elink(dbfrom="protein", db="gene", id="15718680") +result = Entrez.read(handle) +``` + +Parameters: +- `dbfrom`: Source database +- `db`: Target database +- `id`: ID(s) to link from +- Returns LinkOut providers and relevancy scores + +**esummary** - Get document summaries: +```python +handle = Entrez.esummary(db="protein", id="15718680") +summary = Entrez.read(handle) +print(summary[0]['Title']) +``` + +Returns quick overviews without full records. + +**einfo** - Get database statistics: +```python +handle = Entrez.einfo(db="nucleotide") +info = Entrez.read(handle) +``` + +Provides field indices, term counts, update dates, and available links. + +**epost** - Upload ID lists to server: +```python +handle = Entrez.epost("nucleotide", id="123456,789012") +result = Entrez.read(handle) +webenv = result["WebEnv"] +query_key = result["QueryKey"] +``` + +Useful for large queries split across multiple requests. + +**espell** - Get spelling suggestions: +```python +handle = Entrez.espell(term="brest cancer") +result = Entrez.read(handle) +print(result["CorrectedQuery"]) # "breast cancer" +``` + +**ecitmatch** - Convert citations to PubMed IDs: +```python +citation = "proc natl acad sci u s a|1991|88|3248|mann bj|" +handle = Entrez.ecitmatch(db="pubmed", bdata=citation) +``` + +#### Data Processing Functions + +**Entrez.read()** - Parse XML to Python dictionary: +```python +handle = Entrez.esearch(db="protein", term="insulin") +record = Entrez.read(handle) +``` + +**Entrez.parse()** - Generator for large XML results: +```python +handle = Entrez.efetch(db="protein", id=id_list, rettype="gp", retmode="xml") +for record in Entrez.parse(handle): + process(record) +``` + +#### Common Workflows + +**Download sequences by accession:** +```python +handle = Entrez.efetch(db="nucleotide", id="NM_001301717", rettype="fasta", retmode="text") +record = SeqIO.read(handle, "fasta") +``` + +**Search and download multiple sequences:** +```python +# Search +search_handle = Entrez.esearch(db="nucleotide", term="human kinase", retmax="100") +search_results = Entrez.read(search_handle) + +# Download +fetch_handle = Entrez.efetch(db="nucleotide", id=search_results["IdList"], rettype="gb", retmode="text") +for record in SeqIO.parse(fetch_handle, "genbank"): + print(record.id) +``` + +**Use WebEnv for large queries:** +```python +# Post IDs +post_handle = Entrez.epost(db="nucleotide", id=",".join(large_id_list)) +post_result = Entrez.read(post_handle) + +# Fetch in batches +batch_size = 500 +for start in range(0, count, batch_size): + fetch_handle = Entrez.efetch( + db="nucleotide", + rettype="fasta", + retmode="text", + retstart=start, + retmax=batch_size, + webenv=post_result["WebEnv"], + query_key=post_result["QueryKey"] + ) + # Process batch +``` + +### Bio.GenBank - GenBank Format Parsing + +Low-level GenBank file parser (SeqIO is usually preferred). + +### Bio.SwissProt - Swiss-Prot/UniProt Parsing + +Parse Swiss-Prot and UniProtKB flat file format: +```python +from Bio import SwissProt +with open("uniprot.dat") as handle: + for record in SwissProt.parse(handle): + print(record.entry_name, record.organism) +``` + +## Sequence Similarity Searches + +### Bio.Blast - BLAST Interface + +Tools for running BLAST searches and parsing results. + +#### Running BLAST + +**NCBI QBLAST (online):** +```python +from Bio.Blast import NCBIWWW +result_handle = NCBIWWW.qblast("blastn", "nt", sequence) +``` + +Parameters: +- Program: blastn, blastp, blastx, tblastn, tblastx +- Database: nt, nr, refseq_rna, pdb, etc. +- Sequence: string or Seq object +- Additional parameters: `expect`, `word_size`, `hitlist_size`, `format_type` + +**Local BLAST:** +Run standalone BLAST from command line, then parse results. + +#### Parsing BLAST Results + +**XML format (recommended):** +```python +from Bio.Blast import NCBIXML + +result_handle = open("blast_results.xml") +blast_records = NCBIXML.parse(result_handle) + +for blast_record in blast_records: + for alignment in blast_record.alignments: + for hsp in alignment.hsps: + if hsp.expect < 0.001: + print(f"Hit: {alignment.title}") + print(f"Length: {alignment.length}") + print(f"E-value: {hsp.expect}") + print(f"Identities: {hsp.identities}/{hsp.align_length}") +``` + +**Functions:** +- `NCBIXML.read()`: Single query +- `NCBIXML.parse()`: Multiple queries (generator) + +**Key Record Attributes:** +- `alignments`: List of matching sequences +- `query`: Query sequence ID +- `query_length`: Length of query + +**Alignment Attributes:** +- `title`: Description of hit +- `length`: Length of hit sequence +- `hsps`: High-scoring segment pairs + +**HSP Attributes:** +- `expect`: E-value +- `score`: Bit score +- `identities`: Number of identical residues +- `positives`: Number of positive scoring matches +- `gaps`: Number of gaps +- `align_length`: Length of alignment +- `query`: Aligned query sequence +- `match`: Match indicators +- `sbjct`: Aligned subject sequence +- `query_start`, `query_end`: Query coordinates +- `sbjct_start`, `sbjct_end`: Subject coordinates + +#### Common BLAST Workflows + +**Find homologs:** +```python +result = NCBIWWW.qblast("blastp", "nr", protein_sequence, expect=1e-10) +with open("results.xml", "w") as out: + out.write(result.read()) +``` + +**Filter results by criteria:** +```python +for alignment in blast_record.alignments: + for hsp in alignment.hsps: + if hsp.expect < 1e-5 and hsp.identities/hsp.align_length > 0.5: + # Process high-quality hits + pass +``` + +### Bio.SearchIO - Unified Search Results Parser + +Modern interface for parsing various search tool outputs (BLAST, HMMER, BLAT, etc.). + +**Key Functions:** +- `read()`: Parse single query +- `parse()`: Parse multiple queries (generator) +- `write()`: Write results to file +- `convert()`: Convert between formats + +**Supported Tools:** +- BLAST (XML, tabular, plain text) +- HMMER (hmmscan, hmmsearch, phmmer) +- BLAT +- FASTA +- InterProScan +- Exonerate + +**Example:** +```python +from Bio import SearchIO +results = SearchIO.parse("blast_output.xml", "blast-xml") +for result in results: + for hit in result: + if hit.hsps[0].evalue < 0.001: + print(hit.id, hit.hsps[0].evalue) +``` + +## Local Database Management + +### BioSQL - SQL Database Interface + +Store and manage biological sequences in SQL databases (PostgreSQL, MySQL, SQLite). + +**Features:** +- Store SeqRecord objects with annotations +- Efficient querying and retrieval +- Cross-reference sequences +- Track relationships between sequences + +**Example:** +```python +from BioSQL import BioSeqDatabase +server = BioSeqDatabase.open_database(driver="MySQLdb", user="user", passwd="pass", host="localhost", db="bioseqdb") +db = server["my_db"] + +# Store sequences +db.load(SeqIO.parse("sequences.gb", "genbank")) + +# Query +seq = db.lookup(accession="NC_005816") +``` diff --git a/scientific-packages/biopython/references/specialized_modules.md b/scientific-packages/biopython/references/specialized_modules.md new file mode 100644 index 0000000..f6c6daa --- /dev/null +++ b/scientific-packages/biopython/references/specialized_modules.md @@ -0,0 +1,612 @@ +# BioPython Specialized Analysis Modules + +This document covers BioPython's specialized modules for structural biology, phylogenetics, population genetics, and other advanced analyses. + +## Structural Bioinformatics + +### Bio.PDB - Protein Structure Analysis + +Comprehensive tools for handling macromolecular crystal structures. + +#### Structure Hierarchy + +PDB structures are organized hierarchically: +- **Structure** → Models → Chains → Residues → Atoms + +```python +from Bio.PDB import PDBParser + +parser = PDBParser() +structure = parser.get_structure("protein", "1abc.pdb") + +# Navigate hierarchy +for model in structure: + for chain in model: + for residue in chain: + for atom in residue: + print(atom.coord) # xyz coordinates +``` + +#### Parsing Structure Files + +**PDB format:** +```python +from Bio.PDB import PDBParser +parser = PDBParser(QUIET=True) +structure = parser.get_structure("myprotein", "structure.pdb") +``` + +**mmCIF format:** +```python +from Bio.PDB import MMCIFParser +parser = MMCIFParser(QUIET=True) +structure = parser.get_structure("myprotein", "structure.cif") +``` + +**Fast mmCIF parser:** +```python +from Bio.PDB import FastMMCIFParser +parser = FastMMCIFParser(QUIET=True) +structure = parser.get_structure("myprotein", "structure.cif") +``` + +**MMTF format:** +```python +from Bio.PDB import MMTFParser +parser = MMTFParser() +structure = parser.get_structure("structure.mmtf") +``` + +**Binary CIF:** +```python +from Bio.PDB.binary_cif import BinaryCIFParser +parser = BinaryCIFParser() +structure = parser.get_structure("structure.bcif") +``` + +#### Downloading Structures + +```python +from Bio.PDB import PDBList +pdbl = PDBList() + +# Download specific structure +pdbl.retrieve_pdb_file("1ABC", file_format="pdb", pdir="structures/") + +# Download entire PDB (obsolete entries) +pdbl.download_obsolete_entries(pdir="obsolete/") + +# Update local PDB mirror +pdbl.update_pdb() +``` + +#### Structure Selection and Filtering + +```python +# Select specific chains +chain_A = structure[0]['A'] + +# Select specific residues +residue_10 = chain_A[10] + +# Select specific atoms +ca_atom = residue_10['CA'] + +# Iterate over specific atom types +for atom in structure.get_atoms(): + if atom.name == 'CA': # Alpha carbons only + print(atom.coord) +``` + +**Structure selectors:** +```python +from Bio.PDB.Polypeptide import is_aa + +# Filter by residue type +for residue in structure.get_residues(): + if is_aa(residue): + print(f"Amino acid: {residue.resname}") +``` + +#### Secondary Structure Analysis + +**DSSP integration:** +```python +from Bio.PDB import DSSP + +# Requires DSSP program installed +model = structure[0] +dssp = DSSP(model, "structure.pdb") + +# Access secondary structure +for key in dssp: + secondary_structure = dssp[key][2] + accessibility = dssp[key][3] + print(f"Residue {key}: {secondary_structure}, accessible: {accessibility}") +``` + +DSSP codes: +- H: Alpha helix +- B: Beta bridge +- E: Extended strand (beta sheet) +- G: 3-10 helix +- I: Pi helix +- T: Turn +- S: Bend +- -: Coil + +#### Solvent Accessibility + +**Shrake-Rupley algorithm:** +```python +from Bio.PDB import ShrakeRupley + +sr = ShrakeRupley() +sr.compute(structure, level="R") # R=residue, A=atom, C=chain, M=model, S=structure + +for residue in structure.get_residues(): + print(f"{residue.resname} {residue.id[1]}: {residue.sasa} Ų") +``` + +**NACCESS wrapper:** +```python +from Bio.PDB import NACCESS + +# Requires NACCESS program +naccess = NACCESS("structure.pdb") +for residue_id, data in naccess.items(): + print(f"Residue {residue_id}: {data['all_atoms_abs']} Ų") +``` + +**Half-sphere exposure:** +```python +from Bio.PDB import HSExposure + +# Requires DSSP +model = structure[0] +hse = HSExposure() +hse.calc_hs_exposure(model, "structure.pdb") + +for chain in model: + for residue in chain: + if residue.has_id('EXP_HSE_A_U'): + hse_up = residue.xtra['EXP_HSE_A_U'] + hse_down = residue.xtra['EXP_HSE_A_D'] +``` + +#### Structural Alignment and Superimposition + +**Standard superimposition:** +```python +from Bio.PDB import Superimposer + +sup = Superimposer() +sup.set_atoms(ref_atoms, alt_atoms) # Lists of atoms to align +sup.apply(structure2.get_atoms()) # Apply transformation + +print(f"RMSD: {sup.rms}") +print(f"Rotation matrix: {sup.rotran[0]}") +print(f"Translation vector: {sup.rotran[1]}") +``` + +**QCP (Quaternion Characteristic Polynomial) method:** +```python +from Bio.PDB import QCPSuperimposer + +qcp = QCPSuperimposer() +qcp.set(ref_coords, alt_coords) +qcp.run() +print(f"RMSD: {qcp.get_rms()}") +``` + +#### Geometric Calculations + +**Distances and angles:** +```python +# Distance between atoms +from Bio.PDB import Vector +dist = atom1 - atom2 # Returns distance + +# Angle between three atoms +from Bio.PDB import calc_angle +angle = calc_angle(atom1.coord, atom2.coord, atom3.coord) + +# Dihedral angle +from Bio.PDB import calc_dihedral +dihedral = calc_dihedral(atom1.coord, atom2.coord, atom3.coord, atom4.coord) +``` + +**Vector operations:** +```python +from Bio.PDB.Vector import Vector + +v1 = Vector(atom1.coord) +v2 = Vector(atom2.coord) + +# Vector operations +v3 = v1 + v2 +v4 = v1 - v2 +dot_product = v1 * v2 +cross_product = v1 ** v2 +magnitude = v1.norm() +normalized = v1.normalized() +``` + +#### Internal Coordinates + +Advanced residue geometry representation: +```python +from Bio.PDB import internal_coords + +# Enable internal coordinates +structure.atom_to_internal_coordinates() + +# Access phi, psi angles +for residue in structure.get_residues(): + if residue.internal_coord: + print(f"Phi: {residue.internal_coord.get_angle('phi')}") + print(f"Psi: {residue.internal_coord.get_angle('psi')}") +``` + +#### Writing Structures + +```python +from Bio.PDB import PDBIO + +io = PDBIO() +io.set_structure(structure) +io.save("output.pdb") + +# Save specific selection +io.save("chain_A.pdb", select=ChainSelector("A")) +``` + +### Bio.SCOP - SCOP Database + +Access to Structural Classification of Proteins database. + +### Bio.KEGG - Pathway Analysis + +Interface to KEGG (Kyoto Encyclopedia of Genes and Genomes) databases: + +**Capabilities:** +- Access pathway maps +- Retrieve enzyme data +- Get compound information +- Query orthology relationships + +## Phylogenetics + +### Bio.Phylo - Phylogenetic Tree Analysis + +Comprehensive phylogenetic tree manipulation and analysis. + +#### Reading and Writing Trees + +**Supported formats:** +- Newick: Simple, widely-used format +- NEXUS: Rich metadata format +- PhyloXML: XML-based with extensive annotations +- NeXML: Modern XML standard + +```python +from Bio import Phylo + +# Read tree +tree = Phylo.read("tree.nwk", "newick") + +# Read multiple trees +trees = list(Phylo.parse("trees.nex", "nexus")) + +# Write tree +Phylo.write(tree, "output.nwk", "newick") +``` + +#### Tree Visualization + +**ASCII visualization:** +```python +Phylo.draw_ascii(tree) +``` + +**Matplotlib plotting:** +```python +import matplotlib.pyplot as plt +Phylo.draw(tree) +plt.show() + +# With customization +fig, ax = plt.subplots(figsize=(10, 8)) +Phylo.draw(tree, axes=ax, do_show=False) +ax.set_title("My Phylogenetic Tree") +plt.show() +``` + +#### Tree Navigation and Manipulation + +**Find clades:** +```python +# Get all terminal nodes (leaves) +terminals = tree.get_terminals() + +# Get all nonterminal nodes +nonterminals = tree.get_nonterminals() + +# Find specific clade +target = tree.find_any(name="Species_A") + +# Find all matching clades +matches = tree.find_clades(terminal=True) +``` + +**Tree properties:** +```python +# Count terminals +num_species = tree.count_terminals() + +# Get total branch length +total_length = tree.total_branch_length() + +# Check if tree is bifurcating +is_bifurcating = tree.is_bifurcating() + +# Get maximum distance from root +max_dist = tree.distance(tree.root) +``` + +**Tree modification:** +```python +# Prune tree to specific taxa +keep_taxa = ["Species_A", "Species_B", "Species_C"] +tree.prune(keep_taxa) + +# Collapse short branches +tree.collapse_all(lambda c: c.branch_length < 0.01) + +# Ladderize (sort branches) +tree.ladderize() + +# Root tree at midpoint +tree.root_at_midpoint() + +# Root at specific clade +outgroup = tree.find_any(name="Outgroup_species") +tree.root_with_outgroup(outgroup) +``` + +**Calculate distances:** +```python +# Distance between two clades +dist = tree.distance(clade1, clade2) + +# Distance from root +root_dist = tree.distance(tree.root, terminal_clade) +``` + +#### Tree Construction + +**Distance-based methods:** +```python +from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator +from Bio import AlignIO + +# Load alignment +aln = AlignIO.read("alignment.fasta", "fasta") + +# Calculate distance matrix +calculator = DistanceCalculator('identity') +dm = calculator.get_distance(aln) + +# Construct tree using UPGMA +constructor = DistanceTreeConstructor() +tree_upgma = constructor.upgma(dm) + +# Or using Neighbor-Joining +tree_nj = constructor.nj(dm) +``` + +**Parsimony method:** +```python +from Bio.Phylo.TreeConstruction import ParsimonyScorer, NNITreeSearcher + +scorer = ParsimonyScorer() +searcher = NNITreeSearcher(scorer) +tree = searcher.search(starting_tree, alignment) +``` + +**Distance calculators:** +- 'identity': Simple identity scoring +- 'blastn': BLAST nucleotide scoring +- 'blastp': BLAST protein scoring +- 'dnafull': EMBOSS DNA scoring matrix +- 'blosum62': BLOSUM62 protein matrix +- 'pam250': PAM250 protein matrix + +#### Consensus Trees + +```python +from Bio.Phylo.Consensus import majority_consensus, strict_consensus + +# Strict consensus +consensus_strict = strict_consensus(trees) + +# Majority rule consensus +consensus_majority = majority_consensus(trees, cutoff=0.5) + +# Bootstrap consensus +from Bio.Phylo.Consensus import bootstrap_consensus +bootstrap_tree = bootstrap_consensus(trees, cutoff=0.7) +``` + +#### External Tool Wrappers + +**PhyML:** +```python +from Bio.Phylo.Applications import PhymlCommandline + +cmd = PhymlCommandline(input="alignment.phy", datatype="nt", model="HKY85", alpha="e", bootstrap=100) +stdout, stderr = cmd() +tree = Phylo.read("alignment.phy_phyml_tree.txt", "newick") +``` + +**RAxML:** +```python +from Bio.Phylo.Applications import RaxmlCommandline + +cmd = RaxmlCommandline( + sequences="alignment.phy", + model="GTRGAMMA", + name="mytree", + parsimony_seed=12345 +) +stdout, stderr = cmd() +``` + +**FastTree:** +```python +from Bio.Phylo.Applications import FastTreeCommandline + +cmd = FastTreeCommandline(input="alignment.fasta", out="tree.nwk", gtr=True, gamma=True) +stdout, stderr = cmd() +``` + +### Bio.Phylo.PAML - Evolutionary Analysis + +Interface to PAML (Phylogenetic Analysis by Maximum Likelihood): + +**CODEML - Codon-based analysis:** +```python +from Bio.Phylo.PAML import codeml + +cml = codeml.Codeml() +cml.alignment = "alignment.phy" +cml.tree = "tree.nwk" +cml.out_file = "results.out" +cml.working_dir = "./paml_wd" + +# Set parameters +cml.set_options( + seqtype=1, # Codon sequences + model=0, # One omega ratio + NSsites=[0, 1, 2], # Test different models + CodonFreq=2, # F3x4 codon frequencies +) + +results = cml.run() +``` + +**BaseML - Nucleotide-based analysis:** +```python +from Bio.Phylo.PAML import baseml + +bml = baseml.Baseml() +bml.alignment = "alignment.phy" +bml.tree = "tree.nwk" +results = bml.run() +``` + +**YN00 - Yang-Nielsen method:** +```python +from Bio.Phylo.PAML import yn00 + +yn = yn00.Yn00() +yn.alignment = "alignment.phy" +results = yn.run() +``` + +## Population Genetics + +### Bio.PopGen - Population Genetics Analysis + +Tools for population-level genetic analysis. + +**Capabilities:** +- Allele frequency calculations +- Hardy-Weinberg equilibrium testing +- Linkage disequilibrium analysis +- F-statistics (FST, FIS, FIT) +- Tajima's D +- Population structure analysis + +## Clustering and Machine Learning + +### Bio.Cluster - Clustering Algorithms + +Statistical clustering for gene expression and other biological data: + +**Hierarchical clustering:** +```python +from Bio.Cluster import treecluster + +tree = treecluster(data, method='a', dist='e') +# method: 'a'=average, 's'=single, 'm'=maximum, 'c'=centroid +# dist: 'e'=Euclidean, 'c'=correlation, 'a'=absolute correlation +``` + +**k-means clustering:** +```python +from Bio.Cluster import kcluster + +clusterid, error, nfound = kcluster(data, nclusters=5, npass=100) +``` + +**Self-Organizing Maps (SOM):** +```python +from Bio.Cluster import somcluster + +clusterid, celldata = somcluster(data, nx=3, ny=3) +``` + +**Principal Component Analysis:** +```python +from Bio.Cluster import pca + +columnmean, coordinates, components, eigenvalues = pca(data) +``` + +## Visualization + +### Bio.Graphics - Genomic Visualization + +Tools for creating publication-quality biological graphics. + +**GenomeDiagram - Circular and linear genome maps:** +```python +from Bio.Graphics import GenomeDiagram +from Bio import SeqIO + +record = SeqIO.read("genome.gb", "genbank") + +gd_diagram = GenomeDiagram.Diagram("Genome Map") +gd_track = gd_diagram.new_track(1, greytrack=True) +gd_feature_set = gd_track.new_set() + +# Add features +for feature in record.features: + if feature.type == "gene": + gd_feature_set.add_feature(feature, color="blue", label=True) + +gd_diagram.draw(format="linear", pagesize='A4', fragments=1) +gd_diagram.write("genome_map.pdf", "PDF") +``` + +**Chromosomes - Chromosome visualization:** +```python +from Bio.Graphics.BasicChromosome import Chromosome + +chr = Chromosome("Chromosome 1") +chr.add("gene1", 1000, 2000, color="red") +chr.add("gene2", 3000, 4500, color="blue") +``` + +## Phenotype Analysis + +### Bio.phenotype - Phenotypic Microarray Analysis + +Tools for analyzing phenotypic microarray data (e.g., Biolog plates): + +**Capabilities:** +- Parse PM plate data +- Growth curve analysis +- Compare phenotypic profiles +- Calculate similarity metrics diff --git a/scientific-packages/biopython/scripts/alignment_phylogeny.py b/scientific-packages/biopython/scripts/alignment_phylogeny.py new file mode 100644 index 0000000..b46790f --- /dev/null +++ b/scientific-packages/biopython/scripts/alignment_phylogeny.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +""" +Sequence alignment and phylogenetic analysis using BioPython. + +This script demonstrates: +- Pairwise sequence alignment +- Multiple sequence alignment I/O +- Distance matrix calculation +- Phylogenetic tree construction +- Tree manipulation and visualization +""" + +from Bio import Align, AlignIO, Phylo +from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor +from Bio.Phylo.TreeConstruction import ParsimonyScorer, NNITreeSearcher +from Bio.Seq import Seq +import matplotlib.pyplot as plt + + +def pairwise_alignment_example(): + """Demonstrate pairwise sequence alignment.""" + + print("Pairwise Sequence Alignment") + print("=" * 60) + + # Create aligner + aligner = Align.PairwiseAligner() + + # Set parameters + aligner.mode = "global" # or 'local' for local alignment + aligner.match_score = 2 + aligner.mismatch_score = -1 + aligner.open_gap_score = -2 + aligner.extend_gap_score = -0.5 + + # Sequences to align + seq1 = "ACGTACGTACGT" + seq2 = "ACGTTACGTGT" + + print(f"Sequence 1: {seq1}") + print(f"Sequence 2: {seq2}") + print() + + # Perform alignment + alignments = aligner.align(seq1, seq2) + + # Show results + print(f"Number of optimal alignments: {len(alignments)}") + print(f"Best alignment score: {alignments.score:.1f}") + print() + + # Display best alignment + print("Best alignment:") + print(alignments[0]) + print() + + +def local_alignment_example(): + """Demonstrate local alignment (Smith-Waterman).""" + + print("Local Sequence Alignment") + print("=" * 60) + + aligner = Align.PairwiseAligner() + aligner.mode = "local" + aligner.match_score = 2 + aligner.mismatch_score = -1 + aligner.open_gap_score = -2 + aligner.extend_gap_score = -0.5 + + seq1 = "AAAAACGTACGTACGTAAAAA" + seq2 = "TTTTTTACGTACGTTTTTTT" + + print(f"Sequence 1: {seq1}") + print(f"Sequence 2: {seq2}") + print() + + alignments = aligner.align(seq1, seq2) + + print(f"Best local alignment score: {alignments.score:.1f}") + print() + print("Best local alignment:") + print(alignments[0]) + print() + + +def read_and_analyze_alignment(alignment_file, format="fasta"): + """Read and analyze a multiple sequence alignment.""" + + print(f"Reading alignment from: {alignment_file}") + print("-" * 60) + + # Read alignment + alignment = AlignIO.read(alignment_file, format) + + print(f"Number of sequences: {len(alignment)}") + print(f"Alignment length: {alignment.get_alignment_length()}") + print() + + # Display alignment + print("Alignment preview:") + for record in alignment[:5]: # Show first 5 sequences + print(f"{record.id[:15]:15s} {record.seq[:50]}...") + + print() + + # Calculate some statistics + analyze_alignment_statistics(alignment) + + return alignment + + +def analyze_alignment_statistics(alignment): + """Calculate statistics for an alignment.""" + + print("Alignment Statistics:") + print("-" * 60) + + # Get alignment length + length = alignment.get_alignment_length() + + # Count gaps + total_gaps = sum(str(record.seq).count("-") for record in alignment) + gap_percentage = (total_gaps / (length * len(alignment))) * 100 + + print(f"Total positions: {length}") + print(f"Number of sequences: {len(alignment)}") + print(f"Total gaps: {total_gaps} ({gap_percentage:.1f}%)") + print() + + # Calculate conservation at each position + conserved_positions = 0 + for i in range(length): + column = alignment[:, i] + # Count most common residue + if column.count(max(set(column), key=column.count)) == len(alignment): + conserved_positions += 1 + + conservation = (conserved_positions / length) * 100 + print(f"Fully conserved positions: {conserved_positions} ({conservation:.1f}%)") + print() + + +def calculate_distance_matrix(alignment): + """Calculate distance matrix from alignment.""" + + print("Calculating Distance Matrix") + print("-" * 60) + + calculator = DistanceCalculator("identity") + dm = calculator.get_distance(alignment) + + print("Distance matrix:") + print(dm) + print() + + return dm + + +def build_upgma_tree(alignment): + """Build phylogenetic tree using UPGMA.""" + + print("Building UPGMA Tree") + print("=" * 60) + + # Calculate distance matrix + calculator = DistanceCalculator("identity") + dm = calculator.get_distance(alignment) + + # Construct tree + constructor = DistanceTreeConstructor(calculator) + tree = constructor.upgma(dm) + + print("UPGMA tree constructed") + print(f"Number of terminals: {tree.count_terminals()}") + print() + + return tree + + +def build_nj_tree(alignment): + """Build phylogenetic tree using Neighbor-Joining.""" + + print("Building Neighbor-Joining Tree") + print("=" * 60) + + # Calculate distance matrix + calculator = DistanceCalculator("identity") + dm = calculator.get_distance(alignment) + + # Construct tree + constructor = DistanceTreeConstructor(calculator) + tree = constructor.nj(dm) + + print("Neighbor-Joining tree constructed") + print(f"Number of terminals: {tree.count_terminals()}") + print() + + return tree + + +def visualize_tree(tree, title="Phylogenetic Tree"): + """Visualize phylogenetic tree.""" + + print("Visualizing tree...") + print() + + # ASCII visualization + print("ASCII tree:") + Phylo.draw_ascii(tree) + print() + + # Matplotlib visualization + fig, ax = plt.subplots(figsize=(10, 8)) + Phylo.draw(tree, axes=ax, do_show=False) + ax.set_title(title) + plt.tight_layout() + plt.savefig("tree_visualization.png", dpi=300, bbox_inches="tight") + print("Tree saved to tree_visualization.png") + print() + + +def manipulate_tree(tree): + """Demonstrate tree manipulation operations.""" + + print("Tree Manipulation") + print("=" * 60) + + # Get terminals + terminals = tree.get_terminals() + print(f"Terminal nodes: {[t.name for t in terminals]}") + print() + + # Get nonterminals + nonterminals = tree.get_nonterminals() + print(f"Number of internal nodes: {len(nonterminals)}") + print() + + # Calculate total branch length + total_length = tree.total_branch_length() + print(f"Total branch length: {total_length:.4f}") + print() + + # Find specific clade + if len(terminals) > 0: + target_name = terminals[0].name + found = tree.find_any(name=target_name) + print(f"Found clade: {found.name}") + print() + + # Ladderize tree (sort branches) + tree.ladderize() + print("Tree ladderized (branches sorted)") + print() + + # Root at midpoint + tree.root_at_midpoint() + print("Tree rooted at midpoint") + print() + + return tree + + +def read_and_analyze_tree(tree_file, format="newick"): + """Read and analyze a phylogenetic tree.""" + + print(f"Reading tree from: {tree_file}") + print("-" * 60) + + tree = Phylo.read(tree_file, format) + + print(f"Tree format: {format}") + print(f"Number of terminals: {tree.count_terminals()}") + print(f"Is bifurcating: {tree.is_bifurcating()}") + print(f"Total branch length: {tree.total_branch_length():.4f}") + print() + + # Show tree structure + print("Tree structure:") + Phylo.draw_ascii(tree) + print() + + return tree + + +def compare_trees(tree1, tree2): + """Compare two phylogenetic trees.""" + + print("Comparing Trees") + print("=" * 60) + + # Get terminal names + terminals1 = {t.name for t in tree1.get_terminals()} + terminals2 = {t.name for t in tree2.get_terminals()} + + print(f"Tree 1 terminals: {len(terminals1)}") + print(f"Tree 2 terminals: {len(terminals2)}") + print(f"Shared terminals: {len(terminals1 & terminals2)}") + print(f"Unique to tree 1: {len(terminals1 - terminals2)}") + print(f"Unique to tree 2: {len(terminals2 - terminals1)}") + print() + + +def create_example_alignment(): + """Create an example alignment for demonstration.""" + + from Bio.Seq import Seq + from Bio.SeqRecord import SeqRecord + from Bio.Align import MultipleSeqAlignment + + sequences = [ + SeqRecord(Seq("ACTGCTAGCTAGCTAG"), id="seq1"), + SeqRecord(Seq("ACTGCTAGCT-GCTAG"), id="seq2"), + SeqRecord(Seq("ACTGCTAGCTAGCTGG"), id="seq3"), + SeqRecord(Seq("ACTGCT-GCTAGCTAG"), id="seq4"), + ] + + alignment = MultipleSeqAlignment(sequences) + + # Save alignment + AlignIO.write(alignment, "example_alignment.fasta", "fasta") + print("Created example alignment: example_alignment.fasta") + print() + + return alignment + + +def example_workflow(): + """Demonstrate complete alignment and phylogeny workflow.""" + + print("=" * 60) + print("BioPython Alignment & Phylogeny Workflow") + print("=" * 60) + print() + + # Pairwise alignment examples + pairwise_alignment_example() + print() + local_alignment_example() + print() + + # Create example data + alignment = create_example_alignment() + + # Analyze alignment + analyze_alignment_statistics(alignment) + + # Calculate distance matrix + dm = calculate_distance_matrix(alignment) + + # Build trees + upgma_tree = build_upgma_tree(alignment) + nj_tree = build_nj_tree(alignment) + + # Manipulate tree + manipulate_tree(upgma_tree) + + # Visualize + visualize_tree(upgma_tree, "UPGMA Tree") + + print("Workflow completed!") + print() + + +if __name__ == "__main__": + example_workflow() + + print("Note: For real analyses, use actual alignment files.") + print("Supported alignment formats: clustal, phylip, stockholm, nexus, fasta") + print("Supported tree formats: newick, nexus, phyloxml, nexml") diff --git a/scientific-packages/biopython/scripts/blast_search.py b/scientific-packages/biopython/scripts/blast_search.py new file mode 100644 index 0000000..7e05e94 --- /dev/null +++ b/scientific-packages/biopython/scripts/blast_search.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +""" +BLAST searches and result parsing using BioPython. + +This script demonstrates: +- Running BLAST searches via NCBI (qblast) +- Parsing BLAST XML output +- Filtering and analyzing results +- Working with alignments and HSPs +""" + +from Bio.Blast import NCBIWWW, NCBIXML +from Bio import SeqIO + + +def run_blast_online(sequence, program="blastn", database="nt", expect=0.001): + """ + Run BLAST search via NCBI's qblast. + + Parameters: + - sequence: Sequence string or Seq object + - program: blastn, blastp, blastx, tblastn, tblastx + - database: nt (nucleotide), nr (protein), refseq_rna, etc. + - expect: E-value threshold + """ + + print(f"Running {program} search against {database} database...") + print(f"E-value threshold: {expect}") + print("-" * 60) + + # Run BLAST + result_handle = NCBIWWW.qblast( + program=program, + database=database, + sequence=sequence, + expect=expect, + hitlist_size=50, # Number of sequences to show alignments for + ) + + # Save results + output_file = "blast_results.xml" + with open(output_file, "w") as out: + out.write(result_handle.read()) + + result_handle.close() + + print(f"BLAST search complete. Results saved to {output_file}") + print() + + return output_file + + +def parse_blast_results(xml_file, max_hits=10, evalue_threshold=0.001): + """Parse BLAST XML results.""" + + print(f"Parsing BLAST results from: {xml_file}") + print(f"E-value threshold: {evalue_threshold}") + print("=" * 60) + + with open(xml_file) as result_handle: + blast_record = NCBIXML.read(result_handle) + + print(f"Query: {blast_record.query}") + print(f"Query length: {blast_record.query_length} residues") + print(f"Database: {blast_record.database}") + print(f"Number of alignments: {len(blast_record.alignments)}") + print() + + hit_count = 0 + + for alignment in blast_record.alignments: + for hsp in alignment.hsps: + if hsp.expect <= evalue_threshold: + hit_count += 1 + + if hit_count <= max_hits: + print(f"Hit {hit_count}:") + print(f" Sequence: {alignment.title}") + print(f" Length: {alignment.length}") + print(f" E-value: {hsp.expect:.2e}") + print(f" Score: {hsp.score}") + print(f" Identities: {hsp.identities}/{hsp.align_length} ({hsp.identities / hsp.align_length * 100:.1f}%)") + print(f" Positives: {hsp.positives}/{hsp.align_length} ({hsp.positives / hsp.align_length * 100:.1f}%)") + print(f" Gaps: {hsp.gaps}/{hsp.align_length}") + print(f" Query range: {hsp.query_start} - {hsp.query_end}") + print(f" Subject range: {hsp.sbjct_start} - {hsp.sbjct_end}") + print() + + # Show alignment (first 100 characters) + print(" Alignment preview:") + print(f" Query: {hsp.query[:100]}") + print(f" Match: {hsp.match[:100]}") + print(f" Sbjct: {hsp.sbjct[:100]}") + print() + + print(f"Total significant hits (E-value <= {evalue_threshold}): {hit_count}") + print() + + return blast_record + + +def parse_multiple_queries(xml_file): + """Parse BLAST results with multiple queries.""" + + print(f"Parsing multiple queries from: {xml_file}") + print("=" * 60) + + with open(xml_file) as result_handle: + blast_records = NCBIXML.parse(result_handle) + + for i, blast_record in enumerate(blast_records, 1): + print(f"\nQuery {i}: {blast_record.query}") + print(f" Number of hits: {len(blast_record.alignments)}") + + if blast_record.alignments: + best_hit = blast_record.alignments[0] + best_hsp = best_hit.hsps[0] + print(f" Best hit: {best_hit.title[:80]}...") + print(f" Best E-value: {best_hsp.expect:.2e}") + + +def filter_blast_results(blast_record, min_identity=0.7, min_coverage=0.5): + """Filter BLAST results by identity and coverage.""" + + print(f"Filtering results:") + print(f" Minimum identity: {min_identity * 100}%") + print(f" Minimum coverage: {min_coverage * 100}%") + print("-" * 60) + + filtered_hits = [] + + for alignment in blast_record.alignments: + for hsp in alignment.hsps: + identity_fraction = hsp.identities / hsp.align_length + coverage = hsp.align_length / blast_record.query_length + + if identity_fraction >= min_identity and coverage >= min_coverage: + filtered_hits.append( + { + "title": alignment.title, + "length": alignment.length, + "evalue": hsp.expect, + "identity": identity_fraction, + "coverage": coverage, + "alignment": alignment, + "hsp": hsp, + } + ) + + print(f"Found {len(filtered_hits)} hits matching criteria") + print() + + # Sort by E-value + filtered_hits.sort(key=lambda x: x["evalue"]) + + # Display top hits + for i, hit in enumerate(filtered_hits[:5], 1): + print(f"{i}. {hit['title'][:80]}") + print(f" Identity: {hit['identity']*100:.1f}%, Coverage: {hit['coverage']*100:.1f}%, E-value: {hit['evalue']:.2e}") + print() + + return filtered_hits + + +def extract_hit_sequences(blast_record, output_file="blast_hits.fasta"): + """Extract aligned sequences from BLAST results.""" + + print(f"Extracting hit sequences to {output_file}...") + + from Bio.Seq import Seq + from Bio.SeqRecord import SeqRecord + + records = [] + + for i, alignment in enumerate(blast_record.alignments[:10]): # Top 10 hits + hsp = alignment.hsps[0] # Best HSP for this alignment + + # Extract accession from title + accession = alignment.title.split()[0] + + # Create SeqRecord from aligned subject sequence + record = SeqRecord( + Seq(hsp.sbjct.replace("-", "")), # Remove gaps + id=accession, + description=f"E-value: {hsp.expect:.2e}, Identity: {hsp.identities}/{hsp.align_length}", + ) + + records.append(record) + + # Write to FASTA + SeqIO.write(records, output_file, "fasta") + + print(f"Extracted {len(records)} sequences") + print() + + +def analyze_blast_statistics(blast_record): + """Compute statistics from BLAST results.""" + + print("BLAST Result Statistics:") + print("-" * 60) + + if not blast_record.alignments: + print("No hits found") + return + + evalues = [] + identities = [] + scores = [] + + for alignment in blast_record.alignments: + for hsp in alignment.hsps: + evalues.append(hsp.expect) + identities.append(hsp.identities / hsp.align_length) + scores.append(hsp.score) + + import statistics + + print(f"Total HSPs: {len(evalues)}") + print(f"\nE-values:") + print(f" Min: {min(evalues):.2e}") + print(f" Max: {max(evalues):.2e}") + print(f" Median: {statistics.median(evalues):.2e}") + print(f"\nIdentity percentages:") + print(f" Min: {min(identities)*100:.1f}%") + print(f" Max: {max(identities)*100:.1f}%") + print(f" Mean: {statistics.mean(identities)*100:.1f}%") + print(f"\nBit scores:") + print(f" Min: {min(scores):.1f}") + print(f" Max: {max(scores):.1f}") + print(f" Mean: {statistics.mean(scores):.1f}") + print() + + +def example_workflow(): + """Demonstrate BLAST workflow.""" + + print("=" * 60) + print("BioPython BLAST Example Workflow") + print("=" * 60) + print() + + # Example sequence (human beta-globin) + example_sequence = """ + ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTGTGGGGCAAGGTGAACGTGGATGAAGTTGGTGGTGAGGCCCTGGGCAGGCTGCTGGTGGTCTACCCTTGGACCCAGAGGTTCTTTGAGTCCTTTGGGGATCTGTCCACTCCTGATGCTGTTATGGGCAACCCTAAGGTGAAGGCTCATGGCAAGAAAGTGCTCGGTGCCTTTAGTGATGGCCTGGCTCACCTGGACAACCTCAAGGGCACCTTTGCCACACTGAGTGAGCTGCACTGTGACAAGCTGCACGTGGATCCTGAGAACTTCAGGCTCCTGGGCAACGTGCTGGTCTGTGTGCTGGCCCATCACTTTGGCAAAGAATTCACCCCACCAGTGCAGGCTGCCTATCAGAAAGTGGTGGCTGGTGTGGCTAATGCCCTGGCCCACAAGTATCACTAAGCTCGCTTTCTTGCTGTCCAATTTCTATTAAAGGTTCCTTTGTTCCCTAAGTCCAACTACTAAACTGGGGGATATTATGAAGGGCCTTGAGCATCTGGATTCTGCCTAATAAAAAACATTTATTTTCATTGC + """.replace("\n", "").replace(" ", "") + + print("Example: Human beta-globin sequence") + print(f"Length: {len(example_sequence)} bp") + print() + + # Note: Uncomment to run actual BLAST search (takes time) + # xml_file = run_blast_online(example_sequence, program="blastn", database="nt", expect=0.001) + + # For demonstration, use a pre-existing results file + print("To run a real BLAST search, uncomment the run_blast_online() line") + print("For now, demonstrating parsing with example results file") + print() + + # If you have results, parse them: + # blast_record = parse_blast_results("blast_results.xml", max_hits=5) + # filtered = filter_blast_results(blast_record, min_identity=0.9) + # analyze_blast_statistics(blast_record) + # extract_hit_sequences(blast_record) + + +if __name__ == "__main__": + example_workflow() + + print() + print("Note: BLAST searches can take several minutes.") + print("For production use, consider running local BLAST instead.") diff --git a/scientific-packages/biopython/scripts/file_io.py b/scientific-packages/biopython/scripts/file_io.py new file mode 100644 index 0000000..f1033ed --- /dev/null +++ b/scientific-packages/biopython/scripts/file_io.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +""" +File I/O operations using BioPython SeqIO. + +This script demonstrates: +- Reading sequences from various formats +- Writing sequences to files +- Converting between formats +- Filtering and processing sequences +- Working with large files efficiently +""" + +from Bio import SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + + +def read_sequences(filename, format_type): + """Read and display sequences from a file.""" + + print(f"Reading {format_type} file: {filename}") + print("-" * 60) + + count = 0 + for record in SeqIO.parse(filename, format_type): + count += 1 + print(f"ID: {record.id}") + print(f"Name: {record.name}") + print(f"Description: {record.description}") + print(f"Sequence length: {len(record.seq)}") + print(f"Sequence: {record.seq[:50]}...") + print() + + # Only show first 3 sequences + if count >= 3: + break + + # Count total sequences + total = len(list(SeqIO.parse(filename, format_type))) + print(f"Total sequences in file: {total}") + print() + + +def read_single_sequence(filename, format_type): + """Read a single sequence from a file.""" + + record = SeqIO.read(filename, format_type) + + print("Single sequence record:") + print(f"ID: {record.id}") + print(f"Sequence: {record.seq}") + print() + + +def write_sequences(records, output_filename, format_type): + """Write sequences to a file.""" + + count = SeqIO.write(records, output_filename, format_type) + print(f"Wrote {count} sequences to {output_filename} in {format_type} format") + print() + + +def convert_format(input_file, input_format, output_file, output_format): + """Convert sequences from one format to another.""" + + count = SeqIO.convert(input_file, input_format, output_file, output_format) + print(f"Converted {count} sequences from {input_format} to {output_format}") + print() + + +def filter_sequences(input_file, format_type, min_length=100, max_length=1000): + """Filter sequences by length.""" + + filtered = [] + + for record in SeqIO.parse(input_file, format_type): + if min_length <= len(record.seq) <= max_length: + filtered.append(record) + + print(f"Found {len(filtered)} sequences between {min_length} and {max_length} bp") + return filtered + + +def extract_subsequence(input_file, format_type, seq_id, start, end): + """Extract a subsequence from a specific record.""" + + # Index for efficient access + record_dict = SeqIO.index(input_file, format_type) + + if seq_id in record_dict: + record = record_dict[seq_id] + subseq = record.seq[start:end] + print(f"Extracted subsequence from {seq_id} ({start}:{end}):") + print(subseq) + return subseq + else: + print(f"Sequence {seq_id} not found") + return None + + +def create_sequence_records(): + """Create SeqRecord objects from scratch.""" + + # Simple record + simple_record = SeqRecord( + Seq("ATGCATGCATGC"), + id="seq001", + name="MySequence", + description="Example sequence" + ) + + # Record with annotations + annotated_record = SeqRecord( + Seq("ATGGTGCATCTGACTCCTGAGGAG"), + id="seq002", + name="GeneX", + description="Important gene" + ) + annotated_record.annotations["molecule_type"] = "DNA" + annotated_record.annotations["organism"] = "Homo sapiens" + + return [simple_record, annotated_record] + + +def index_large_file(filename, format_type): + """Index a large file for random access without loading into memory.""" + + # Create index + record_index = SeqIO.index(filename, format_type) + + print(f"Indexed {len(record_index)} sequences") + print(f"Available IDs: {list(record_index.keys())[:10]}...") + print() + + # Access specific record by ID + if len(record_index) > 0: + first_id = list(record_index.keys())[0] + record = record_index[first_id] + print(f"Accessed record: {record.id}") + print() + + # Close index + record_index.close() + + +def parse_with_quality_scores(fastq_file): + """Parse FASTQ files with quality scores.""" + + print("Parsing FASTQ with quality scores:") + print("-" * 60) + + for record in SeqIO.parse(fastq_file, "fastq"): + print(f"ID: {record.id}") + print(f"Sequence: {record.seq[:50]}...") + print(f"Quality scores (first 10): {record.letter_annotations['phred_quality'][:10]}") + + # Calculate average quality + avg_quality = sum(record.letter_annotations["phred_quality"]) / len(record) + print(f"Average quality: {avg_quality:.2f}") + print() + break # Just show first record + + +def batch_process_large_file(input_file, format_type, batch_size=100): + """Process large files in batches to manage memory.""" + + batch = [] + count = 0 + + for record in SeqIO.parse(input_file, format_type): + batch.append(record) + count += 1 + + if len(batch) == batch_size: + # Process batch + print(f"Processing batch of {len(batch)} sequences...") + # Do something with batch + batch = [] # Clear for next batch + + # Process remaining records + if batch: + print(f"Processing final batch of {len(batch)} sequences...") + + print(f"Total sequences processed: {count}") + + +def example_workflow(): + """Demonstrate a complete workflow.""" + + print("=" * 60) + print("BioPython SeqIO Workflow Example") + print("=" * 60) + print() + + # Create example sequences + records = create_sequence_records() + + # Write as FASTA + write_sequences(records, "example_output.fasta", "fasta") + + # Write as GenBank + write_sequences(records, "example_output.gb", "genbank") + + # Convert FASTA to GenBank (would work if file exists) + # convert_format("input.fasta", "fasta", "output.gb", "genbank") + + print("Example workflow completed!") + + +if __name__ == "__main__": + example_workflow() + + print() + print("Note: This script demonstrates BioPython SeqIO operations.") + print("Uncomment and adapt the functions for your specific files.") diff --git a/scientific-packages/biopython/scripts/ncbi_entrez.py b/scientific-packages/biopython/scripts/ncbi_entrez.py new file mode 100644 index 0000000..f9fe21f --- /dev/null +++ b/scientific-packages/biopython/scripts/ncbi_entrez.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +""" +NCBI Entrez database access using BioPython. + +This script demonstrates: +- Searching NCBI databases +- Downloading sequences by accession +- Retrieving PubMed articles +- Batch downloading with WebEnv +- Proper error handling and rate limiting +""" + +import time +from Bio import Entrez, SeqIO + +# IMPORTANT: Always set your email +Entrez.email = "your.email@example.com" # Change this! + + +def search_nucleotide(query, max_results=10): + """Search NCBI nucleotide database.""" + + print(f"Searching nucleotide database for: {query}") + print("-" * 60) + + handle = Entrez.esearch(db="nucleotide", term=query, retmax=max_results) + record = Entrez.read(handle) + handle.close() + + print(f"Found {record['Count']} total matches") + print(f"Returning top {len(record['IdList'])} IDs:") + print(record["IdList"]) + print() + + return record["IdList"] + + +def fetch_sequence_by_accession(accession): + """Download a sequence by accession number.""" + + print(f"Fetching sequence: {accession}") + + try: + handle = Entrez.efetch( + db="nucleotide", id=accession, rettype="gb", retmode="text" + ) + record = SeqIO.read(handle, "genbank") + handle.close() + + print(f"Successfully retrieved: {record.id}") + print(f"Description: {record.description}") + print(f"Length: {len(record.seq)} bp") + print(f"Organism: {record.annotations.get('organism', 'Unknown')}") + print() + + return record + + except Exception as e: + print(f"Error fetching {accession}: {e}") + return None + + +def fetch_multiple_sequences(id_list, output_file="downloaded_sequences.fasta"): + """Download multiple sequences and save to file.""" + + print(f"Fetching {len(id_list)} sequences...") + + try: + # For >200 IDs, efetch automatically uses POST + handle = Entrez.efetch( + db="nucleotide", id=id_list, rettype="fasta", retmode="text" + ) + + # Parse and save + records = list(SeqIO.parse(handle, "fasta")) + handle.close() + + SeqIO.write(records, output_file, "fasta") + + print(f"Successfully downloaded {len(records)} sequences to {output_file}") + print() + + return records + + except Exception as e: + print(f"Error fetching sequences: {e}") + return [] + + +def search_and_download(query, output_file, max_results=100): + """Complete workflow: search and download sequences.""" + + print(f"Searching and downloading: {query}") + print("=" * 60) + + # Search + handle = Entrez.esearch(db="nucleotide", term=query, retmax=max_results) + record = Entrez.read(handle) + handle.close() + + id_list = record["IdList"] + print(f"Found {len(id_list)} sequences") + + if not id_list: + print("No results found") + return + + # Download in batches to be polite + batch_size = 100 + all_records = [] + + for start in range(0, len(id_list), batch_size): + end = min(start + batch_size, len(id_list)) + batch_ids = id_list[start:end] + + print(f"Downloading batch {start // batch_size + 1} ({len(batch_ids)} sequences)...") + + handle = Entrez.efetch( + db="nucleotide", id=batch_ids, rettype="fasta", retmode="text" + ) + batch_records = list(SeqIO.parse(handle, "fasta")) + handle.close() + + all_records.extend(batch_records) + + # Be polite - wait between requests + time.sleep(0.5) + + # Save all records + SeqIO.write(all_records, output_file, "fasta") + print(f"Downloaded {len(all_records)} sequences to {output_file}") + print() + + +def use_history_for_large_queries(query, max_results=1000): + """Use NCBI History server for large queries.""" + + print("Using NCBI History server for large query") + print("-" * 60) + + # Search with history + search_handle = Entrez.esearch( + db="nucleotide", term=query, retmax=max_results, usehistory="y" + ) + search_results = Entrez.read(search_handle) + search_handle.close() + + count = int(search_results["Count"]) + webenv = search_results["WebEnv"] + query_key = search_results["QueryKey"] + + print(f"Found {count} total sequences") + print(f"WebEnv: {webenv[:20]}...") + print(f"QueryKey: {query_key}") + print() + + # Fetch in batches using history + batch_size = 500 + all_records = [] + + for start in range(0, min(count, max_results), batch_size): + end = min(start + batch_size, max_results) + + print(f"Downloading records {start + 1} to {end}...") + + fetch_handle = Entrez.efetch( + db="nucleotide", + rettype="fasta", + retmode="text", + retstart=start, + retmax=batch_size, + webenv=webenv, + query_key=query_key, + ) + + batch_records = list(SeqIO.parse(fetch_handle, "fasta")) + fetch_handle.close() + + all_records.extend(batch_records) + + # Be polite + time.sleep(0.5) + + print(f"Downloaded {len(all_records)} sequences total") + return all_records + + +def search_pubmed(query, max_results=10): + """Search PubMed for articles.""" + + print(f"Searching PubMed for: {query}") + print("-" * 60) + + handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results) + record = Entrez.read(handle) + handle.close() + + id_list = record["IdList"] + print(f"Found {record['Count']} total articles") + print(f"Returning {len(id_list)} PMIDs:") + print(id_list) + print() + + return id_list + + +def fetch_pubmed_abstracts(pmid_list): + """Fetch PubMed article summaries.""" + + print(f"Fetching summaries for {len(pmid_list)} articles...") + + handle = Entrez.efetch(db="pubmed", id=pmid_list, rettype="abstract", retmode="text") + abstracts = handle.read() + handle.close() + + print(abstracts[:500]) # Show first 500 characters + print("...") + print() + + +def get_database_info(database="nucleotide"): + """Get information about an NCBI database.""" + + print(f"Getting info for database: {database}") + print("-" * 60) + + handle = Entrez.einfo(db=database) + record = Entrez.read(handle) + handle.close() + + db_info = record["DbInfo"] + print(f"Name: {db_info['DbName']}") + print(f"Description: {db_info['Description']}") + print(f"Record count: {db_info['Count']}") + print(f"Last update: {db_info['LastUpdate']}") + print() + + +def link_databases(db_from, db_to, id_): + """Find related records in other databases.""" + + print(f"Finding links from {db_from} ID {id_} to {db_to}") + print("-" * 60) + + handle = Entrez.elink(dbfrom=db_from, db=db_to, id=id_) + record = Entrez.read(handle) + handle.close() + + if record[0]["LinkSetDb"]: + linked_ids = [link["Id"] for link in record[0]["LinkSetDb"][0]["Link"]] + print(f"Found {len(linked_ids)} linked records") + print(f"IDs: {linked_ids[:10]}") + else: + print("No linked records found") + + print() + + +def example_workflow(): + """Demonstrate complete Entrez workflow.""" + + print("=" * 60) + print("BioPython Entrez Example Workflow") + print("=" * 60) + print() + + # Note: These are examples - uncomment to run with your email set + + # # Example 1: Search and get IDs + # ids = search_nucleotide("Homo sapiens[Organism] AND COX1[Gene]", max_results=5) + # + # # Example 2: Fetch a specific sequence + # fetch_sequence_by_accession("NM_001301717") + # + # # Example 3: Complete search and download + # search_and_download("Escherichia coli[Organism] AND 16S", "ecoli_16s.fasta", max_results=50) + # + # # Example 4: PubMed search + # pmids = search_pubmed("CRISPR[Title] AND 2023[PDAT]", max_results=5) + # fetch_pubmed_abstracts(pmids[:2]) + # + # # Example 5: Get database info + # get_database_info("nucleotide") + + print("Examples are commented out. Uncomment and set your email to run.") + + +if __name__ == "__main__": + example_workflow() + + print() + print("IMPORTANT: Always set Entrez.email before using these functions!") + print("NCBI requires an email address for their E-utilities.") diff --git a/scientific-packages/biopython/scripts/sequence_operations.py b/scientific-packages/biopython/scripts/sequence_operations.py new file mode 100644 index 0000000..e583eff --- /dev/null +++ b/scientific-packages/biopython/scripts/sequence_operations.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Common sequence operations using BioPython. + +This script demonstrates basic sequence manipulation tasks like: +- Creating and manipulating Seq objects +- Transcription and translation +- Complement and reverse complement +- Calculating GC content and melting temperature +""" + +from Bio.Seq import Seq +from Bio.SeqUtils import gc_fraction, MeltingTemp as mt + + +def demonstrate_seq_operations(): + """Show common Seq object operations.""" + + # Create DNA sequence + dna_seq = Seq("ATGGTGCATCTGACTCCTGAGGAGAAGTCTGCCGTTACTGCCCTG") + + print("Original DNA sequence:") + print(dna_seq) + print() + + # Transcription (DNA -> RNA) + rna_seq = dna_seq.transcribe() + print("Transcribed to RNA:") + print(rna_seq) + print() + + # Translation (DNA -> Protein) + protein_seq = dna_seq.translate() + print("Translated to protein:") + print(protein_seq) + print() + + # Translation with stop codon handling + protein_to_stop = dna_seq.translate(to_stop=True) + print("Translated to first stop codon:") + print(protein_to_stop) + print() + + # Complement + complement = dna_seq.complement() + print("Complement:") + print(complement) + print() + + # Reverse complement + reverse_complement = dna_seq.reverse_complement() + print("Reverse complement:") + print(reverse_complement) + print() + + # GC content + gc = gc_fraction(dna_seq) * 100 + print(f"GC content: {gc:.2f}%") + print() + + # Melting temperature + tm = mt.Tm_NN(dna_seq) + print(f"Melting temperature (nearest-neighbor): {tm:.2f}°C") + print() + + # Sequence searching + codon_start = dna_seq.find("ATG") + print(f"Start codon (ATG) position: {codon_start}") + + # Count occurrences + g_count = dna_seq.count("G") + print(f"Number of G nucleotides: {g_count}") + print() + + +def translate_with_genetic_code(): + """Demonstrate translation with different genetic codes.""" + + dna_seq = Seq("ATGGTGCATCTGACTCCTGAGGAGAAGTCT") + + # Standard genetic code (table 1) + standard = dna_seq.translate(table=1) + print("Standard genetic code translation:") + print(standard) + + # Vertebrate mitochondrial code (table 2) + mito = dna_seq.translate(table=2) + print("Vertebrate mitochondrial code translation:") + print(mito) + print() + + +def working_with_codons(): + """Access genetic code tables.""" + from Bio.Data import CodonTable + + # Get standard genetic code + standard_table = CodonTable.unambiguous_dna_by_id[1] + + print("Standard genetic code:") + print(f"Start codons: {standard_table.start_codons}") + print(f"Stop codons: {standard_table.stop_codons}") + print() + + # Show some codon translations + print("Example codons:") + for codon in ["ATG", "TGG", "TAA", "TAG", "TGA"]: + if codon in standard_table.stop_codons: + print(f"{codon} -> STOP") + else: + aa = standard_table.forward_table.get(codon, "Unknown") + print(f"{codon} -> {aa}") + + +if __name__ == "__main__": + print("=" * 60) + print("BioPython Sequence Operations Demo") + print("=" * 60) + print() + + demonstrate_seq_operations() + print("-" * 60) + translate_with_genetic_code() + print("-" * 60) + working_with_codons() diff --git a/scientific-packages/bioservices/SKILL.md b/scientific-packages/bioservices/SKILL.md new file mode 100644 index 0000000..8905127 --- /dev/null +++ b/scientific-packages/bioservices/SKILL.md @@ -0,0 +1,355 @@ +--- +name: bioservices +description: Toolkit for accessing 40+ biological web services and databases programmatically. Use when working with protein sequences, gene pathways (KEGG), identifier mapping (UniProt), compound databases (ChEBI, ChEMBL), sequence analysis (BLAST), pathway interactions, gene ontology, or any bioinformatics data retrieval tasks requiring integration across multiple biological databases. +--- + +# BioServices + +## Overview + +BioServices is a Python package providing programmatic access to approximately 40 bioinformatics web services and databases. Use this skill to retrieve biological data, perform cross-database queries, map identifiers, analyze sequences, and integrate multiple biological resources in Python workflows. The package handles both REST and SOAP/WSDL protocols transparently. + +## When to Use This Skill + +Apply this skill when tasks involve: +- Retrieving protein sequences, annotations, or structures from UniProt, PDB, Pfam +- Analyzing metabolic pathways and gene functions via KEGG or Reactome +- Searching compound databases (ChEBI, ChEMBL, PubChem) for chemical information +- Converting identifiers between different biological databases (KEGG↔UniProt, compound IDs) +- Running sequence similarity searches (BLAST, MUSCLE alignment) +- Querying gene ontology terms (QuickGO, GO annotations) +- Accessing protein-protein interaction data (PSICQUIC, IntactComplex) +- Mining genomic data (BioMart, ArrayExpress, ENA) +- Integrating data from multiple bioinformatics resources in a single workflow + +## Core Capabilities + +### 1. Protein Analysis + +Retrieve protein information, sequences, and functional annotations: + +```python +from bioservices import UniProt + +u = UniProt(verbose=False) + +# Search for protein by name +results = u.search("ZAP70_HUMAN", frmt="tab", columns="id,genes,organism") + +# Retrieve FASTA sequence +sequence = u.retrieve("P43403", "fasta") + +# Map identifiers between databases +kegg_ids = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query="P43403") +``` + +**Key methods:** +- `search()`: Query UniProt with flexible search terms +- `retrieve()`: Get protein entries in various formats (FASTA, XML, tab) +- `mapping()`: Convert identifiers between databases + +Reference: `references/services_reference.md` for complete UniProt API details. + +### 2. Pathway Discovery and Analysis + +Access KEGG pathway information for genes and organisms: + +```python +from bioservices import KEGG + +k = KEGG() +k.organism = "hsa" # Set to human + +# Search for organisms +k.lookfor_organism("droso") # Find Drosophila species + +# Find pathways by name +k.lookfor_pathway("B cell") # Returns matching pathway IDs + +# Get pathways containing specific genes +pathways = k.get_pathway_by_gene("7535", "hsa") # ZAP70 gene + +# Retrieve and parse pathway data +data = k.get("hsa04660") +parsed = k.parse(data) + +# Extract pathway interactions +interactions = k.parse_kgml_pathway("hsa04660") +relations = interactions['relations'] # Protein-protein interactions + +# Convert to Simple Interaction Format +sif_data = k.pathway2sif("hsa04660") +``` + +**Key methods:** +- `lookfor_organism()`, `lookfor_pathway()`: Search by name +- `get_pathway_by_gene()`: Find pathways containing genes +- `parse_kgml_pathway()`: Extract structured pathway data +- `pathway2sif()`: Get protein interaction networks + +Reference: `references/workflow_patterns.md` for complete pathway analysis workflows. + +### 3. Compound Database Searches + +Search and cross-reference compounds across multiple databases: + +```python +from bioservices import KEGG, UniChem + +k = KEGG() + +# Search compounds by name +results = k.find("compound", "Geldanamycin") # Returns cpd:C11222 + +# Get compound information with database links +compound_info = k.get("cpd:C11222") # Includes ChEBI links + +# Cross-reference KEGG → ChEMBL using UniChem +u = UniChem() +chembl_id = u.get_compound_id_from_kegg("C11222") # Returns CHEMBL278315 +``` + +**Common workflow:** +1. Search compound by name in KEGG +2. Extract KEGG compound ID +3. Use UniChem for KEGG → ChEMBL mapping +4. ChEBI IDs are often provided in KEGG entries + +Reference: `references/identifier_mapping.md` for complete cross-database mapping guide. + +### 4. Sequence Analysis + +Run BLAST searches and sequence alignments: + +```python +from bioservices import NCBIblast + +s = NCBIblast(verbose=False) + +# Run BLASTP against UniProtKB +jobid = s.run( + program="blastp", + sequence=protein_sequence, + stype="protein", + database="uniprotkb", + email="your.email@example.com" # Required by NCBI +) + +# Check job status and retrieve results +s.getStatus(jobid) +results = s.getResult(jobid, "out") +``` + +**Note:** BLAST jobs are asynchronous. Check status before retrieving results. + +### 5. Identifier Mapping + +Convert identifiers between different biological databases: + +```python +from bioservices import UniProt, KEGG + +# UniProt mapping (many database pairs supported) +u = UniProt() +results = u.mapping( + fr="UniProtKB_AC-ID", # Source database + to="KEGG", # Target database + query="P43403" # Identifier(s) to convert +) + +# KEGG gene ID → UniProt +kegg_to_uniprot = u.mapping(fr="KEGG", to="UniProtKB_AC-ID", query="hsa:7535") + +# For compounds, use UniChem +from bioservices import UniChem +u = UniChem() +chembl_from_kegg = u.get_compound_id_from_kegg("C11222") +``` + +**Supported mappings (UniProt):** +- UniProtKB ↔ KEGG +- UniProtKB ↔ Ensembl +- UniProtKB ↔ PDB +- UniProtKB ↔ RefSeq +- And many more (see `references/identifier_mapping.md`) + +### 6. Gene Ontology Queries + +Access GO terms and annotations: + +```python +from bioservices import QuickGO + +g = QuickGO(verbose=False) + +# Retrieve GO term information +term_info = g.Term("GO:0003824", frmt="obo") + +# Search annotations +annotations = g.Annotation(protein="P43403", format="tsv") +``` + +### 7. Protein-Protein Interactions + +Query interaction databases via PSICQUIC: + +```python +from bioservices import PSICQUIC + +s = PSICQUIC(verbose=False) + +# Query specific database (e.g., MINT) +interactions = s.query("mint", "ZAP70 AND species:9606") + +# List available interaction databases +databases = s.activeDBs +``` + +**Available databases:** MINT, IntAct, BioGRID, DIP, and 30+ others. + +## Multi-Service Integration Workflows + +BioServices excels at combining multiple services for comprehensive analysis. Common integration patterns: + +### Complete Protein Analysis Pipeline + +Execute a full protein characterization workflow: + +```bash +python scripts/protein_analysis_workflow.py ZAP70_HUMAN your.email@example.com +``` + +This script demonstrates: +1. UniProt search for protein entry +2. FASTA sequence retrieval +3. BLAST similarity search +4. KEGG pathway discovery +5. PSICQUIC interaction mapping + +### Pathway Network Analysis + +Analyze all pathways for an organism: + +```bash +python scripts/pathway_analysis.py hsa output_directory/ +``` + +Extracts and analyzes: +- All pathway IDs for organism +- Protein-protein interactions per pathway +- Interaction type distributions +- Exports to CSV/SIF formats + +### Cross-Database Compound Search + +Map compound identifiers across databases: + +```bash +python scripts/compound_cross_reference.py Geldanamycin +``` + +Retrieves: +- KEGG compound ID +- ChEBI identifier +- ChEMBL identifier +- Basic compound properties + +### Batch Identifier Conversion + +Convert multiple identifiers at once: + +```bash +python scripts/batch_id_converter.py input_ids.txt --from UniProtKB_AC-ID --to KEGG +``` + +## Best Practices + +### Output Format Handling + +Different services return data in various formats: +- **XML**: Parse using BeautifulSoup (most SOAP services) +- **Tab-separated (TSV)**: Pandas DataFrames for tabular data +- **Dictionary/JSON**: Direct Python manipulation +- **FASTA**: BioPython integration for sequence analysis + +### Rate Limiting and Verbosity + +Control API request behavior: + +```python +from bioservices import KEGG + +k = KEGG(verbose=False) # Suppress HTTP request details +k.TIMEOUT = 30 # Adjust timeout for slow connections +``` + +### Error Handling + +Wrap service calls in try-except blocks: + +```python +try: + results = u.search("ambiguous_query") + if results: + # Process results + pass +except Exception as e: + print(f"Search failed: {e}") +``` + +### Organism Codes + +Use standard organism abbreviations: +- `hsa`: Homo sapiens (human) +- `mmu`: Mus musculus (mouse) +- `dme`: Drosophila melanogaster +- `sce`: Saccharomyces cerevisiae (yeast) + +List all organisms: `k.list("organism")` or `k.organismIds` + +### Integration with Other Tools + +BioServices works well with: +- **BioPython**: Sequence analysis on retrieved FASTA data +- **Pandas**: Tabular data manipulation +- **PyMOL**: 3D structure visualization (retrieve PDB IDs) +- **NetworkX**: Network analysis of pathway interactions +- **Galaxy**: Custom tool wrappers for workflow platforms + +## Resources + +### scripts/ + +Executable Python scripts demonstrating complete workflows: + +- `protein_analysis_workflow.py`: End-to-end protein characterization +- `pathway_analysis.py`: KEGG pathway discovery and network extraction +- `compound_cross_reference.py`: Multi-database compound searching +- `batch_id_converter.py`: Bulk identifier mapping utility + +Scripts can be executed directly or adapted for specific use cases. + +### references/ + +Detailed documentation loaded as needed: + +- `services_reference.md`: Comprehensive list of all 40+ services with methods +- `workflow_patterns.md`: Detailed multi-step analysis workflows +- `identifier_mapping.md`: Complete guide to cross-database ID conversion + +Load references when working with specific services or complex integration tasks. + +## Installation + +```bash +pip install bioservices +``` + +Dependencies are automatically managed. Package is tested on Python 3.9-3.12. + +## Additional Information + +For detailed API documentation and advanced features, refer to: +- Official documentation: https://bioservices.readthedocs.io/ +- Source code: https://github.com/cokelaer/bioservices +- Service-specific references in `references/services_reference.md` diff --git a/scientific-packages/bioservices/references/identifier_mapping.md b/scientific-packages/bioservices/references/identifier_mapping.md new file mode 100644 index 0000000..6fb9b38 --- /dev/null +++ b/scientific-packages/bioservices/references/identifier_mapping.md @@ -0,0 +1,685 @@ +# BioServices: Identifier Mapping Guide + +This document provides comprehensive information about converting identifiers between different biological databases using BioServices. + +## Table of Contents + +1. [Overview](#overview) +2. [UniProt Mapping Service](#uniprot-mapping-service) +3. [UniChem Compound Mapping](#unichem-compound-mapping) +4. [KEGG Identifier Conversions](#kegg-identifier-conversions) +5. [Common Mapping Patterns](#common-mapping-patterns) +6. [Troubleshooting](#troubleshooting) + +--- + +## Overview + +Biological databases use different identifier systems. Cross-referencing requires mapping between these systems. BioServices provides multiple approaches: + +1. **UniProt Mapping**: Comprehensive protein/gene ID conversion +2. **UniChem**: Chemical compound ID mapping +3. **KEGG**: Built-in cross-references in entries +4. **PICR**: Protein identifier cross-reference service + +--- + +## UniProt Mapping Service + +The UniProt mapping service is the most comprehensive tool for protein and gene identifier conversion. + +### Basic Usage + +```python +from bioservices import UniProt + +u = UniProt() + +# Map single ID +result = u.mapping( + fr="UniProtKB_AC-ID", # Source database + to="KEGG", # Target database + query="P43403" # Identifier to convert +) + +print(result) +# Output: {'P43403': ['hsa:7535']} +``` + +### Batch Mapping + +```python +# Map multiple IDs (comma-separated) +ids = ["P43403", "P04637", "P53779"] +result = u.mapping( + fr="UniProtKB_AC-ID", + to="KEGG", + query=",".join(ids) +) + +for uniprot_id, kegg_ids in result.items(): + print(f"{uniprot_id} → {kegg_ids}") +``` + +### Supported Database Pairs + +UniProt supports mapping between 100+ database pairs. Key ones include: + +#### Protein/Gene Databases + +| Source Format | Code | Target Format | Code | +|---------------|------|---------------|------| +| UniProtKB AC/ID | `UniProtKB_AC-ID` | KEGG | `KEGG` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | Ensembl | `Ensembl` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | Ensembl Protein | `Ensembl_Protein` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | Ensembl Transcript | `Ensembl_Transcript` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | RefSeq Protein | `RefSeq_Protein` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | RefSeq Nucleotide | `RefSeq_Nucleotide` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | GeneID (Entrez) | `GeneID` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | HGNC | `HGNC` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | MGI | `MGI` | +| KEGG | `KEGG` | UniProtKB | `UniProtKB` | +| Ensembl | `Ensembl` | UniProtKB | `UniProtKB` | +| GeneID | `GeneID` | UniProtKB | `UniProtKB` | + +#### Structural Databases + +| Source | Code | Target | Code | +|--------|------|--------|------| +| UniProtKB AC/ID | `UniProtKB_AC-ID` | PDB | `PDB` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | Pfam | `Pfam` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | InterPro | `InterPro` | +| PDB | `PDB` | UniProtKB | `UniProtKB` | + +#### Expression & Proteomics + +| Source | Code | Target | Code | +|--------|------|--------|------| +| UniProtKB AC/ID | `UniProtKB_AC-ID` | PRIDE | `PRIDE` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | ProteomicsDB | `ProteomicsDB` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | PaxDb | `PaxDb` | + +#### Organism-Specific + +| Source | Code | Target | Code | +|--------|------|--------|------| +| UniProtKB AC/ID | `UniProtKB_AC-ID` | FlyBase | `FlyBase` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | WormBase | `WormBase` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | SGD | `SGD` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | ZFIN | `ZFIN` | + +#### Other Useful Mappings + +| Source | Code | Target | Code | +|--------|------|--------|------| +| UniProtKB AC/ID | `UniProtKB_AC-ID` | GO | `GO` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | Reactome | `Reactome` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | STRING | `STRING` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | BioGRID | `BioGRID` | +| UniProtKB AC/ID | `UniProtKB_AC-ID` | OMA | `OMA` | + +### Complete List of Database Codes + +To get the complete, up-to-date list: + +```python +from bioservices import UniProt + +u = UniProt() + +# This information is in the UniProt REST API documentation +# Common patterns: +# - Source databases typically end in source database name +# - UniProtKB uses "UniProtKB_AC-ID" or "UniProtKB" +# - Most other databases use their standard abbreviation +``` + +### Common Database Codes Reference + +**Gene/Protein Identifiers:** +- `UniProtKB_AC-ID`: UniProt accession/ID +- `UniProtKB`: UniProt accession +- `KEGG`: KEGG gene IDs (e.g., hsa:7535) +- `GeneID`: NCBI Gene (Entrez) IDs +- `Ensembl`: Ensembl gene IDs +- `Ensembl_Protein`: Ensembl protein IDs +- `Ensembl_Transcript`: Ensembl transcript IDs +- `RefSeq_Protein`: RefSeq protein IDs (NP_) +- `RefSeq_Nucleotide`: RefSeq nucleotide IDs (NM_) + +**Gene Nomenclature:** +- `HGNC`: Human Gene Nomenclature Committee +- `MGI`: Mouse Genome Informatics +- `RGD`: Rat Genome Database +- `SGD`: Saccharomyces Genome Database +- `FlyBase`: Drosophila database +- `WormBase`: C. elegans database +- `ZFIN`: Zebrafish database + +**Structure:** +- `PDB`: Protein Data Bank +- `Pfam`: Protein families +- `InterPro`: Protein domains +- `SUPFAM`: Superfamily +- `PROSITE`: Protein motifs + +**Pathways & Networks:** +- `Reactome`: Reactome pathways +- `BioCyc`: BioCyc pathways +- `PathwayCommons`: Pathway Commons +- `STRING`: Protein-protein networks +- `BioGRID`: Interaction database + +### Mapping Examples + +#### UniProt → KEGG + +```python +from bioservices import UniProt + +u = UniProt() + +# Single mapping +result = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query="P43403") +print(result) # {'P43403': ['hsa:7535']} +``` + +#### KEGG → UniProt + +```python +# Reverse mapping +result = u.mapping(fr="KEGG", to="UniProtKB", query="hsa:7535") +print(result) # {'hsa:7535': ['P43403']} +``` + +#### UniProt → Ensembl + +```python +# To Ensembl gene IDs +result = u.mapping(fr="UniProtKB_AC-ID", to="Ensembl", query="P43403") +print(result) # {'P43403': ['ENSG00000115085']} + +# To Ensembl protein IDs +result = u.mapping(fr="UniProtKB_AC-ID", to="Ensembl_Protein", query="P43403") +print(result) # {'P43403': ['ENSP00000381359']} +``` + +#### UniProt → PDB + +```python +# Find 3D structures +result = u.mapping(fr="UniProtKB_AC-ID", to="PDB", query="P04637") +print(result) # {'P04637': ['1A1U', '1AIE', '1C26', ...]} +``` + +#### UniProt → RefSeq + +```python +# Get RefSeq protein IDs +result = u.mapping(fr="UniProtKB_AC-ID", to="RefSeq_Protein", query="P43403") +print(result) # {'P43403': ['NP_001070.2']} +``` + +#### Gene Name → UniProt (via search, then mapping) + +```python +# First search for gene +search_result = u.search("gene:ZAP70 AND organism:9606", frmt="tab", columns="id") +lines = search_result.strip().split("\n") +if len(lines) > 1: + uniprot_id = lines[1].split("\t")[0] + + # Then map to other databases + kegg_id = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id) + print(kegg_id) +``` + +--- + +## UniChem Compound Mapping + +UniChem specializes in mapping chemical compound identifiers across databases. + +### Source Database IDs + +| Source ID | Database | +|-----------|----------| +| 1 | ChEMBL | +| 2 | DrugBank | +| 3 | PDB | +| 4 | IUPHAR/BPS Guide to Pharmacology | +| 5 | PubChem | +| 6 | KEGG | +| 7 | ChEBI | +| 8 | NIH Clinical Collection | +| 14 | FDA/SRS | +| 22 | PubChem | + +### Basic Usage + +```python +from bioservices import UniChem + +u = UniChem() + +# Get ChEMBL ID from KEGG compound ID +chembl_id = u.get_compound_id_from_kegg("C11222") +print(chembl_id) # CHEMBL278315 +``` + +### All Compound IDs + +```python +# Get all identifiers for a compound +# src_compound_id: compound ID, src_id: source database ID +all_ids = u.get_all_compound_ids("CHEMBL278315", src_id=1) # 1 = ChEMBL + +for mapping in all_ids: + src_name = mapping['src_name'] + src_compound_id = mapping['src_compound_id'] + print(f"{src_name}: {src_compound_id}") +``` + +### Specific Database Conversion + +```python +# Convert between specific databases +# from_src_id=6 (KEGG), to_src_id=1 (ChEMBL) +result = u.get_src_compound_ids("C11222", from_src_id=6, to_src_id=1) +print(result) +``` + +### Common Compound Mappings + +#### KEGG → ChEMBL + +```python +u = UniChem() +chembl_id = u.get_compound_id_from_kegg("C00031") # D-Glucose +print(f"ChEMBL: {chembl_id}") +``` + +#### ChEMBL → PubChem + +```python +result = u.get_src_compound_ids("CHEMBL278315", from_src_id=1, to_src_id=22) +if result: + pubchem_id = result[0]['src_compound_id'] + print(f"PubChem: {pubchem_id}") +``` + +#### ChEBI → DrugBank + +```python +result = u.get_src_compound_ids("5292", from_src_id=7, to_src_id=2) +if result: + drugbank_id = result[0]['src_compound_id'] + print(f"DrugBank: {drugbank_id}") +``` + +--- + +## KEGG Identifier Conversions + +KEGG entries contain cross-references that can be extracted by parsing. + +### Extract Database Links from KEGG Entry + +```python +from bioservices import KEGG + +k = KEGG() + +# Get compound entry +entry = k.get("cpd:C11222") + +# Parse for specific database +chebi_id = None +uniprot_ids = [] + +for line in entry.split("\n"): + if "ChEBI:" in line: + # Extract ChEBI ID + parts = line.split("ChEBI:") + if len(parts) > 1: + chebi_id = parts[1].strip().split()[0] + +# For genes/proteins +gene_entry = k.get("hsa:7535") +for line in gene_entry.split("\n"): + if line.startswith(" "): # Database links section + if "UniProt:" in line: + parts = line.split("UniProt:") + if len(parts) > 1: + uniprot_id = parts[1].strip() + uniprot_ids.append(uniprot_id) +``` + +### KEGG Gene ID Components + +KEGG gene IDs have format `organism:gene_id`: + +```python +kegg_id = "hsa:7535" +organism, gene_id = kegg_id.split(":") + +print(f"Organism: {organism}") # hsa (human) +print(f"Gene ID: {gene_id}") # 7535 +``` + +### KEGG Pathway to Genes + +```python +k = KEGG() + +# Get pathway entry +pathway = k.get("path:hsa04660") + +# Parse for gene list +genes = [] +in_gene_section = False + +for line in pathway.split("\n"): + if line.startswith("GENE"): + in_gene_section = True + + if in_gene_section: + if line.startswith(" " * 12): # Gene line + parts = line.strip().split() + if parts: + gene_id = parts[0] + genes.append(f"hsa:{gene_id}") + elif not line.startswith(" "): + break + +print(f"Found {len(genes)} genes") +``` + +--- + +## Common Mapping Patterns + +### Pattern 1: Gene Symbol → Multiple Database IDs + +```python +from bioservices import UniProt + +def gene_symbol_to_ids(gene_symbol, organism="9606"): + """Convert gene symbol to multiple database IDs.""" + u = UniProt() + + # Search for gene + query = f"gene:{gene_symbol} AND organism:{organism}" + result = u.search(query, frmt="tab", columns="id") + + lines = result.strip().split("\n") + if len(lines) < 2: + return None + + uniprot_id = lines[1].split("\t")[0] + + # Map to multiple databases + ids = { + 'uniprot': uniprot_id, + 'kegg': u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id), + 'ensembl': u.mapping(fr="UniProtKB_AC-ID", to="Ensembl", query=uniprot_id), + 'refseq': u.mapping(fr="UniProtKB_AC-ID", to="RefSeq_Protein", query=uniprot_id), + 'pdb': u.mapping(fr="UniProtKB_AC-ID", to="PDB", query=uniprot_id) + } + + return ids + +# Usage +ids = gene_symbol_to_ids("ZAP70") +print(ids) +``` + +### Pattern 2: Compound Name → All Database IDs + +```python +from bioservices import KEGG, UniChem, ChEBI + +def compound_name_to_ids(compound_name): + """Search compound and get all database IDs.""" + k = KEGG() + + # Search KEGG + results = k.find("compound", compound_name) + if not results: + return None + + # Extract KEGG ID + kegg_id = results.strip().split("\n")[0].split("\t")[0].replace("cpd:", "") + + # Get KEGG entry for ChEBI + entry = k.get(f"cpd:{kegg_id}") + chebi_id = None + for line in entry.split("\n"): + if "ChEBI:" in line: + parts = line.split("ChEBI:") + if len(parts) > 1: + chebi_id = parts[1].strip().split()[0] + break + + # Get ChEMBL from UniChem + u = UniChem() + try: + chembl_id = u.get_compound_id_from_kegg(kegg_id) + except: + chembl_id = None + + return { + 'kegg': kegg_id, + 'chebi': chebi_id, + 'chembl': chembl_id + } + +# Usage +ids = compound_name_to_ids("Geldanamycin") +print(ids) +``` + +### Pattern 3: Batch ID Conversion with Error Handling + +```python +from bioservices import UniProt + +def safe_batch_mapping(ids, from_db, to_db, chunk_size=100): + """Safely map IDs with error handling and chunking.""" + u = UniProt() + all_results = {} + + for i in range(0, len(ids), chunk_size): + chunk = ids[i:i+chunk_size] + query = ",".join(chunk) + + try: + results = u.mapping(fr=from_db, to=to_db, query=query) + all_results.update(results) + print(f"✓ Processed {min(i+chunk_size, len(ids))}/{len(ids)}") + + except Exception as e: + print(f"✗ Error at chunk {i}: {e}") + + # Try individual IDs in failed chunk + for single_id in chunk: + try: + result = u.mapping(fr=from_db, to=to_db, query=single_id) + all_results.update(result) + except: + all_results[single_id] = None + + return all_results + +# Usage +uniprot_ids = ["P43403", "P04637", "P53779", "INVALID123"] +mapping = safe_batch_mapping(uniprot_ids, "UniProtKB_AC-ID", "KEGG") +``` + +### Pattern 4: Multi-Hop Mapping + +Sometimes you need to map through intermediate databases: + +```python +from bioservices import UniProt + +def multi_hop_mapping(gene_symbol, organism="9606"): + """Gene symbol → UniProt → KEGG → Pathways.""" + u = UniProt() + k = KEGG() + + # Step 1: Gene symbol → UniProt + query = f"gene:{gene_symbol} AND organism:{organism}" + result = u.search(query, frmt="tab", columns="id") + + lines = result.strip().split("\n") + if len(lines) < 2: + return None + + uniprot_id = lines[1].split("\t")[0] + + # Step 2: UniProt → KEGG + kegg_mapping = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id) + if not kegg_mapping or uniprot_id not in kegg_mapping: + return None + + kegg_id = kegg_mapping[uniprot_id][0] + + # Step 3: KEGG → Pathways + organism_code, gene_id = kegg_id.split(":") + pathways = k.get_pathway_by_gene(gene_id, organism_code) + + return { + 'gene': gene_symbol, + 'uniprot': uniprot_id, + 'kegg': kegg_id, + 'pathways': pathways + } + +# Usage +result = multi_hop_mapping("TP53") +print(result) +``` + +--- + +## Troubleshooting + +### Issue 1: No Mapping Found + +**Symptom:** Mapping returns empty or None + +**Solutions:** +1. Verify source ID exists in source database +2. Check database code spelling +3. Try reverse mapping +4. Some IDs may not have mappings in all databases + +```python +result = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query="P43403") + +if not result or 'P43403' not in result: + print("No mapping found. Try:") + print("1. Verify ID exists: u.search('P43403')") + print("2. Check if protein has KEGG annotation") +``` + +### Issue 2: Too Many IDs in Batch + +**Symptom:** Batch mapping fails or times out + +**Solution:** Split into smaller chunks + +```python +def chunked_mapping(ids, from_db, to_db, chunk_size=50): + all_results = {} + + for i in range(0, len(ids), chunk_size): + chunk = ids[i:i+chunk_size] + result = u.mapping(fr=from_db, to=to_db, query=",".join(chunk)) + all_results.update(result) + + return all_results +``` + +### Issue 3: Multiple Target IDs + +**Symptom:** One source ID maps to multiple target IDs + +**Solution:** Handle as list + +```python +result = u.mapping(fr="UniProtKB_AC-ID", to="PDB", query="P04637") +# Result: {'P04637': ['1A1U', '1AIE', '1C26', ...]} + +pdb_ids = result['P04637'] +print(f"Found {len(pdb_ids)} PDB structures") + +for pdb_id in pdb_ids: + print(f" {pdb_id}") +``` + +### Issue 4: Organism Ambiguity + +**Symptom:** Gene symbol maps to multiple organisms + +**Solution:** Always specify organism in searches + +```python +# Bad: Ambiguous +result = u.search("gene:TP53") # Many organisms have TP53 + +# Good: Specific +result = u.search("gene:TP53 AND organism:9606") # Human only +``` + +### Issue 5: Deprecated IDs + +**Symptom:** Old database IDs don't map + +**Solution:** Update to current IDs first + +```python +# Check if ID is current +entry = u.retrieve("P43403", frmt="txt") + +# Look for secondary accessions +for line in entry.split("\n"): + if line.startswith("AC"): + print(line) # Shows primary and secondary accessions +``` + +--- + +## Best Practices + +1. **Always validate inputs** before batch processing +2. **Handle None/empty results** gracefully +3. **Use chunking** for large ID lists (50-100 per chunk) +4. **Cache results** for repeated queries +5. **Specify organism** when possible to avoid ambiguity +6. **Log failures** in batch processing for later retry +7. **Add delays** between large batches to respect API limits + +```python +import time + +def polite_batch_mapping(ids, from_db, to_db): + """Batch mapping with rate limiting.""" + results = {} + + for i in range(0, len(ids), 50): + chunk = ids[i:i+50] + result = u.mapping(fr=from_db, to=to_db, query=",".join(chunk)) + results.update(result) + + time.sleep(0.5) # Be nice to the API + + return results +``` + +--- + +For complete working examples, see: +- `scripts/batch_id_converter.py`: Command-line batch conversion tool +- `workflow_patterns.md`: Integration into larger workflows diff --git a/scientific-packages/bioservices/references/services_reference.md b/scientific-packages/bioservices/references/services_reference.md new file mode 100644 index 0000000..26baf71 --- /dev/null +++ b/scientific-packages/bioservices/references/services_reference.md @@ -0,0 +1,634 @@ +# BioServices: Complete Services Reference + +This document provides a comprehensive reference for all major services available in BioServices, including key methods, parameters, and use cases. + +## Protein & Gene Resources + +### UniProt + +Protein sequence and functional information database. + +**Initialization:** +```python +from bioservices import UniProt +u = UniProt(verbose=False) +``` + +**Key Methods:** + +- `search(query, frmt="tab", columns=None, limit=None, sort=None, compress=False, include=False, **kwargs)` + - Search UniProt with flexible query syntax + - `frmt`: "tab", "fasta", "xml", "rdf", "gff", "txt" + - `columns`: Comma-separated list (e.g., "id,genes,organism,length") + - Returns: String in requested format + +- `retrieve(uniprot_id, frmt="txt")` + - Retrieve specific UniProt entry + - `frmt`: "txt", "fasta", "xml", "rdf", "gff" + - Returns: Entry data in requested format + +- `mapping(fr="UniProtKB_AC-ID", to="KEGG", query="P43403")` + - Convert identifiers between databases + - `fr`/`to`: Database identifiers (see identifier_mapping.md) + - `query`: Single ID or comma-separated list + - Returns: Dictionary mapping input to output IDs + +- `searchUniProtId(pattern, columns="entry name,length,organism", limit=100)` + - Convenience method for ID-based searches + - Returns: Tab-separated values + +**Common columns:** id, entry name, genes, organism, protein names, length, sequence, go-id, ec, pathway, interactor + +**Use cases:** +- Protein sequence retrieval for BLAST +- Functional annotation lookup +- Cross-database identifier mapping +- Batch protein information retrieval + +--- + +### KEGG (Kyoto Encyclopedia of Genes and Genomes) + +Metabolic pathways, genes, and organisms database. + +**Initialization:** +```python +from bioservices import KEGG +k = KEGG() +k.organism = "hsa" # Set default organism +``` + +**Key Methods:** + +- `list(database)` + - List entries in KEGG database + - `database`: "organism", "pathway", "module", "disease", "drug", "compound" + - Returns: Multi-line string with entries + +- `find(database, query)` + - Search database by keywords + - Returns: List of matching entries with IDs + +- `get(entry_id)` + - Retrieve entry by ID + - Supports genes, pathways, compounds, etc. + - Returns: Raw entry text + +- `parse(data)` + - Parse KEGG entry into dictionary + - Returns: Dict with structured data + +- `lookfor_organism(name)` + - Search organisms by name pattern + - Returns: List of matching organism codes + +- `lookfor_pathway(name)` + - Search pathways by name + - Returns: List of pathway IDs + +- `get_pathway_by_gene(gene_id, organism)` + - Find pathways containing gene + - Returns: List of pathway IDs + +- `parse_kgml_pathway(pathway_id)` + - Parse pathway KGML for interactions + - Returns: Dict with "entries" and "relations" + +- `pathway2sif(pathway_id)` + - Extract Simple Interaction Format data + - Filters for activation/inhibition + - Returns: List of interaction tuples + +**Organism codes:** +- hsa: Homo sapiens +- mmu: Mus musculus +- dme: Drosophila melanogaster +- sce: Saccharomyces cerevisiae +- eco: Escherichia coli + +**Use cases:** +- Pathway analysis and visualization +- Gene function annotation +- Metabolic network reconstruction +- Protein-protein interaction extraction + +--- + +### HGNC (Human Gene Nomenclature Committee) + +Official human gene naming authority. + +**Initialization:** +```python +from bioservices import HGNC +h = HGNC() +``` + +**Key Methods:** +- `search(query)`: Search gene symbols/names +- `fetch(format, query)`: Retrieve gene information + +**Use cases:** +- Standardizing human gene names +- Looking up official gene symbols + +--- + +### MyGeneInfo + +Gene annotation and query service. + +**Initialization:** +```python +from bioservices import MyGeneInfo +m = MyGeneInfo() +``` + +**Key Methods:** +- `querymany(ids, scopes, fields, species)`: Batch gene queries +- `getgene(geneid)`: Get gene annotation + +**Use cases:** +- Batch gene annotation retrieval +- Gene ID conversion + +--- + +## Chemical Compound Resources + +### ChEBI (Chemical Entities of Biological Interest) + +Dictionary of molecular entities. + +**Initialization:** +```python +from bioservices import ChEBI +c = ChEBI() +``` + +**Key Methods:** +- `getCompleteEntity(chebi_id)`: Full compound information +- `getLiteEntity(chebi_id)`: Basic information +- `getCompleteEntityByList(chebi_ids)`: Batch retrieval + +**Use cases:** +- Small molecule information +- Chemical structure data +- Compound property lookup + +--- + +### ChEMBL + +Bioactive drug-like compound database. + +**Initialization:** +```python +from bioservices import ChEMBL +c = ChEMBL() +``` + +**Key Methods:** +- `get_compound_by_chemblId(chembl_id)`: Compound details +- `get_target_by_chemblId(chembl_id)`: Target information +- `get_assays()`: Bioassay data + +**Use cases:** +- Drug discovery data +- Bioactivity information +- Target-compound relationships + +--- + +### UniChem + +Chemical identifier mapping service. + +**Initialization:** +```python +from bioservices import UniChem +u = UniChem() +``` + +**Key Methods:** +- `get_compound_id_from_kegg(kegg_id)`: KEGG → ChEMBL +- `get_all_compound_ids(src_compound_id, src_id)`: Get all IDs +- `get_src_compound_ids(src_compound_id, from_src_id, to_src_id)`: Convert IDs + +**Source IDs:** +- 1: ChEMBL +- 2: DrugBank +- 3: PDB +- 6: KEGG +- 7: ChEBI +- 22: PubChem + +**Use cases:** +- Cross-database compound ID mapping +- Linking chemical databases + +--- + +### PubChem + +Chemical compound database from NIH. + +**Initialization:** +```python +from bioservices import PubChem +p = PubChem() +``` + +**Key Methods:** +- `get_compounds(identifier, namespace)`: Retrieve compounds +- `get_properties(properties, identifier, namespace)`: Get properties + +**Use cases:** +- Chemical structure retrieval +- Compound property information + +--- + +## Sequence Analysis Tools + +### NCBIblast + +Sequence similarity searching. + +**Initialization:** +```python +from bioservices import NCBIblast +s = NCBIblast(verbose=False) +``` + +**Key Methods:** +- `run(program, sequence, stype, database, email, **params)` + - Submit BLAST job + - `program`: "blastp", "blastn", "blastx", "tblastn", "tblastx" + - `stype`: "protein" or "dna" + - `database`: "uniprotkb", "pdb", "refseq_protein", etc. + - `email`: Required by NCBI + - Returns: Job ID + +- `getStatus(jobid)` + - Check job status + - Returns: "RUNNING", "FINISHED", "ERROR" + +- `getResult(jobid, result_type)` + - Retrieve results + - `result_type`: "out" (default), "ids", "xml" + +**Important:** BLAST jobs are asynchronous. Always check status before retrieving results. + +**Use cases:** +- Protein homology searches +- Sequence similarity analysis +- Functional annotation by homology + +--- + +## Pathway & Interaction Resources + +### Reactome + +Pathway database. + +**Initialization:** +```python +from bioservices import Reactome +r = Reactome() +``` + +**Key Methods:** +- `get_pathway_by_id(pathway_id)`: Pathway details +- `search_pathway(query)`: Search pathways + +**Use cases:** +- Human pathway analysis +- Biological process annotation + +--- + +### PSICQUIC + +Protein interaction query service (federates 30+ databases). + +**Initialization:** +```python +from bioservices import PSICQUIC +s = PSICQUIC() +``` + +**Key Methods:** +- `query(database, query_string)` + - Query specific interaction database + - Returns: PSI-MI TAB format + +- `activeDBs` + - Property listing available databases + - Returns: List of database names + +**Available databases:** MINT, IntAct, BioGRID, DIP, InnateDB, MatrixDB, MPIDB, UniProt, and 30+ more + +**Query syntax:** Supports AND, OR, species filters +- Example: "ZAP70 AND species:9606" + +**Use cases:** +- Protein-protein interaction discovery +- Network analysis +- Interactome mapping + +--- + +### IntactComplex + +Protein complex database. + +**Initialization:** +```python +from bioservices import IntactComplex +i = IntactComplex() +``` + +**Key Methods:** +- `search(query)`: Search complexes +- `details(complex_ac)`: Complex details + +**Use cases:** +- Protein complex composition +- Multi-protein assembly analysis + +--- + +### OmniPath + +Integrated signaling pathway database. + +**Initialization:** +```python +from bioservices import OmniPath +o = OmniPath() +``` + +**Key Methods:** +- `interactions(datasets, organisms)`: Get interactions +- `ptms(datasets, organisms)`: Post-translational modifications + +**Use cases:** +- Cell signaling analysis +- Regulatory network mapping + +--- + +## Gene Ontology + +### QuickGO + +Gene Ontology annotation service. + +**Initialization:** +```python +from bioservices import QuickGO +g = QuickGO() +``` + +**Key Methods:** +- `Term(go_id, frmt="obo")` + - Retrieve GO term information + - Returns: Term definition and metadata + +- `Annotation(protein=None, goid=None, format="tsv")` + - Get GO annotations + - Returns: Annotations in requested format + +**GO categories:** +- Biological Process (BP) +- Molecular Function (MF) +- Cellular Component (CC) + +**Use cases:** +- Functional annotation +- Enrichment analysis +- GO term lookup + +--- + +## Genomic Resources + +### BioMart + +Data mining tool for genomic data. + +**Initialization:** +```python +from bioservices import BioMart +b = BioMart() +``` + +**Key Methods:** +- `datasets(dataset)`: List available datasets +- `attributes(dataset)`: List attributes +- `query(query_xml)`: Execute BioMart query + +**Use cases:** +- Bulk genomic data retrieval +- Custom genome annotations +- SNP information + +--- + +### ArrayExpress + +Gene expression database. + +**Initialization:** +```python +from bioservices import ArrayExpress +a = ArrayExpress() +``` + +**Key Methods:** +- `queryExperiments(keywords)`: Search experiments +- `retrieveExperiment(accession)`: Get experiment data + +**Use cases:** +- Gene expression data +- Microarray analysis +- RNA-seq data retrieval + +--- + +### ENA (European Nucleotide Archive) + +Nucleotide sequence database. + +**Initialization:** +```python +from bioservices import ENA +e = ENA() +``` + +**Key Methods:** +- `search_data(query)`: Search sequences +- `retrieve_data(accession)`: Retrieve sequences + +**Use cases:** +- Nucleotide sequence retrieval +- Genome assembly access + +--- + +## Structural Biology + +### PDB (Protein Data Bank) + +3D protein structure database. + +**Initialization:** +```python +from bioservices import PDB +p = PDB() +``` + +**Key Methods:** +- `get_file(pdb_id, file_format)`: Download structure files +- `search(query)`: Search structures + +**File formats:** pdb, cif, xml + +**Use cases:** +- 3D structure retrieval +- Structure-based analysis +- PyMOL visualization + +--- + +### Pfam + +Protein family database. + +**Initialization:** +```python +from bioservices import Pfam +p = Pfam() +``` + +**Key Methods:** +- `searchSequence(sequence)`: Find domains in sequence +- `getPfamEntry(pfam_id)`: Domain information + +**Use cases:** +- Protein domain identification +- Family classification +- Functional motif discovery + +--- + +## Specialized Resources + +### BioModels + +Systems biology model repository. + +**Initialization:** +```python +from bioservices import BioModels +b = BioModels() +``` + +**Key Methods:** +- `get_model_by_id(model_id)`: Retrieve SBML model + +**Use cases:** +- Systems biology modeling +- SBML model retrieval + +--- + +### COG (Clusters of Orthologous Genes) + +Orthologous gene classification. + +**Initialization:** +```python +from bioservices import COG +c = COG() +``` + +**Use cases:** +- Orthology analysis +- Functional classification + +--- + +### BiGG Models + +Metabolic network models. + +**Initialization:** +```python +from bioservices import BiGG +b = BiGG() +``` + +**Key Methods:** +- `list_models()`: Available models +- `get_model(model_id)`: Model details + +**Use cases:** +- Metabolic network analysis +- Flux balance analysis + +--- + +## General Patterns + +### Error Handling + +All services may throw exceptions. Wrap calls in try-except: + +```python +try: + result = service.method(params) + if result: + # Process result + pass +except Exception as e: + print(f"Error: {e}") +``` + +### Verbosity Control + +Most services support `verbose` parameter: +```python +service = Service(verbose=False) # Suppress HTTP logs +``` + +### Rate Limiting + +Services have timeouts and rate limits: +```python +service.TIMEOUT = 30 # Adjust timeout +service.DELAY = 1 # Delay between requests (if supported) +``` + +### Output Formats + +Common format parameters: +- `frmt`: "xml", "json", "tab", "txt", "fasta" +- `format`: Service-specific variants + +### Caching + +Some services cache results: +```python +service.CACHE = True # Enable caching +service.clear_cache() # Clear cache +``` + +## Additional Resources + +For detailed API documentation: +- Official docs: https://bioservices.readthedocs.io/ +- Individual service docs linked from main page +- Source code: https://github.com/cokelaer/bioservices diff --git a/scientific-packages/bioservices/references/workflow_patterns.md b/scientific-packages/bioservices/references/workflow_patterns.md new file mode 100644 index 0000000..0e79ef8 --- /dev/null +++ b/scientific-packages/bioservices/references/workflow_patterns.md @@ -0,0 +1,811 @@ +# BioServices: Common Workflow Patterns + +This document describes detailed multi-step workflows for common bioinformatics tasks using BioServices. + +## Table of Contents + +1. [Complete Protein Analysis Pipeline](#complete-protein-analysis-pipeline) +2. [Pathway Discovery and Network Analysis](#pathway-discovery-and-network-analysis) +3. [Compound Multi-Database Search](#compound-multi-database-search) +4. [Batch Identifier Conversion](#batch-identifier-conversion) +5. [Gene Functional Annotation](#gene-functional-annotation) +6. [Protein Interaction Network Construction](#protein-interaction-network-construction) +7. [Multi-Organism Comparative Analysis](#multi-organism-comparative-analysis) + +--- + +## Complete Protein Analysis Pipeline + +**Goal:** Given a protein name, retrieve sequence, find homologs, identify pathways, and discover interactions. + +**Example:** Analyzing human ZAP70 protein + +### Step 1: UniProt Search and Identifier Retrieval + +```python +from bioservices import UniProt + +u = UniProt(verbose=False) + +# Search for protein by name +query = "ZAP70_HUMAN" +results = u.search(query, frmt="tab", columns="id,genes,organism,length") + +# Parse results +lines = results.strip().split("\n") +if len(lines) > 1: + header = lines[0] + data = lines[1].split("\t") + uniprot_id = data[0] # e.g., P43403 + gene_names = data[1] # e.g., ZAP70 + +print(f"UniProt ID: {uniprot_id}") +print(f"Gene names: {gene_names}") +``` + +**Output:** +- UniProt accession: P43403 +- Gene name: ZAP70 + +### Step 2: Sequence Retrieval + +```python +# Retrieve FASTA sequence +sequence = u.retrieve(uniprot_id, frmt="fasta") +print(sequence) + +# Extract just the sequence string (remove header) +seq_lines = sequence.split("\n") +sequence_only = "".join(seq_lines[1:]) # Skip FASTA header +``` + +**Output:** Complete protein sequence in FASTA format + +### Step 3: BLAST Similarity Search + +```python +from bioservices import NCBIblast +import time + +s = NCBIblast(verbose=False) + +# Submit BLAST job +jobid = s.run( + program="blastp", + sequence=sequence_only, + stype="protein", + database="uniprotkb", + email="your.email@example.com" +) + +print(f"BLAST Job ID: {jobid}") + +# Wait for completion +while True: + status = s.getStatus(jobid) + print(f"Status: {status}") + if status == "FINISHED": + break + elif status == "ERROR": + print("BLAST job failed") + break + time.sleep(5) + +# Retrieve results +if status == "FINISHED": + blast_results = s.getResult(jobid, "out") + print(blast_results[:500]) # Print first 500 characters +``` + +**Output:** BLAST alignment results showing similar proteins + +### Step 4: KEGG Pathway Discovery + +```python +from bioservices import KEGG + +k = KEGG() + +# Get KEGG gene ID from UniProt mapping +kegg_mapping = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id) +print(f"KEGG mapping: {kegg_mapping}") + +# Extract KEGG gene ID (e.g., hsa:7535) +if kegg_mapping: + kegg_gene_id = kegg_mapping[uniprot_id][0] if uniprot_id in kegg_mapping else None + + if kegg_gene_id: + # Find pathways containing this gene + organism = kegg_gene_id.split(":")[0] # e.g., "hsa" + gene_id = kegg_gene_id.split(":")[1] # e.g., "7535" + + pathways = k.get_pathway_by_gene(gene_id, organism) + print(f"Found {len(pathways)} pathways:") + + # Get pathway names + for pathway_id in pathways: + pathway_info = k.get(pathway_id) + # Parse NAME line + for line in pathway_info.split("\n"): + if line.startswith("NAME"): + pathway_name = line.replace("NAME", "").strip() + print(f" {pathway_id}: {pathway_name}") + break +``` + +**Output:** +- path:hsa04064 - NF-kappa B signaling pathway +- path:hsa04650 - Natural killer cell mediated cytotoxicity +- path:hsa04660 - T cell receptor signaling pathway +- path:hsa04662 - B cell receptor signaling pathway + +### Step 5: Protein-Protein Interactions + +```python +from bioservices import PSICQUIC + +p = PSICQUIC() + +# Query MINT database for human (taxid:9606) interactions +query = f"ZAP70 AND species:9606" +interactions = p.query("mint", query) + +# Parse PSI-MI TAB format results +if interactions: + interaction_lines = interactions.strip().split("\n") + print(f"Found {len(interaction_lines)} interactions") + + # Print first few interactions + for line in interaction_lines[:5]: + fields = line.split("\t") + protein_a = fields[0] + protein_b = fields[1] + interaction_type = fields[11] + print(f" {protein_a} - {protein_b}: {interaction_type}") +``` + +**Output:** List of proteins that interact with ZAP70 + +### Step 6: Gene Ontology Annotation + +```python +from bioservices import QuickGO + +g = QuickGO() + +# Get GO annotations for protein +annotations = g.Annotation(protein=uniprot_id, format="tsv") + +if annotations: + # Parse TSV results + lines = annotations.strip().split("\n") + print(f"Found {len(lines)-1} GO annotations") + + # Display first few annotations + for line in lines[1:6]: # Skip header + fields = line.split("\t") + go_id = fields[6] + go_term = fields[7] + go_aspect = fields[8] + print(f" {go_id}: {go_term} [{go_aspect}]") +``` + +**Output:** GO terms annotating ZAP70 function, process, and location + +### Complete Pipeline Summary + +**Inputs:** Protein name (e.g., "ZAP70_HUMAN") + +**Outputs:** +1. UniProt accession and gene name +2. Protein sequence (FASTA) +3. Similar proteins (BLAST results) +4. Biological pathways (KEGG) +5. Interaction partners (PSICQUIC) +6. Functional annotations (GO terms) + +**Script:** `scripts/protein_analysis_workflow.py` automates this entire pipeline. + +--- + +## Pathway Discovery and Network Analysis + +**Goal:** Analyze all pathways for an organism and extract protein interaction networks. + +**Example:** Human (hsa) pathway analysis + +### Step 1: Get All Pathways for Organism + +```python +from bioservices import KEGG + +k = KEGG() +k.organism = "hsa" + +# Get all pathway IDs +pathway_ids = k.pathwayIds +print(f"Found {len(pathway_ids)} pathways for {k.organism}") + +# Display first few +for pid in pathway_ids[:10]: + print(f" {pid}") +``` + +**Output:** List of ~300 human pathways + +### Step 2: Parse Pathway for Interactions + +```python +# Analyze specific pathway +pathway_id = "hsa04660" # T cell receptor signaling + +# Get KGML data +kgml_data = k.parse_kgml_pathway(pathway_id) + +# Extract entries (genes/proteins) +entries = kgml_data['entries'] +print(f"Pathway contains {len(entries)} entries") + +# Extract relations (interactions) +relations = kgml_data['relations'] +print(f"Found {len(relations)} relations") + +# Analyze relation types +relation_types = {} +for rel in relations: + rel_type = rel.get('name', 'unknown') + relation_types[rel_type] = relation_types.get(rel_type, 0) + 1 + +print("\nRelation type distribution:") +for rel_type, count in sorted(relation_types.items()): + print(f" {rel_type}: {count}") +``` + +**Output:** +- Entry count (genes/proteins in pathway) +- Relation count (interactions) +- Distribution of interaction types (activation, inhibition, binding, etc.) + +### Step 3: Extract Protein-Protein Interactions + +```python +# Filter for specific interaction types +pprel_interactions = [ + rel for rel in relations + if rel.get('link') == 'PPrel' # Protein-protein relation +] + +print(f"Found {len(pprel_interactions)} protein-protein interactions") + +# Extract interaction details +for rel in pprel_interactions[:10]: + entry1 = rel['entry1'] + entry2 = rel['entry2'] + interaction_type = rel.get('name', 'unknown') + + print(f" {entry1} -> {entry2}: {interaction_type}") +``` + +**Output:** Directed protein-protein interactions with types + +### Step 4: Convert to Network Format (SIF) + +```python +# Get Simple Interaction Format (filters for key interactions) +sif_data = k.pathway2sif(pathway_id) + +# SIF format: source, interaction_type, target +print("\nSimple Interaction Format:") +for interaction in sif_data[:10]: + print(f" {interaction}") +``` + +**Output:** Network edges suitable for Cytoscape or NetworkX + +### Step 5: Batch Analysis of All Pathways + +```python +import pandas as pd + +# Analyze all pathways (this takes time!) +all_results = [] + +for pathway_id in pathway_ids[:50]: # Limit for example + try: + kgml = k.parse_kgml_pathway(pathway_id) + + result = { + 'pathway_id': pathway_id, + 'num_entries': len(kgml.get('entries', [])), + 'num_relations': len(kgml.get('relations', [])) + } + + all_results.append(result) + + except Exception as e: + print(f"Error parsing {pathway_id}: {e}") + +# Create DataFrame +df = pd.DataFrame(all_results) +print(df.describe()) + +# Find largest pathways +print("\nLargest pathways:") +print(df.nlargest(10, 'num_entries')[['pathway_id', 'num_entries', 'num_relations']]) +``` + +**Output:** Statistical summary of pathway sizes and interaction densities + +**Script:** `scripts/pathway_analysis.py` implements this workflow with export options. + +--- + +## Compound Multi-Database Search + +**Goal:** Search for compound by name and retrieve identifiers across KEGG, ChEBI, and ChEMBL. + +**Example:** Geldanamycin (antibiotic) + +### Step 1: Search KEGG Compound Database + +```python +from bioservices import KEGG + +k = KEGG() + +# Search by compound name +compound_name = "Geldanamycin" +results = k.find("compound", compound_name) + +print(f"KEGG search results for '{compound_name}':") +print(results) + +# Extract compound ID +if results: + lines = results.strip().split("\n") + if lines: + kegg_id = lines[0].split("\t")[0] # e.g., cpd:C11222 + kegg_id_clean = kegg_id.replace("cpd:", "") # C11222 + print(f"\nKEGG Compound ID: {kegg_id_clean}") +``` + +**Output:** KEGG ID (e.g., C11222) + +### Step 2: Get KEGG Entry with Database Links + +```python +# Retrieve compound entry +compound_entry = k.get(kegg_id) + +# Parse entry for database links +chebi_id = None +for line in compound_entry.split("\n"): + if "ChEBI:" in line: + # Extract ChEBI ID + parts = line.split("ChEBI:") + if len(parts) > 1: + chebi_id = parts[1].strip().split()[0] + print(f"ChEBI ID: {chebi_id}") + break + +# Display entry snippet +print("\nKEGG Entry (first 500 chars):") +print(compound_entry[:500]) +``` + +**Output:** ChEBI ID (e.g., 5292) and compound information + +### Step 3: Cross-Reference to ChEMBL via UniChem + +```python +from bioservices import UniChem + +u = UniChem() + +# Convert KEGG → ChEMBL +try: + chembl_id = u.get_compound_id_from_kegg(kegg_id_clean) + print(f"ChEMBL ID: {chembl_id}") +except Exception as e: + print(f"UniChem lookup failed: {e}") + chembl_id = None +``` + +**Output:** ChEMBL ID (e.g., CHEMBL278315) + +### Step 4: Retrieve Detailed Information + +```python +# Get ChEBI information +if chebi_id: + from bioservices import ChEBI + c = ChEBI() + + try: + chebi_entity = c.getCompleteEntity(f"CHEBI:{chebi_id}") + print(f"\nChEBI Formula: {chebi_entity.Formulae}") + print(f"ChEBI Name: {chebi_entity.chebiAsciiName}") + except Exception as e: + print(f"ChEBI lookup failed: {e}") + +# Get ChEMBL information +if chembl_id: + from bioservices import ChEMBL + chembl = ChEMBL() + + try: + chembl_compound = chembl.get_compound_by_chemblId(chembl_id) + print(f"\nChEMBL Molecular Weight: {chembl_compound['molecule_properties']['full_mwt']}") + print(f"ChEMBL SMILES: {chembl_compound['molecule_structures']['canonical_smiles']}") + except Exception as e: + print(f"ChEMBL lookup failed: {e}") +``` + +**Output:** Chemical properties from multiple databases + +### Complete Compound Workflow Summary + +**Input:** Compound name (e.g., "Geldanamycin") + +**Output:** +- KEGG ID: C11222 +- ChEBI ID: 5292 +- ChEMBL ID: CHEMBL278315 +- Chemical formula +- Molecular weight +- SMILES structure + +**Script:** `scripts/compound_cross_reference.py` automates this workflow. + +--- + +## Batch Identifier Conversion + +**Goal:** Convert multiple identifiers between databases efficiently. + +### Batch UniProt → KEGG Mapping + +```python +from bioservices import UniProt + +u = UniProt() + +# List of UniProt IDs +uniprot_ids = ["P43403", "P04637", "P53779", "Q9Y6K9"] + +# Batch mapping (comma-separated) +query_string = ",".join(uniprot_ids) +results = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=query_string) + +print("UniProt → KEGG mapping:") +for uniprot_id, kegg_ids in results.items(): + print(f" {uniprot_id} → {kegg_ids}") +``` + +**Output:** Dictionary mapping each UniProt ID to KEGG gene IDs + +### Batch File Processing + +```python +import csv + +# Read identifiers from file +def read_ids_from_file(filename): + with open(filename, 'r') as f: + ids = [line.strip() for line in f if line.strip()] + return ids + +# Process in chunks (API limits) +def batch_convert(ids, from_db, to_db, chunk_size=100): + u = UniProt() + all_results = {} + + for i in range(0, len(ids), chunk_size): + chunk = ids[i:i+chunk_size] + query = ",".join(chunk) + + try: + results = u.mapping(fr=from_db, to=to_db, query=query) + all_results.update(results) + print(f"Processed {min(i+chunk_size, len(ids))}/{len(ids)}") + except Exception as e: + print(f"Error processing chunk {i}: {e}") + + return all_results + +# Write results to CSV +def write_mapping_to_csv(mapping, output_file): + with open(output_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Source_ID', 'Target_IDs']) + + for source_id, target_ids in mapping.items(): + target_str = ";".join(target_ids) if target_ids else "No mapping" + writer.writerow([source_id, target_str]) + +# Example usage +input_ids = read_ids_from_file("uniprot_ids.txt") +mapping = batch_convert(input_ids, "UniProtKB_AC-ID", "KEGG", chunk_size=50) +write_mapping_to_csv(mapping, "uniprot_to_kegg_mapping.csv") +``` + +**Script:** `scripts/batch_id_converter.py` provides command-line batch conversion. + +--- + +## Gene Functional Annotation + +**Goal:** Retrieve comprehensive functional information for a gene. + +### Workflow + +```python +from bioservices import UniProt, KEGG, QuickGO + +# Gene of interest +gene_symbol = "TP53" + +# 1. Find UniProt entry +u = UniProt() +search_results = u.search(f"gene:{gene_symbol} AND organism:9606", + frmt="tab", + columns="id,genes,protein names") + +# Extract UniProt ID +lines = search_results.strip().split("\n") +if len(lines) > 1: + uniprot_id = lines[1].split("\t")[0] + protein_name = lines[1].split("\t")[2] + print(f"Protein: {protein_name}") + print(f"UniProt ID: {uniprot_id}") + +# 2. Get KEGG pathways +kegg_mapping = u.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id) +if uniprot_id in kegg_mapping: + kegg_id = kegg_mapping[uniprot_id][0] + + k = KEGG() + organism, gene_id = kegg_id.split(":") + pathways = k.get_pathway_by_gene(gene_id, organism) + + print(f"\nPathways ({len(pathways)}):") + for pathway_id in pathways[:5]: + print(f" {pathway_id}") + +# 3. Get GO annotations +g = QuickGO() +go_annotations = g.Annotation(protein=uniprot_id, format="tsv") + +if go_annotations: + lines = go_annotations.strip().split("\n") + print(f"\nGO Annotations ({len(lines)-1} total):") + + # Group by aspect + aspects = {"P": [], "F": [], "C": []} + for line in lines[1:]: + fields = line.split("\t") + go_aspect = fields[8] # P, F, or C + go_term = fields[7] + aspects[go_aspect].append(go_term) + + print(f" Biological Process: {len(aspects['P'])} terms") + print(f" Molecular Function: {len(aspects['F'])} terms") + print(f" Cellular Component: {len(aspects['C'])} terms") + +# 4. Get protein sequence features +full_entry = u.retrieve(uniprot_id, frmt="txt") +print("\nProtein Features:") +for line in full_entry.split("\n"): + if line.startswith("FT DOMAIN"): + print(f" {line}") +``` + +**Output:** Comprehensive annotation including name, pathways, GO terms, and features. + +--- + +## Protein Interaction Network Construction + +**Goal:** Build a protein-protein interaction network for a set of proteins. + +### Workflow + +```python +from bioservices import PSICQUIC +import networkx as nx + +# Proteins of interest +proteins = ["ZAP70", "LCK", "LAT", "SLP76", "PLCg1"] + +# Initialize PSICQUIC +p = PSICQUIC() + +# Build network +G = nx.Graph() + +for protein in proteins: + # Query for human interactions + query = f"{protein} AND species:9606" + + try: + results = p.query("intact", query) + + if results: + lines = results.strip().split("\n") + + for line in lines: + fields = line.split("\t") + # Extract protein names (simplified) + protein_a = fields[4].split(":")[1] if ":" in fields[4] else fields[4] + protein_b = fields[5].split(":")[1] if ":" in fields[5] else fields[5] + + # Add edge + G.add_edge(protein_a, protein_b) + + except Exception as e: + print(f"Error querying {protein}: {e}") + +print(f"Network: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges") + +# Analyze network +print("\nNode degrees:") +for node in proteins: + if node in G: + print(f" {node}: {G.degree(node)} interactions") + +# Export for visualization +nx.write_gml(G, "protein_network.gml") +print("\nNetwork exported to protein_network.gml") +``` + +**Output:** NetworkX graph exported in GML format for Cytoscape visualization. + +--- + +## Multi-Organism Comparative Analysis + +**Goal:** Compare pathway or gene presence across multiple organisms. + +### Workflow + +```python +from bioservices import KEGG + +k = KEGG() + +# Organisms to compare +organisms = ["hsa", "mmu", "dme", "sce"] # Human, mouse, fly, yeast +organism_names = { + "hsa": "Human", + "mmu": "Mouse", + "dme": "Fly", + "sce": "Yeast" +} + +# Pathway of interest +pathway_name = "cell cycle" + +print(f"Searching for '{pathway_name}' pathway across organisms:\n") + +for org in organisms: + k.organism = org + + # Search pathways + results = k.lookfor_pathway(pathway_name) + + print(f"{organism_names[org]} ({org}):") + if results: + for pathway in results[:3]: # Show first 3 + print(f" {pathway}") + else: + print(" No matches found") + print() +``` + +**Output:** Pathway presence/absence across organisms. + +--- + +## Best Practices for Workflows + +### 1. Error Handling + +Always wrap service calls: +```python +try: + result = service.method(params) + if result: + # Process + pass +except Exception as e: + print(f"Error: {e}") +``` + +### 2. Rate Limiting + +Add delays for batch processing: +```python +import time + +for item in items: + result = service.query(item) + time.sleep(0.5) # 500ms delay +``` + +### 3. Result Validation + +Check for empty or unexpected results: +```python +if result and len(result) > 0: + # Process + pass +else: + print("No results returned") +``` + +### 4. Progress Reporting + +For long workflows: +```python +total = len(items) +for i, item in enumerate(items): + # Process item + if (i + 1) % 10 == 0: + print(f"Processed {i+1}/{total}") +``` + +### 5. Data Export + +Save intermediate results: +```python +import json + +with open("results.json", "w") as f: + json.dump(results, f, indent=2) +``` + +--- + +## Integration with Other Tools + +### BioPython Integration + +```python +from bioservices import UniProt +from Bio import SeqIO +from io import StringIO + +u = UniProt() +fasta_data = u.retrieve("P43403", "fasta") + +# Parse with BioPython +fasta_io = StringIO(fasta_data) +record = SeqIO.read(fasta_io, "fasta") + +print(f"Sequence length: {len(record.seq)}") +print(f"Description: {record.description}") +``` + +### Pandas Integration + +```python +from bioservices import UniProt +import pandas as pd +from io import StringIO + +u = UniProt() +results = u.search("zap70", frmt="tab", columns="id,genes,length,organism") + +# Load into DataFrame +df = pd.read_csv(StringIO(results), sep="\t") +print(df.head()) +print(df.describe()) +``` + +### NetworkX Integration + +See Protein Interaction Network Construction above. + +--- + +For complete working examples, see the scripts in `scripts/` directory. diff --git a/scientific-packages/bioservices/scripts/batch_id_converter.py b/scientific-packages/bioservices/scripts/batch_id_converter.py new file mode 100755 index 0000000..14e3893 --- /dev/null +++ b/scientific-packages/bioservices/scripts/batch_id_converter.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +""" +Batch Identifier Converter + +This script converts multiple identifiers between biological databases +using UniProt's mapping service. Supports batch processing with +automatic chunking and error handling. + +Usage: + python batch_id_converter.py INPUT_FILE --from DB1 --to DB2 [options] + +Examples: + python batch_id_converter.py uniprot_ids.txt --from UniProtKB_AC-ID --to KEGG + python batch_id_converter.py gene_ids.txt --from GeneID --to UniProtKB --output mapping.csv + python batch_id_converter.py ids.txt --from UniProtKB_AC-ID --to Ensembl --chunk-size 50 + +Input file format: + One identifier per line (plain text) + +Common database codes: + UniProtKB_AC-ID - UniProt accession/ID + KEGG - KEGG gene IDs + GeneID - NCBI Gene (Entrez) IDs + Ensembl - Ensembl gene IDs + Ensembl_Protein - Ensembl protein IDs + RefSeq_Protein - RefSeq protein IDs + PDB - Protein Data Bank IDs + HGNC - Human gene symbols + GO - Gene Ontology IDs +""" + +import sys +import argparse +import csv +import time +from bioservices import UniProt + + +# Common database code mappings +DATABASE_CODES = { + 'uniprot': 'UniProtKB_AC-ID', + 'uniprotkb': 'UniProtKB_AC-ID', + 'kegg': 'KEGG', + 'geneid': 'GeneID', + 'entrez': 'GeneID', + 'ensembl': 'Ensembl', + 'ensembl_protein': 'Ensembl_Protein', + 'ensembl_transcript': 'Ensembl_Transcript', + 'refseq': 'RefSeq_Protein', + 'refseq_protein': 'RefSeq_Protein', + 'pdb': 'PDB', + 'hgnc': 'HGNC', + 'mgi': 'MGI', + 'go': 'GO', + 'pfam': 'Pfam', + 'interpro': 'InterPro', + 'reactome': 'Reactome', + 'string': 'STRING', + 'biogrid': 'BioGRID' +} + + +def normalize_database_code(code): + """Normalize database code to official format.""" + # Try exact match first + if code in DATABASE_CODES.values(): + return code + + # Try lowercase lookup + lowercase = code.lower() + if lowercase in DATABASE_CODES: + return DATABASE_CODES[lowercase] + + # Return as-is if not found (may still be valid) + return code + + +def read_ids_from_file(filename): + """Read identifiers from file (one per line).""" + print(f"Reading identifiers from {filename}...") + + ids = [] + with open(filename, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#'): + ids.append(line) + + print(f"✓ Read {len(ids)} identifier(s)") + + return ids + + +def batch_convert(ids, from_db, to_db, chunk_size=100, delay=0.5): + """Convert IDs with automatic chunking and error handling.""" + print(f"\nConverting {len(ids)} IDs:") + print(f" From: {from_db}") + print(f" To: {to_db}") + print(f" Chunk size: {chunk_size}") + print() + + u = UniProt(verbose=False) + all_results = {} + failed_ids = [] + + total_chunks = (len(ids) + chunk_size - 1) // chunk_size + + for i in range(0, len(ids), chunk_size): + chunk = ids[i:i+chunk_size] + chunk_num = (i // chunk_size) + 1 + + query = ",".join(chunk) + + try: + print(f" [{chunk_num}/{total_chunks}] Processing {len(chunk)} IDs...", end=" ") + + results = u.mapping(fr=from_db, to=to_db, query=query) + + if results: + all_results.update(results) + mapped_count = len([v for v in results.values() if v]) + print(f"✓ Mapped: {mapped_count}/{len(chunk)}") + else: + print(f"✗ No mappings returned") + failed_ids.extend(chunk) + + # Rate limiting + if delay > 0 and i + chunk_size < len(ids): + time.sleep(delay) + + except Exception as e: + print(f"✗ Error: {e}") + + # Try individual IDs in failed chunk + print(f" Retrying individual IDs...") + for single_id in chunk: + try: + result = u.mapping(fr=from_db, to=to_db, query=single_id) + if result: + all_results.update(result) + print(f" ✓ {single_id}") + else: + failed_ids.append(single_id) + print(f" ✗ {single_id} - no mapping") + except Exception as e2: + failed_ids.append(single_id) + print(f" ✗ {single_id} - {e2}") + + time.sleep(0.2) + + # Add missing IDs to results (mark as failed) + for id_ in ids: + if id_ not in all_results: + all_results[id_] = None + + print(f"\n✓ Conversion complete:") + print(f" Total: {len(ids)}") + print(f" Mapped: {len([v for v in all_results.values() if v])}") + print(f" Failed: {len(failed_ids)}") + + return all_results, failed_ids + + +def save_mapping_csv(mapping, output_file, from_db, to_db): + """Save mapping results to CSV.""" + print(f"\nSaving results to {output_file}...") + + with open(output_file, 'w', newline='') as f: + writer = csv.writer(f) + + # Header + writer.writerow(['Source_ID', 'Source_DB', 'Target_IDs', 'Target_DB', 'Mapping_Status']) + + # Data + for source_id, target_ids in sorted(mapping.items()): + if target_ids: + target_str = ";".join(target_ids) + status = "Success" + else: + target_str = "" + status = "Failed" + + writer.writerow([source_id, from_db, target_str, to_db, status]) + + print(f"✓ Results saved") + + +def save_failed_ids(failed_ids, output_file): + """Save failed IDs to file.""" + if not failed_ids: + return + + print(f"\nSaving failed IDs to {output_file}...") + + with open(output_file, 'w') as f: + for id_ in failed_ids: + f.write(f"{id_}\n") + + print(f"✓ Saved {len(failed_ids)} failed ID(s)") + + +def print_mapping_summary(mapping, from_db, to_db): + """Print summary of mapping results.""" + print(f"\n{'='*70}") + print("MAPPING SUMMARY") + print(f"{'='*70}") + + total = len(mapping) + mapped = len([v for v in mapping.values() if v]) + failed = total - mapped + + print(f"\nSource database: {from_db}") + print(f"Target database: {to_db}") + print(f"\nTotal identifiers: {total}") + print(f"Successfully mapped: {mapped} ({mapped/total*100:.1f}%)") + print(f"Failed to map: {failed} ({failed/total*100:.1f}%)") + + # Show some examples + if mapped > 0: + print(f"\nExample mappings (first 5):") + count = 0 + for source_id, target_ids in mapping.items(): + if target_ids: + target_str = ", ".join(target_ids[:3]) + if len(target_ids) > 3: + target_str += f" ... +{len(target_ids)-3} more" + print(f" {source_id} → {target_str}") + count += 1 + if count >= 5: + break + + # Show multiple mapping statistics + multiple_mappings = [v for v in mapping.values() if v and len(v) > 1] + if multiple_mappings: + print(f"\nMultiple target mappings: {len(multiple_mappings)} ID(s)") + print(f" (These source IDs map to multiple target IDs)") + + print(f"{'='*70}") + + +def list_common_databases(): + """Print list of common database codes.""" + print("\nCommon Database Codes:") + print("-" * 70) + print(f"{'Alias':<20} {'Official Code':<30}") + print("-" * 70) + + for alias, code in sorted(DATABASE_CODES.items()): + if alias != code.lower(): + print(f"{alias:<20} {code:<30}") + + print("-" * 70) + print("\nNote: Many other database codes are supported.") + print("See UniProt documentation for complete list.") + + +def main(): + """Main conversion workflow.""" + parser = argparse.ArgumentParser( + description="Batch convert biological identifiers between databases", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python batch_id_converter.py uniprot_ids.txt --from UniProtKB_AC-ID --to KEGG + python batch_id_converter.py ids.txt --from GeneID --to UniProtKB -o mapping.csv + python batch_id_converter.py ids.txt --from uniprot --to ensembl --chunk-size 50 + +Common database codes: + UniProtKB_AC-ID, KEGG, GeneID, Ensembl, Ensembl_Protein, + RefSeq_Protein, PDB, HGNC, GO, Pfam, InterPro, Reactome + +Use --list-databases to see all supported aliases. + """ + ) + parser.add_argument("input_file", help="Input file with IDs (one per line)") + parser.add_argument("--from", dest="from_db", required=True, + help="Source database code") + parser.add_argument("--to", dest="to_db", required=True, + help="Target database code") + parser.add_argument("-o", "--output", default=None, + help="Output CSV file (default: mapping_results.csv)") + parser.add_argument("--chunk-size", type=int, default=100, + help="Number of IDs per batch (default: 100)") + parser.add_argument("--delay", type=float, default=0.5, + help="Delay between batches in seconds (default: 0.5)") + parser.add_argument("--save-failed", action="store_true", + help="Save failed IDs to separate file") + parser.add_argument("--list-databases", action="store_true", + help="List common database codes and exit") + + args = parser.parse_args() + + # List databases and exit + if args.list_databases: + list_common_databases() + sys.exit(0) + + print("=" * 70) + print("BIOSERVICES: Batch Identifier Converter") + print("=" * 70) + + # Normalize database codes + from_db = normalize_database_code(args.from_db) + to_db = normalize_database_code(args.to_db) + + if from_db != args.from_db: + print(f"\nNote: Normalized '{args.from_db}' → '{from_db}'") + if to_db != args.to_db: + print(f"Note: Normalized '{args.to_db}' → '{to_db}'") + + # Read input IDs + try: + ids = read_ids_from_file(args.input_file) + except Exception as e: + print(f"\n✗ Error reading input file: {e}") + sys.exit(1) + + if not ids: + print("\n✗ No IDs found in input file") + sys.exit(1) + + # Perform conversion + mapping, failed_ids = batch_convert( + ids, + from_db, + to_db, + chunk_size=args.chunk_size, + delay=args.delay + ) + + # Print summary + print_mapping_summary(mapping, from_db, to_db) + + # Save results + output_file = args.output or "mapping_results.csv" + save_mapping_csv(mapping, output_file, from_db, to_db) + + # Save failed IDs if requested + if args.save_failed and failed_ids: + failed_file = output_file.replace(".csv", "_failed.txt") + save_failed_ids(failed_ids, failed_file) + + print(f"\n✓ Done!") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/bioservices/scripts/compound_cross_reference.py b/scientific-packages/bioservices/scripts/compound_cross_reference.py new file mode 100755 index 0000000..997ccf9 --- /dev/null +++ b/scientific-packages/bioservices/scripts/compound_cross_reference.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +""" +Compound Cross-Database Search + +This script searches for a compound by name and retrieves identifiers +from multiple databases: +- KEGG Compound +- ChEBI +- ChEMBL (via UniChem) +- Basic compound properties + +Usage: + python compound_cross_reference.py COMPOUND_NAME [--output FILE] + +Examples: + python compound_cross_reference.py Geldanamycin + python compound_cross_reference.py "Adenosine triphosphate" + python compound_cross_reference.py Aspirin --output aspirin_info.txt +""" + +import sys +import argparse +from bioservices import KEGG, UniChem, ChEBI, ChEMBL + + +def search_kegg_compound(compound_name): + """Search KEGG for compound by name.""" + print(f"\n{'='*70}") + print("STEP 1: KEGG Compound Search") + print(f"{'='*70}") + + k = KEGG() + + print(f"Searching KEGG for: {compound_name}") + + try: + results = k.find("compound", compound_name) + + if not results or not results.strip(): + print(f"✗ No results found in KEGG") + return k, None + + # Parse results + lines = results.strip().split("\n") + print(f"✓ Found {len(lines)} result(s):\n") + + for i, line in enumerate(lines[:5], 1): + parts = line.split("\t") + kegg_id = parts[0] + description = parts[1] if len(parts) > 1 else "No description" + print(f" {i}. {kegg_id}: {description}") + + # Use first result + first_result = lines[0].split("\t") + kegg_id = first_result[0].replace("cpd:", "") + + print(f"\nUsing: {kegg_id}") + + return k, kegg_id + + except Exception as e: + print(f"✗ Error: {e}") + return k, None + + +def get_kegg_info(kegg, kegg_id): + """Retrieve detailed KEGG compound information.""" + print(f"\n{'='*70}") + print("STEP 2: KEGG Compound Details") + print(f"{'='*70}") + + try: + print(f"Retrieving KEGG entry for {kegg_id}...") + + entry = kegg.get(f"cpd:{kegg_id}") + + if not entry: + print("✗ Failed to retrieve entry") + return None + + # Parse entry + compound_info = { + 'kegg_id': kegg_id, + 'name': None, + 'formula': None, + 'exact_mass': None, + 'mol_weight': None, + 'chebi_id': None, + 'pathways': [] + } + + current_section = None + + for line in entry.split("\n"): + if line.startswith("NAME"): + compound_info['name'] = line.replace("NAME", "").strip().rstrip(";") + + elif line.startswith("FORMULA"): + compound_info['formula'] = line.replace("FORMULA", "").strip() + + elif line.startswith("EXACT_MASS"): + compound_info['exact_mass'] = line.replace("EXACT_MASS", "").strip() + + elif line.startswith("MOL_WEIGHT"): + compound_info['mol_weight'] = line.replace("MOL_WEIGHT", "").strip() + + elif "ChEBI:" in line: + parts = line.split("ChEBI:") + if len(parts) > 1: + compound_info['chebi_id'] = parts[1].strip().split()[0] + + elif line.startswith("PATHWAY"): + current_section = "pathway" + pathway = line.replace("PATHWAY", "").strip() + if pathway: + compound_info['pathways'].append(pathway) + + elif current_section == "pathway" and line.startswith(" "): + pathway = line.strip() + if pathway: + compound_info['pathways'].append(pathway) + + elif line.startswith(" ") and not line.startswith(" "): + current_section = None + + # Display information + print(f"\n✓ KEGG Compound Information:") + print(f" ID: {compound_info['kegg_id']}") + print(f" Name: {compound_info['name']}") + print(f" Formula: {compound_info['formula']}") + print(f" Exact Mass: {compound_info['exact_mass']}") + print(f" Molecular Weight: {compound_info['mol_weight']}") + + if compound_info['chebi_id']: + print(f" ChEBI ID: {compound_info['chebi_id']}") + + if compound_info['pathways']: + print(f" Pathways: {len(compound_info['pathways'])} found") + + return compound_info + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def get_chembl_id(kegg_id): + """Map KEGG ID to ChEMBL via UniChem.""" + print(f"\n{'='*70}") + print("STEP 3: ChEMBL Mapping (via UniChem)") + print(f"{'='*70}") + + try: + u = UniChem() + + print(f"Mapping KEGG:{kegg_id} to ChEMBL...") + + chembl_id = u.get_compound_id_from_kegg(kegg_id) + + if chembl_id: + print(f"✓ ChEMBL ID: {chembl_id}") + return chembl_id + else: + print("✗ No ChEMBL mapping found") + return None + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def get_chebi_info(chebi_id): + """Retrieve ChEBI compound information.""" + print(f"\n{'='*70}") + print("STEP 4: ChEBI Details") + print(f"{'='*70}") + + if not chebi_id: + print("⊘ No ChEBI ID available") + return None + + try: + c = ChEBI() + + print(f"Retrieving ChEBI entry for {chebi_id}...") + + # Ensure proper format + if not chebi_id.startswith("CHEBI:"): + chebi_id = f"CHEBI:{chebi_id}" + + entity = c.getCompleteEntity(chebi_id) + + if entity: + print(f"\n✓ ChEBI Information:") + print(f" ID: {entity.chebiId}") + print(f" Name: {entity.chebiAsciiName}") + + if hasattr(entity, 'Formulae') and entity.Formulae: + print(f" Formula: {entity.Formulae}") + + if hasattr(entity, 'mass') and entity.mass: + print(f" Mass: {entity.mass}") + + if hasattr(entity, 'charge') and entity.charge: + print(f" Charge: {entity.charge}") + + return { + 'chebi_id': entity.chebiId, + 'name': entity.chebiAsciiName, + 'formula': entity.Formulae if hasattr(entity, 'Formulae') else None, + 'mass': entity.mass if hasattr(entity, 'mass') else None + } + else: + print("✗ Failed to retrieve ChEBI entry") + return None + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def get_chembl_info(chembl_id): + """Retrieve ChEMBL compound information.""" + print(f"\n{'='*70}") + print("STEP 5: ChEMBL Details") + print(f"{'='*70}") + + if not chembl_id: + print("⊘ No ChEMBL ID available") + return None + + try: + c = ChEMBL() + + print(f"Retrieving ChEMBL entry for {chembl_id}...") + + compound = c.get_compound_by_chemblId(chembl_id) + + if compound: + print(f"\n✓ ChEMBL Information:") + print(f" ID: {chembl_id}") + + if 'pref_name' in compound and compound['pref_name']: + print(f" Preferred Name: {compound['pref_name']}") + + if 'molecule_properties' in compound: + props = compound['molecule_properties'] + + if 'full_mwt' in props: + print(f" Molecular Weight: {props['full_mwt']}") + + if 'alogp' in props: + print(f" LogP: {props['alogp']}") + + if 'hba' in props: + print(f" H-Bond Acceptors: {props['hba']}") + + if 'hbd' in props: + print(f" H-Bond Donors: {props['hbd']}") + + if 'molecule_structures' in compound: + structs = compound['molecule_structures'] + + if 'canonical_smiles' in structs: + smiles = structs['canonical_smiles'] + print(f" SMILES: {smiles[:60]}{'...' if len(smiles) > 60 else ''}") + + return compound + else: + print("✗ Failed to retrieve ChEMBL entry") + return None + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def save_results(compound_name, kegg_info, chembl_id, output_file): + """Save results to file.""" + print(f"\n{'='*70}") + print(f"Saving results to {output_file}") + print(f"{'='*70}") + + with open(output_file, 'w') as f: + f.write("=" * 70 + "\n") + f.write(f"Compound Cross-Reference Report: {compound_name}\n") + f.write("=" * 70 + "\n\n") + + # KEGG information + if kegg_info: + f.write("KEGG Compound\n") + f.write("-" * 70 + "\n") + f.write(f"ID: {kegg_info['kegg_id']}\n") + f.write(f"Name: {kegg_info['name']}\n") + f.write(f"Formula: {kegg_info['formula']}\n") + f.write(f"Exact Mass: {kegg_info['exact_mass']}\n") + f.write(f"Molecular Weight: {kegg_info['mol_weight']}\n") + f.write(f"Pathways: {len(kegg_info['pathways'])} found\n") + f.write("\n") + + # Database IDs + f.write("Cross-Database Identifiers\n") + f.write("-" * 70 + "\n") + if kegg_info: + f.write(f"KEGG: {kegg_info['kegg_id']}\n") + if kegg_info['chebi_id']: + f.write(f"ChEBI: {kegg_info['chebi_id']}\n") + if chembl_id: + f.write(f"ChEMBL: {chembl_id}\n") + f.write("\n") + + print(f"✓ Results saved") + + +def main(): + """Main workflow.""" + parser = argparse.ArgumentParser( + description="Search compound across multiple databases", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python compound_cross_reference.py Geldanamycin + python compound_cross_reference.py "Adenosine triphosphate" + python compound_cross_reference.py Aspirin --output aspirin_info.txt + """ + ) + parser.add_argument("compound", help="Compound name to search") + parser.add_argument("--output", default=None, + help="Output file for results (optional)") + + args = parser.parse_args() + + print("=" * 70) + print("BIOSERVICES: Compound Cross-Database Search") + print("=" * 70) + + # Step 1: Search KEGG + kegg, kegg_id = search_kegg_compound(args.compound) + if not kegg_id: + print("\n✗ Failed to find compound. Exiting.") + sys.exit(1) + + # Step 2: Get KEGG details + kegg_info = get_kegg_info(kegg, kegg_id) + + # Step 3: Map to ChEMBL + chembl_id = get_chembl_id(kegg_id) + + # Step 4: Get ChEBI details + chebi_info = None + if kegg_info and kegg_info['chebi_id']: + chebi_info = get_chebi_info(kegg_info['chebi_id']) + + # Step 5: Get ChEMBL details + chembl_info = None + if chembl_id: + chembl_info = get_chembl_info(chembl_id) + + # Summary + print(f"\n{'='*70}") + print("SUMMARY") + print(f"{'='*70}") + print(f" Compound: {args.compound}") + if kegg_info: + print(f" KEGG ID: {kegg_info['kegg_id']}") + if kegg_info['chebi_id']: + print(f" ChEBI ID: {kegg_info['chebi_id']}") + if chembl_id: + print(f" ChEMBL ID: {chembl_id}") + print(f"{'='*70}") + + # Save to file if requested + if args.output: + save_results(args.compound, kegg_info, chembl_id, args.output) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/bioservices/scripts/pathway_analysis.py b/scientific-packages/bioservices/scripts/pathway_analysis.py new file mode 100755 index 0000000..5ec3322 --- /dev/null +++ b/scientific-packages/bioservices/scripts/pathway_analysis.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +KEGG Pathway Network Analysis + +This script analyzes all pathways for an organism and extracts: +- Pathway sizes (number of genes) +- Protein-protein interactions +- Interaction type distributions +- Network data in various formats (CSV, SIF) + +Usage: + python pathway_analysis.py ORGANISM OUTPUT_DIR [--limit N] + +Examples: + python pathway_analysis.py hsa ./human_pathways + python pathway_analysis.py mmu ./mouse_pathways --limit 50 + +Organism codes: + hsa = Homo sapiens (human) + mmu = Mus musculus (mouse) + dme = Drosophila melanogaster + sce = Saccharomyces cerevisiae (yeast) + eco = Escherichia coli +""" + +import sys +import os +import argparse +import csv +from collections import Counter +from bioservices import KEGG + + +def get_all_pathways(kegg, organism): + """Get all pathway IDs for organism.""" + print(f"\nRetrieving pathways for {organism}...") + + kegg.organism = organism + pathway_ids = kegg.pathwayIds + + print(f"✓ Found {len(pathway_ids)} pathways") + + return pathway_ids + + +def analyze_pathway(kegg, pathway_id): + """Analyze single pathway for size and interactions.""" + try: + # Parse KGML pathway + kgml = kegg.parse_kgml_pathway(pathway_id) + + entries = kgml.get('entries', []) + relations = kgml.get('relations', []) + + # Count relation types + relation_types = Counter() + for rel in relations: + rel_type = rel.get('name', 'unknown') + relation_types[rel_type] += 1 + + # Get pathway name + try: + entry = kegg.get(pathway_id) + pathway_name = "Unknown" + for line in entry.split("\n"): + if line.startswith("NAME"): + pathway_name = line.replace("NAME", "").strip() + break + except: + pathway_name = "Unknown" + + result = { + 'pathway_id': pathway_id, + 'pathway_name': pathway_name, + 'num_entries': len(entries), + 'num_relations': len(relations), + 'relation_types': dict(relation_types), + 'entries': entries, + 'relations': relations + } + + return result + + except Exception as e: + print(f" ✗ Error analyzing {pathway_id}: {e}") + return None + + +def analyze_all_pathways(kegg, pathway_ids, limit=None): + """Analyze all pathways.""" + if limit: + pathway_ids = pathway_ids[:limit] + print(f"\n⚠ Limiting analysis to first {limit} pathways") + + print(f"\nAnalyzing {len(pathway_ids)} pathways...") + + results = [] + for i, pathway_id in enumerate(pathway_ids, 1): + print(f" [{i}/{len(pathway_ids)}] {pathway_id}", end="\r") + + result = analyze_pathway(kegg, pathway_id) + if result: + results.append(result) + + print(f"\n✓ Successfully analyzed {len(results)}/{len(pathway_ids)} pathways") + + return results + + +def save_pathway_summary(results, output_file): + """Save pathway summary to CSV.""" + print(f"\nSaving pathway summary to {output_file}...") + + with open(output_file, 'w', newline='') as f: + writer = csv.writer(f) + + # Header + writer.writerow([ + 'Pathway_ID', + 'Pathway_Name', + 'Num_Genes', + 'Num_Interactions', + 'Activation', + 'Inhibition', + 'Phosphorylation', + 'Binding', + 'Other' + ]) + + # Data + for result in results: + rel_types = result['relation_types'] + + writer.writerow([ + result['pathway_id'], + result['pathway_name'], + result['num_entries'], + result['num_relations'], + rel_types.get('activation', 0), + rel_types.get('inhibition', 0), + rel_types.get('phosphorylation', 0), + rel_types.get('binding/association', 0), + sum(v for k, v in rel_types.items() + if k not in ['activation', 'inhibition', 'phosphorylation', 'binding/association']) + ]) + + print(f"✓ Summary saved") + + +def save_interactions_sif(results, output_file): + """Save all interactions in SIF format.""" + print(f"\nSaving interactions to {output_file}...") + + with open(output_file, 'w') as f: + for result in results: + pathway_id = result['pathway_id'] + + for rel in result['relations']: + entry1 = rel.get('entry1', '') + entry2 = rel.get('entry2', '') + interaction_type = rel.get('name', 'interaction') + + # Write SIF format: source\tinteraction\ttarget + f.write(f"{entry1}\t{interaction_type}\t{entry2}\n") + + print(f"✓ Interactions saved") + + +def save_detailed_pathway_info(results, output_dir): + """Save detailed information for each pathway.""" + print(f"\nSaving detailed pathway files to {output_dir}/pathways/...") + + pathway_dir = os.path.join(output_dir, "pathways") + os.makedirs(pathway_dir, exist_ok=True) + + for result in results: + pathway_id = result['pathway_id'].replace(":", "_") + filename = os.path.join(pathway_dir, f"{pathway_id}_interactions.csv") + + with open(filename, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Source', 'Target', 'Interaction_Type', 'Link_Type']) + + for rel in result['relations']: + writer.writerow([ + rel.get('entry1', ''), + rel.get('entry2', ''), + rel.get('name', 'unknown'), + rel.get('link', 'unknown') + ]) + + print(f"✓ Detailed files saved for {len(results)} pathways") + + +def print_statistics(results): + """Print analysis statistics.""" + print(f"\n{'='*70}") + print("PATHWAY ANALYSIS STATISTICS") + print(f"{'='*70}") + + # Total stats + total_pathways = len(results) + total_interactions = sum(r['num_relations'] for r in results) + total_genes = sum(r['num_entries'] for r in results) + + print(f"\nOverall:") + print(f" Total pathways: {total_pathways}") + print(f" Total genes/proteins: {total_genes}") + print(f" Total interactions: {total_interactions}") + + # Largest pathways + print(f"\nLargest pathways (by gene count):") + sorted_by_size = sorted(results, key=lambda x: x['num_entries'], reverse=True) + for i, result in enumerate(sorted_by_size[:10], 1): + print(f" {i}. {result['pathway_id']}: {result['num_entries']} genes") + print(f" {result['pathway_name']}") + + # Most connected pathways + print(f"\nMost connected pathways (by interactions):") + sorted_by_connections = sorted(results, key=lambda x: x['num_relations'], reverse=True) + for i, result in enumerate(sorted_by_connections[:10], 1): + print(f" {i}. {result['pathway_id']}: {result['num_relations']} interactions") + print(f" {result['pathway_name']}") + + # Interaction type distribution + print(f"\nInteraction type distribution:") + all_types = Counter() + for result in results: + for rel_type, count in result['relation_types'].items(): + all_types[rel_type] += count + + for rel_type, count in all_types.most_common(): + percentage = (count / total_interactions) * 100 if total_interactions > 0 else 0 + print(f" {rel_type}: {count} ({percentage:.1f}%)") + + +def main(): + """Main analysis workflow.""" + parser = argparse.ArgumentParser( + description="Analyze KEGG pathways for an organism", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python pathway_analysis.py hsa ./human_pathways + python pathway_analysis.py mmu ./mouse_pathways --limit 50 + +Organism codes: + hsa = Homo sapiens (human) + mmu = Mus musculus (mouse) + dme = Drosophila melanogaster + sce = Saccharomyces cerevisiae (yeast) + eco = Escherichia coli + """ + ) + parser.add_argument("organism", help="KEGG organism code (e.g., hsa, mmu)") + parser.add_argument("output_dir", help="Output directory for results") + parser.add_argument("--limit", type=int, default=None, + help="Limit analysis to first N pathways") + + args = parser.parse_args() + + print("=" * 70) + print("BIOSERVICES: KEGG Pathway Network Analysis") + print("=" * 70) + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize KEGG + kegg = KEGG() + + # Get all pathways + pathway_ids = get_all_pathways(kegg, args.organism) + + if not pathway_ids: + print(f"\n✗ No pathways found for {args.organism}") + sys.exit(1) + + # Analyze pathways + results = analyze_all_pathways(kegg, pathway_ids, args.limit) + + if not results: + print("\n✗ No pathways successfully analyzed") + sys.exit(1) + + # Print statistics + print_statistics(results) + + # Save results + summary_file = os.path.join(args.output_dir, "pathway_summary.csv") + save_pathway_summary(results, summary_file) + + sif_file = os.path.join(args.output_dir, "all_interactions.sif") + save_interactions_sif(results, sif_file) + + save_detailed_pathway_info(results, args.output_dir) + + # Final summary + print(f"\n{'='*70}") + print("OUTPUT FILES") + print(f"{'='*70}") + print(f" Summary: {summary_file}") + print(f" Interactions: {sif_file}") + print(f" Detailed: {args.output_dir}/pathways/") + print(f"{'='*70}") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/bioservices/scripts/protein_analysis_workflow.py b/scientific-packages/bioservices/scripts/protein_analysis_workflow.py new file mode 100755 index 0000000..7973423 --- /dev/null +++ b/scientific-packages/bioservices/scripts/protein_analysis_workflow.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +Complete Protein Analysis Workflow + +This script performs a comprehensive protein analysis pipeline: +1. UniProt search and identifier retrieval +2. FASTA sequence retrieval +3. BLAST similarity search +4. KEGG pathway discovery +5. PSICQUIC interaction mapping +6. GO annotation retrieval + +Usage: + python protein_analysis_workflow.py PROTEIN_NAME EMAIL [--skip-blast] + +Examples: + python protein_analysis_workflow.py ZAP70_HUMAN user@example.com + python protein_analysis_workflow.py P43403 user@example.com --skip-blast + +Note: BLAST searches can take several minutes. Use --skip-blast to skip this step. +""" + +import sys +import time +import argparse +from bioservices import UniProt, KEGG, NCBIblast, PSICQUIC, QuickGO + + +def search_protein(query): + """Search UniProt for protein and retrieve basic information.""" + print(f"\n{'='*70}") + print("STEP 1: UniProt Search") + print(f"{'='*70}") + + u = UniProt(verbose=False) + + print(f"Searching for: {query}") + + # Try direct retrieval first (if query looks like accession) + if len(query) == 6 and query[0] in "OPQ": + try: + entry = u.retrieve(query, frmt="tab") + if entry: + uniprot_id = query + print(f"✓ Found UniProt entry: {uniprot_id}") + return u, uniprot_id + except: + pass + + # Otherwise search + results = u.search(query, frmt="tab", columns="id,genes,organism,length,protein names", limit=5) + + if not results: + print("✗ No results found") + return u, None + + lines = results.strip().split("\n") + if len(lines) < 2: + print("✗ No entries found") + return u, None + + # Display results + print(f"\n✓ Found {len(lines)-1} result(s):") + for i, line in enumerate(lines[1:], 1): + fields = line.split("\t") + print(f" {i}. {fields[0]} - {fields[1]} ({fields[2]})") + + # Use first result + first_entry = lines[1].split("\t") + uniprot_id = first_entry[0] + gene_names = first_entry[1] if len(first_entry) > 1 else "N/A" + organism = first_entry[2] if len(first_entry) > 2 else "N/A" + length = first_entry[3] if len(first_entry) > 3 else "N/A" + protein_name = first_entry[4] if len(first_entry) > 4 else "N/A" + + print(f"\nUsing first result:") + print(f" UniProt ID: {uniprot_id}") + print(f" Gene names: {gene_names}") + print(f" Organism: {organism}") + print(f" Length: {length} aa") + print(f" Protein: {protein_name}") + + return u, uniprot_id + + +def retrieve_sequence(uniprot, uniprot_id): + """Retrieve FASTA sequence for protein.""" + print(f"\n{'='*70}") + print("STEP 2: FASTA Sequence Retrieval") + print(f"{'='*70}") + + try: + sequence = uniprot.retrieve(uniprot_id, frmt="fasta") + + if sequence: + # Extract sequence only (remove header) + lines = sequence.strip().split("\n") + header = lines[0] + seq_only = "".join(lines[1:]) + + print(f"✓ Retrieved sequence:") + print(f" Header: {header}") + print(f" Length: {len(seq_only)} residues") + print(f" First 60 residues: {seq_only[:60]}...") + + return seq_only + else: + print("✗ Failed to retrieve sequence") + return None + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def run_blast(sequence, email, skip=False): + """Run BLAST similarity search.""" + print(f"\n{'='*70}") + print("STEP 3: BLAST Similarity Search") + print(f"{'='*70}") + + if skip: + print("⊘ Skipped (--skip-blast flag)") + return None + + if not email or "@" not in email: + print("⊘ Skipped (valid email required for BLAST)") + return None + + try: + print(f"Submitting BLASTP job...") + print(f" Database: uniprotkb") + print(f" Sequence length: {len(sequence)} aa") + + s = NCBIblast(verbose=False) + + jobid = s.run( + program="blastp", + sequence=sequence, + stype="protein", + database="uniprotkb", + email=email + ) + + print(f"✓ Job submitted: {jobid}") + print(f" Waiting for completion...") + + # Poll for completion + max_wait = 300 # 5 minutes + start_time = time.time() + + while time.time() - start_time < max_wait: + status = s.getStatus(jobid) + elapsed = int(time.time() - start_time) + print(f" Status: {status} (elapsed: {elapsed}s)", end="\r") + + if status == "FINISHED": + print(f"\n✓ BLAST completed in {elapsed}s") + + # Retrieve results + results = s.getResult(jobid, "out") + + # Parse and display summary + lines = results.split("\n") + print(f"\n Results preview:") + for line in lines[:20]: + if line.strip(): + print(f" {line}") + + return results + + elif status == "ERROR": + print(f"\n✗ BLAST job failed") + return None + + time.sleep(5) + + print(f"\n✗ Timeout after {max_wait}s") + return None + + except Exception as e: + print(f"✗ Error: {e}") + return None + + +def discover_pathways(uniprot, kegg, uniprot_id): + """Discover KEGG pathways for protein.""" + print(f"\n{'='*70}") + print("STEP 4: KEGG Pathway Discovery") + print(f"{'='*70}") + + try: + # Map UniProt → KEGG + print(f"Mapping {uniprot_id} to KEGG...") + kegg_mapping = uniprot.mapping(fr="UniProtKB_AC-ID", to="KEGG", query=uniprot_id) + + if not kegg_mapping or uniprot_id not in kegg_mapping: + print("✗ No KEGG mapping found") + return [] + + kegg_ids = kegg_mapping[uniprot_id] + print(f"✓ KEGG ID(s): {kegg_ids}") + + # Get pathways for first KEGG ID + kegg_id = kegg_ids[0] + organism, gene_id = kegg_id.split(":") + + print(f"\nSearching pathways for {kegg_id}...") + pathways = kegg.get_pathway_by_gene(gene_id, organism) + + if not pathways: + print("✗ No pathways found") + return [] + + print(f"✓ Found {len(pathways)} pathway(s):\n") + + # Get pathway names + pathway_info = [] + for pathway_id in pathways: + try: + entry = kegg.get(pathway_id) + + # Extract pathway name + pathway_name = "Unknown" + for line in entry.split("\n"): + if line.startswith("NAME"): + pathway_name = line.replace("NAME", "").strip() + break + + pathway_info.append((pathway_id, pathway_name)) + print(f" • {pathway_id}: {pathway_name}") + + except Exception as e: + print(f" • {pathway_id}: [Error retrieving name]") + + return pathway_info + + except Exception as e: + print(f"✗ Error: {e}") + return [] + + +def find_interactions(protein_query): + """Find protein-protein interactions via PSICQUIC.""" + print(f"\n{'='*70}") + print("STEP 5: Protein-Protein Interactions") + print(f"{'='*70}") + + try: + p = PSICQUIC() + + # Try querying MINT database + query = f"{protein_query} AND species:9606" + print(f"Querying MINT database...") + print(f" Query: {query}") + + results = p.query("mint", query) + + if not results: + print("✗ No interactions found in MINT") + return [] + + # Parse PSI-MI TAB format + lines = results.strip().split("\n") + print(f"✓ Found {len(lines)} interaction(s):\n") + + # Display first 10 interactions + interactions = [] + for i, line in enumerate(lines[:10], 1): + fields = line.split("\t") + if len(fields) >= 12: + protein_a = fields[4].split(":")[1] if ":" in fields[4] else fields[4] + protein_b = fields[5].split(":")[1] if ":" in fields[5] else fields[5] + interaction_type = fields[11] + + interactions.append((protein_a, protein_b, interaction_type)) + print(f" {i}. {protein_a} ↔ {protein_b}") + + if len(lines) > 10: + print(f" ... and {len(lines)-10} more") + + return interactions + + except Exception as e: + print(f"✗ Error: {e}") + return [] + + +def get_go_annotations(uniprot_id): + """Retrieve GO annotations.""" + print(f"\n{'='*70}") + print("STEP 6: Gene Ontology Annotations") + print(f"{'='*70}") + + try: + g = QuickGO() + + print(f"Retrieving GO annotations for {uniprot_id}...") + annotations = g.Annotation(protein=uniprot_id, format="tsv") + + if not annotations: + print("✗ No GO annotations found") + return [] + + lines = annotations.strip().split("\n") + print(f"✓ Found {len(lines)-1} annotation(s)\n") + + # Group by aspect + aspects = {"P": [], "F": [], "C": []} + for line in lines[1:]: + fields = line.split("\t") + if len(fields) >= 9: + go_id = fields[6] + go_term = fields[7] + go_aspect = fields[8] + + if go_aspect in aspects: + aspects[go_aspect].append((go_id, go_term)) + + # Display summary + print(f" Biological Process (P): {len(aspects['P'])} terms") + for go_id, go_term in aspects['P'][:5]: + print(f" • {go_id}: {go_term}") + if len(aspects['P']) > 5: + print(f" ... and {len(aspects['P'])-5} more") + + print(f"\n Molecular Function (F): {len(aspects['F'])} terms") + for go_id, go_term in aspects['F'][:5]: + print(f" • {go_id}: {go_term}") + if len(aspects['F']) > 5: + print(f" ... and {len(aspects['F'])-5} more") + + print(f"\n Cellular Component (C): {len(aspects['C'])} terms") + for go_id, go_term in aspects['C'][:5]: + print(f" • {go_id}: {go_term}") + if len(aspects['C']) > 5: + print(f" ... and {len(aspects['C'])-5} more") + + return aspects + + except Exception as e: + print(f"✗ Error: {e}") + return {} + + +def main(): + """Main workflow.""" + parser = argparse.ArgumentParser( + description="Complete protein analysis workflow using BioServices", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python protein_analysis_workflow.py ZAP70_HUMAN user@example.com + python protein_analysis_workflow.py P43403 user@example.com --skip-blast + """ + ) + parser.add_argument("protein", help="Protein name or UniProt ID") + parser.add_argument("email", help="Email address (required for BLAST)") + parser.add_argument("--skip-blast", action="store_true", + help="Skip BLAST search (faster)") + + args = parser.parse_args() + + print("=" * 70) + print("BIOSERVICES: Complete Protein Analysis Workflow") + print("=" * 70) + + # Step 1: Search protein + uniprot, uniprot_id = search_protein(args.protein) + if not uniprot_id: + print("\n✗ Failed to find protein. Exiting.") + sys.exit(1) + + # Step 2: Retrieve sequence + sequence = retrieve_sequence(uniprot, uniprot_id) + if not sequence: + print("\n⚠ Warning: Could not retrieve sequence") + + # Step 3: BLAST search + if sequence: + blast_results = run_blast(sequence, args.email, args.skip_blast) + + # Step 4: Pathway discovery + kegg = KEGG() + pathways = discover_pathways(uniprot, kegg, uniprot_id) + + # Step 5: Interaction mapping + interactions = find_interactions(args.protein) + + # Step 6: GO annotations + go_terms = get_go_annotations(uniprot_id) + + # Summary + print(f"\n{'='*70}") + print("WORKFLOW SUMMARY") + print(f"{'='*70}") + print(f" Protein: {args.protein}") + print(f" UniProt ID: {uniprot_id}") + print(f" Sequence: {'✓' if sequence else '✗'}") + print(f" BLAST: {'✓' if not args.skip_blast and sequence else '⊘'}") + print(f" Pathways: {len(pathways)} found") + print(f" Interactions: {len(interactions)} found") + print(f" GO annotations: {sum(len(v) for v in go_terms.values())} found") + print(f"{'='*70}") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/cellxgene-census/SKILL.md b/scientific-packages/cellxgene-census/SKILL.md new file mode 100644 index 0000000..a394a5c --- /dev/null +++ b/scientific-packages/cellxgene-census/SKILL.md @@ -0,0 +1,505 @@ +--- +name: cellxgene-census +description: Access and analyze single-cell genomics data from the CZ CELLxGENE Census. This skill should be used when working with large-scale single-cell RNA-seq data, querying cell and gene metadata, training machine learning models on Census data, integrating multiple single-cell datasets, or performing cross-dataset analyses. It covers data exploration, expression queries, out-of-core processing, PyTorch integration, and scanpy workflows. +--- + +# CZ CELLxGENE Census + +## Overview + +The CZ CELLxGENE Census provides programmatic access to a comprehensive, versioned collection of standardized single-cell genomics data from CZ CELLxGENE Discover. This skill enables efficient querying and analysis of millions of cells across thousands of datasets. + +The Census includes: +- **61+ million cells** from human and mouse +- **Standardized metadata** (cell types, tissues, diseases, donors) +- **Raw gene expression** matrices +- **Pre-calculated embeddings** and statistics +- **Integration with PyTorch, scanpy, and other analysis tools** + +## When to Use This Skill + +Use this skill when tasks involve: +- Querying single-cell expression data by cell type, tissue, or disease +- Exploring available single-cell datasets and metadata +- Training machine learning models on single-cell data +- Performing large-scale cross-dataset analyses +- Integrating Census data with scanpy or other analysis frameworks +- Computing statistics across millions of cells +- Accessing pre-calculated embeddings or model predictions + +## Installation and Setup + +Install the Census API: +```bash +pip install cellxgene-census +``` + +For machine learning workflows, install additional dependencies: +```bash +pip install cellxgene-census[experimental] +``` + +## Core Workflow Patterns + +### 1. Opening the Census + +Always use the context manager to ensure proper resource cleanup: + +```python +import cellxgene_census + +# Open latest stable version +with cellxgene_census.open_soma() as census: + # Work with census data + +# Open specific version for reproducibility +with cellxgene_census.open_soma(census_version="2023-07-25") as census: + # Work with census data +``` + +**Key points:** +- Use context manager (`with` statement) for automatic cleanup +- Specify `census_version` for reproducible analyses +- Default opens latest "stable" release + +### 2. Exploring Census Information + +Before querying expression data, explore available datasets and metadata. + +**Access summary information:** +```python +# Get summary statistics +summary = census["census_info"]["summary"].read().concat().to_pandas() +print(f"Total cells: {summary['total_cell_count'][0]}") + +# Get all datasets +datasets = census["census_info"]["datasets"].read().concat().to_pandas() + +# Filter datasets by criteria +covid_datasets = datasets[datasets["disease"].str.contains("COVID", na=False)] +``` + +**Query cell metadata to understand available data:** +```python +# Get unique cell types in a tissue +cell_metadata = cellxgene_census.get_obs( + census, + "homo_sapiens", + value_filter="tissue_general == 'brain' and is_primary_data == True", + column_names=["cell_type"] +) +unique_cell_types = cell_metadata["cell_type"].unique() +print(f"Found {len(unique_cell_types)} cell types in brain") + +# Count cells by tissue +tissue_counts = cell_metadata.groupby("tissue_general").size() +``` + +**Important:** Always filter for `is_primary_data == True` to avoid counting duplicate cells unless specifically analyzing duplicates. + +### 3. Querying Expression Data (Small to Medium Scale) + +For queries returning < 100k cells that fit in memory, use `get_anndata()`: + +```python +# Basic query with cell type and tissue filters +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", # or "Mus musculus" + obs_value_filter="cell_type == 'B cell' and tissue_general == 'lung' and is_primary_data == True", + obs_column_names=["assay", "disease", "sex", "donor_id"], +) + +# Query specific genes with multiple filters +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + var_value_filter="feature_name in ['CD4', 'CD8A', 'CD19', 'FOXP3']", + obs_value_filter="cell_type == 'T cell' and disease == 'COVID-19' and is_primary_data == True", + obs_column_names=["cell_type", "tissue_general", "donor_id"], +) +``` + +**Filter syntax:** +- Use `obs_value_filter` for cell filtering +- Use `var_value_filter` for gene filtering +- Combine conditions with `and`, `or` +- Use `in` for multiple values: `tissue in ['lung', 'liver']` +- Select only needed columns with `obs_column_names` + +**Getting metadata separately:** +```python +# Query cell metadata +cell_metadata = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="disease == 'COVID-19' and is_primary_data == True", + column_names=["cell_type", "tissue_general", "donor_id"] +) + +# Query gene metadata +gene_metadata = cellxgene_census.get_var( + census, "homo_sapiens", + value_filter="feature_name in ['CD4', 'CD8A']", + column_names=["feature_id", "feature_name", "feature_length"] +) +``` + +### 4. Large-Scale Queries (Out-of-Core Processing) + +For queries exceeding available RAM, use `axis_query()` with iterative processing: + +```python +import tiledbsoma as soma + +# Create axis query +query = census["census_data"]["homo_sapiens"].axis_query( + measurement_name="RNA", + obs_query=soma.AxisQuery( + value_filter="tissue_general == 'brain' and is_primary_data == True" + ), + var_query=soma.AxisQuery( + value_filter="feature_name in ['FOXP2', 'TBR1', 'SATB2']" + ) +) + +# Iterate through expression matrix in chunks +iterator = query.X("raw").tables() +for batch in iterator: + # batch is a pyarrow.Table with columns: + # - soma_data: expression value + # - soma_dim_0: cell (obs) coordinate + # - soma_dim_1: gene (var) coordinate + process_batch(batch) +``` + +**Computing incremental statistics:** +```python +# Example: Calculate mean expression +n_observations = 0 +sum_values = 0.0 + +iterator = query.X("raw").tables() +for batch in iterator: + values = batch["soma_data"].to_numpy() + n_observations += len(values) + sum_values += values.sum() + +mean_expression = sum_values / n_observations +``` + +### 5. Machine Learning with PyTorch + +For training models, use the experimental PyTorch integration: + +```python +from cellxgene_census.experimental.ml import experiment_dataloader + +with cellxgene_census.open_soma() as census: + # Create dataloader + dataloader = experiment_dataloader( + census["census_data"]["homo_sapiens"], + measurement_name="RNA", + X_name="raw", + obs_value_filter="tissue_general == 'liver' and is_primary_data == True", + obs_column_names=["cell_type"], + batch_size=128, + shuffle=True, + ) + + # Training loop + for epoch in range(num_epochs): + for batch in dataloader: + X = batch["X"] # Gene expression tensor + labels = batch["obs"]["cell_type"] # Cell type labels + + # Forward pass + outputs = model(X) + loss = criterion(outputs, labels) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() +``` + +**Train/test splitting:** +```python +from cellxgene_census.experimental.ml import ExperimentDataset + +# Create dataset from experiment +dataset = ExperimentDataset( + experiment_axis_query, + layer_name="raw", + obs_column_names=["cell_type"], + batch_size=128, +) + +# Split into train and test +train_dataset, test_dataset = dataset.random_split( + split=[0.8, 0.2], + seed=42 +) +``` + +### 6. Integration with Scanpy + +Seamlessly integrate Census data with scanpy workflows: + +```python +import scanpy as sc + +# Load data from Census +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="cell_type == 'neuron' and tissue_general == 'cortex' and is_primary_data == True", +) + +# Standard scanpy workflow +sc.pp.normalize_total(adata, target_sum=1e4) +sc.pp.log1p(adata) +sc.pp.highly_variable_genes(adata, n_top_genes=2000) + +# Dimensionality reduction +sc.pp.pca(adata, n_comps=50) +sc.pp.neighbors(adata) +sc.tl.umap(adata) + +# Visualization +sc.pl.umap(adata, color=["cell_type", "tissue", "disease"]) +``` + +### 7. Multi-Dataset Integration + +Query and integrate multiple datasets: + +```python +# Strategy 1: Query multiple tissues separately +tissues = ["lung", "liver", "kidney"] +adatas = [] + +for tissue in tissues: + adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", + ) + adata.obs["tissue"] = tissue + adatas.append(adata) + +# Concatenate +combined = adatas[0].concatenate(adatas[1:]) + +# Strategy 2: Query multiple datasets directly +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="tissue_general in ['lung', 'liver', 'kidney'] and is_primary_data == True", +) +``` + +## Key Concepts and Best Practices + +### Always Filter for Primary Data +Unless analyzing duplicates, always include `is_primary_data == True` in queries to avoid counting cells multiple times: +```python +obs_value_filter="cell_type == 'B cell' and is_primary_data == True" +``` + +### Specify Census Version for Reproducibility +Always specify the Census version in production analyses: +```python +census = cellxgene_census.open_soma(census_version="2023-07-25") +``` + +### Estimate Query Size Before Loading +For large queries, first check the number of cells to avoid memory issues: +```python +# Get cell count +metadata = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="tissue_general == 'brain' and is_primary_data == True", + column_names=["soma_joinid"] +) +n_cells = len(metadata) +print(f"Query will return {n_cells:,} cells") + +# If too large (>100k), use out-of-core processing +``` + +### Use tissue_general for Broader Groupings +The `tissue_general` field provides coarser categories than `tissue`, useful for cross-tissue analyses: +```python +# Broader grouping +obs_value_filter="tissue_general == 'immune system'" + +# Specific tissue +obs_value_filter="tissue == 'peripheral blood mononuclear cell'" +``` + +### Select Only Needed Columns +Minimize data transfer by specifying only required metadata columns: +```python +obs_column_names=["cell_type", "tissue_general", "disease"] # Not all columns +``` + +### Check Dataset Presence for Gene-Specific Queries +When analyzing specific genes, verify which datasets measured them: +```python +presence = cellxgene_census.get_presence_matrix( + census, + "homo_sapiens", + var_value_filter="feature_name in ['CD4', 'CD8A']" +) +``` + +### Two-Step Workflow: Explore Then Query +First explore metadata to understand available data, then query expression: +```python +# Step 1: Explore what's available +metadata = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="disease == 'COVID-19' and is_primary_data == True", + column_names=["cell_type", "tissue_general"] +) +print(metadata.value_counts()) + +# Step 2: Query based on findings +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="disease == 'COVID-19' and cell_type == 'T cell' and is_primary_data == True", +) +``` + +## Available Metadata Fields + +### Cell Metadata (obs) +Key fields for filtering: +- `cell_type`, `cell_type_ontology_term_id` +- `tissue`, `tissue_general`, `tissue_ontology_term_id` +- `disease`, `disease_ontology_term_id` +- `assay`, `assay_ontology_term_id` +- `donor_id`, `sex`, `self_reported_ethnicity` +- `development_stage`, `development_stage_ontology_term_id` +- `dataset_id` +- `is_primary_data` (Boolean: True = unique cell) + +### Gene Metadata (var) +- `feature_id` (Ensembl gene ID, e.g., "ENSG00000161798") +- `feature_name` (Gene symbol, e.g., "FOXP2") +- `feature_length` (Gene length in base pairs) + +## Reference Documentation + +This skill includes detailed reference documentation: + +### references/census_schema.md +Comprehensive documentation of: +- Census data structure and organization +- All available metadata fields +- Value filter syntax and operators +- SOMA object types +- Data inclusion criteria + +**When to read:** When you need detailed schema information, full list of metadata fields, or complex filter syntax. + +### references/common_patterns.md +Examples and patterns for: +- Exploratory queries (metadata only) +- Small-to-medium queries (AnnData) +- Large queries (out-of-core processing) +- PyTorch integration +- Scanpy integration workflows +- Multi-dataset integration +- Best practices and common pitfalls + +**When to read:** When implementing specific query patterns, looking for code examples, or troubleshooting common issues. + +## Common Use Cases + +### Use Case 1: Explore Cell Types in a Tissue +```python +with cellxgene_census.open_soma() as census: + cells = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="tissue_general == 'lung' and is_primary_data == True", + column_names=["cell_type"] + ) + print(cells["cell_type"].value_counts()) +``` + +### Use Case 2: Query Marker Gene Expression +```python +with cellxgene_census.open_soma() as census: + adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + var_value_filter="feature_name in ['CD4', 'CD8A', 'CD19']", + obs_value_filter="cell_type in ['T cell', 'B cell'] and is_primary_data == True", + ) +``` + +### Use Case 3: Train Cell Type Classifier +```python +from cellxgene_census.experimental.ml import experiment_dataloader + +with cellxgene_census.open_soma() as census: + dataloader = experiment_dataloader( + census["census_data"]["homo_sapiens"], + measurement_name="RNA", + X_name="raw", + obs_value_filter="is_primary_data == True", + obs_column_names=["cell_type"], + batch_size=128, + shuffle=True, + ) + + # Train model + for epoch in range(epochs): + for batch in dataloader: + # Training logic + pass +``` + +### Use Case 4: Cross-Tissue Analysis +```python +with cellxgene_census.open_soma() as census: + adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="cell_type == 'macrophage' and tissue_general in ['lung', 'liver', 'brain'] and is_primary_data == True", + ) + + # Analyze macrophage differences across tissues + sc.tl.rank_genes_groups(adata, groupby="tissue_general") +``` + +## Troubleshooting + +### Query Returns Too Many Cells +- Add more specific filters to reduce scope +- Use `tissue` instead of `tissue_general` for finer granularity +- Filter by specific `dataset_id` if known +- Switch to out-of-core processing for large queries + +### Memory Errors +- Reduce query scope with more restrictive filters +- Select fewer genes with `var_value_filter` +- Use out-of-core processing with `axis_query()` +- Process data in batches + +### Duplicate Cells in Results +- Always include `is_primary_data == True` in filters +- Check if intentionally querying across multiple datasets + +### Gene Not Found +- Verify gene name spelling (case-sensitive) +- Try Ensembl ID with `feature_id` instead of `feature_name` +- Check dataset presence matrix to see if gene was measured +- Some genes may have been filtered during Census construction + +### Version Inconsistencies +- Always specify `census_version` explicitly +- Use same version across all analyses +- Check release notes for version-specific changes diff --git a/scientific-packages/cellxgene-census/references/census_schema.md b/scientific-packages/cellxgene-census/references/census_schema.md new file mode 100644 index 0000000..de38701 --- /dev/null +++ b/scientific-packages/cellxgene-census/references/census_schema.md @@ -0,0 +1,182 @@ +# CZ CELLxGENE Census Data Schema Reference + +## Overview + +The CZ CELLxGENE Census is a versioned collection of single-cell data built on the TileDB-SOMA framework. This reference documents the data structure, available metadata fields, and query syntax. + +## High-Level Structure + +The Census is organized as a `SOMACollection` with two main components: + +### 1. census_info +Summary information including: +- **summary**: Build date, cell counts, dataset statistics +- **datasets**: All datasets from CELLxGENE Discover with metadata +- **summary_cell_counts**: Cell counts stratified by metadata categories + +### 2. census_data +Organism-specific `SOMAExperiment` objects: +- **"homo_sapiens"**: Human single-cell data +- **"mus_musculus"**: Mouse single-cell data + +## Data Structure Per Organism + +Each organism experiment contains: + +### obs (Cell Metadata) +Cell-level annotations stored as a `SOMADataFrame`. Access via: +```python +census["census_data"]["homo_sapiens"].obs +``` + +### ms["RNA"] (Measurement) +RNA measurement data including: +- **X**: Data matrices with layers: + - `raw`: Raw count data + - `normalized`: (if available) Normalized counts +- **var**: Gene metadata +- **feature_dataset_presence_matrix**: Sparse boolean array showing which genes were measured in each dataset + +## Cell Metadata Fields (obs) + +### Required/Core Fields + +**Identity & Dataset:** +- `soma_joinid`: Unique integer identifier for joins +- `dataset_id`: Source dataset identifier +- `is_primary_data`: Boolean flag (True = unique cell, False = duplicate across datasets) + +**Cell Type:** +- `cell_type`: Human-readable cell type name +- `cell_type_ontology_term_id`: Standardized ontology term (e.g., "CL:0000236") + +**Tissue:** +- `tissue`: Specific tissue name +- `tissue_general`: Broader tissue category (useful for grouping) +- `tissue_ontology_term_id`: Standardized ontology term + +**Assay:** +- `assay`: Sequencing technology used +- `assay_ontology_term_id`: Standardized ontology term + +**Disease:** +- `disease`: Disease status or condition +- `disease_ontology_term_id`: Standardized ontology term + +**Donor:** +- `donor_id`: Unique donor identifier +- `sex`: Biological sex (male, female, unknown) +- `self_reported_ethnicity`: Ethnicity information +- `development_stage`: Life stage (adult, child, embryonic, etc.) +- `development_stage_ontology_term_id`: Standardized ontology term + +**Organism:** +- `organism`: Scientific name (Homo sapiens, Mus musculus) +- `organism_ontology_term_id`: Standardized ontology term + +**Technical:** +- `suspension_type`: Sample preparation type (cell, nucleus, na) + +## Gene Metadata Fields (var) + +Access via: +```python +census["census_data"]["homo_sapiens"].ms["RNA"].var +``` + +**Available Fields:** +- `soma_joinid`: Unique integer identifier for joins +- `feature_id`: Ensembl gene ID (e.g., "ENSG00000161798") +- `feature_name`: Gene symbol (e.g., "FOXP2") +- `feature_length`: Gene length in base pairs + +## Value Filter Syntax + +Queries use Python-like expressions for filtering. The syntax is processed by TileDB-SOMA. + +### Comparison Operators +- `==`: Equal to +- `!=`: Not equal to +- `<`, `>`, `<=`, `>=`: Numeric comparisons +- `in`: Membership test (e.g., `feature_id in ['ENSG00000161798', 'ENSG00000188229']`) + +### Logical Operators +- `and`, `&`: Logical AND +- `or`, `|`: Logical OR + +### Examples + +**Single condition:** +```python +value_filter="cell_type == 'B cell'" +``` + +**Multiple conditions with AND:** +```python +value_filter="cell_type == 'B cell' and tissue_general == 'lung' and is_primary_data == True" +``` + +**Using IN for multiple values:** +```python +value_filter="tissue in ['lung', 'liver', 'kidney']" +``` + +**Complex condition:** +```python +value_filter="(cell_type == 'neuron' or cell_type == 'astrocyte') and disease != 'normal'" +``` + +**Filtering genes:** +```python +var_value_filter="feature_name in ['CD4', 'CD8A', 'CD19']" +``` + +## Data Inclusion Criteria + +The Census includes all data from CZ CELLxGENE Discover meeting: + +1. **Species**: Human (*Homo sapiens*) or mouse (*Mus musculus*) +2. **Technology**: Approved sequencing technologies for RNA +3. **Count Type**: Raw counts only (no processed/normalized-only data) +4. **Metadata**: Standardized following CELLxGENE schema +5. **Both spatial and non-spatial data**: Includes traditional and spatial transcriptomics + +## Important Data Characteristics + +### Duplicate Cells +Cells may appear across multiple datasets. Use `is_primary_data == True` to filter for unique cells in most analyses. + +### Count Types +The Census includes: +- **Molecule counts**: From UMI-based methods +- **Full-gene sequencing read counts**: From non-UMI methods +These may need different normalization approaches. + +### Versioning +Census releases are versioned (e.g., "2023-07-25", "stable"). Always specify version for reproducible analysis: +```python +census = cellxgene_census.open_soma(census_version="2023-07-25") +``` + +## Dataset Presence Matrix + +Access which genes were measured in each dataset: +```python +presence_matrix = census["census_data"]["homo_sapiens"].ms["RNA"]["feature_dataset_presence_matrix"] +``` + +This sparse boolean matrix helps understand: +- Gene coverage across datasets +- Which datasets to include for specific gene analyses +- Technical batch effects related to gene coverage + +## SOMA Object Types + +Core TileDB-SOMA objects used: +- **DataFrame**: Tabular data (obs, var) +- **SparseNDArray**: Sparse matrices (X layers, presence matrix) +- **DenseNDArray**: Dense arrays (less common) +- **Collection**: Container for related objects +- **Experiment**: Top-level container for measurements +- **SOMAScene**: Spatial transcriptomics scenes +- **obs_spatial_presence**: Spatial data availability diff --git a/scientific-packages/cellxgene-census/references/common_patterns.md b/scientific-packages/cellxgene-census/references/common_patterns.md new file mode 100644 index 0000000..8ca9ff8 --- /dev/null +++ b/scientific-packages/cellxgene-census/references/common_patterns.md @@ -0,0 +1,351 @@ +# Common Query Patterns and Best Practices + +## Query Pattern Categories + +### 1. Exploratory Queries (Metadata Only) + +Use when exploring available data without loading expression matrices. + +**Pattern: Get unique cell types in a tissue** +```python +import cellxgene_census + +with cellxgene_census.open_soma() as census: + cell_metadata = cellxgene_census.get_obs( + census, + "homo_sapiens", + value_filter="tissue_general == 'brain' and is_primary_data == True", + column_names=["cell_type"] + ) + unique_cell_types = cell_metadata["cell_type"].unique() + print(f"Found {len(unique_cell_types)} unique cell types") +``` + +**Pattern: Count cells by condition** +```python +cell_metadata = cellxgene_census.get_obs( + census, + "homo_sapiens", + value_filter="disease != 'normal' and is_primary_data == True", + column_names=["disease", "tissue_general"] +) +counts = cell_metadata.groupby(["disease", "tissue_general"]).size() +``` + +**Pattern: Explore dataset information** +```python +# Access datasets table +datasets = census["census_info"]["datasets"].read().concat().to_pandas() + +# Filter for specific criteria +covid_datasets = datasets[datasets["disease"].str.contains("COVID", na=False)] +``` + +### 2. Small-to-Medium Queries (AnnData) + +Use `get_anndata()` when results fit in memory (typically < 100k cells). + +**Pattern: Tissue-specific cell type query** +```python +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="cell_type == 'B cell' and tissue_general == 'lung' and is_primary_data == True", + obs_column_names=["assay", "disease", "sex", "donor_id"], +) +``` + +**Pattern: Gene-specific query with multiple genes** +```python +marker_genes = ["CD4", "CD8A", "CD19", "FOXP3"] + +# First get gene IDs +gene_metadata = cellxgene_census.get_var( + census, "homo_sapiens", + value_filter=f"feature_name in {marker_genes}", + column_names=["feature_id", "feature_name"] +) +gene_ids = gene_metadata["feature_id"].tolist() + +# Query with gene filter +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + var_value_filter=f"feature_id in {gene_ids}", + obs_value_filter="cell_type == 'T cell' and is_primary_data == True", +) +``` + +**Pattern: Multi-tissue query** +```python +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="tissue_general in ['lung', 'liver', 'kidney'] and is_primary_data == True", + obs_column_names=["cell_type", "tissue_general", "dataset_id"], +) +``` + +**Pattern: Disease-specific query** +```python +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="disease == 'COVID-19' and tissue_general == 'lung' and is_primary_data == True", +) +``` + +### 3. Large Queries (Out-of-Core Processing) + +Use `axis_query()` for queries that exceed available RAM. + +**Pattern: Iterative processing** +```python +import pyarrow as pa + +# Create query +query = census["census_data"]["homo_sapiens"].axis_query( + measurement_name="RNA", + obs_query=soma.AxisQuery( + value_filter="tissue_general == 'brain' and is_primary_data == True" + ), + var_query=soma.AxisQuery( + value_filter="feature_name in ['FOXP2', 'TBR1', 'SATB2']" + ) +) + +# Iterate through X matrix in chunks +iterator = query.X("raw").tables() +for batch in iterator: + # Process batch (a pyarrow.Table) + # batch has columns: soma_data, soma_dim_0, soma_dim_1 + process_batch(batch) +``` + +**Pattern: Incremental statistics (mean/variance)** +```python +# Using Welford's online algorithm +n = 0 +mean = 0 +M2 = 0 + +iterator = query.X("raw").tables() +for batch in iterator: + values = batch["soma_data"].to_numpy() + for x in values: + n += 1 + delta = x - mean + mean += delta / n + delta2 = x - mean + M2 += delta * delta2 + +variance = M2 / (n - 1) if n > 1 else 0 +``` + +### 4. PyTorch Integration (Machine Learning) + +Use `experiment_dataloader()` for training models. + +**Pattern: Create training dataloader** +```python +from cellxgene_census.experimental.ml import experiment_dataloader +import torch + +with cellxgene_census.open_soma() as census: + # Create dataloader + dataloader = experiment_dataloader( + census["census_data"]["homo_sapiens"], + measurement_name="RNA", + X_name="raw", + obs_value_filter="tissue_general == 'liver' and is_primary_data == True", + obs_column_names=["cell_type"], + batch_size=128, + shuffle=True, + ) + + # Training loop + for epoch in range(num_epochs): + for batch in dataloader: + X = batch["X"] # Gene expression + labels = batch["obs"]["cell_type"] # Cell type labels + # Train model... +``` + +**Pattern: Train/test split** +```python +from cellxgene_census.experimental.ml import ExperimentDataset + +# Create dataset from query +dataset = ExperimentDataset( + experiment_axis_query, + layer_name="raw", + obs_column_names=["cell_type"], + batch_size=128, +) + +# Split data +train_dataset, test_dataset = dataset.random_split( + split=[0.8, 0.2], + seed=42 +) + +# Create loaders +train_loader = experiment_dataloader(train_dataset) +test_loader = experiment_dataloader(test_dataset) +``` + +### 5. Integration Workflows + +**Pattern: Scanpy integration** +```python +import scanpy as sc + +# Load data +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="cell_type == 'neuron' and is_primary_data == True", +) + +# Standard scanpy workflow +sc.pp.normalize_total(adata, target_sum=1e4) +sc.pp.log1p(adata) +sc.pp.highly_variable_genes(adata) +sc.pp.pca(adata) +sc.pp.neighbors(adata) +sc.tl.umap(adata) +sc.pl.umap(adata, color=["cell_type", "tissue_general"]) +``` + +**Pattern: Multi-dataset integration** +```python +# Query multiple datasets separately +datasets_to_integrate = ["dataset_id_1", "dataset_id_2", "dataset_id_3"] + +adatas = [] +for dataset_id in datasets_to_integrate: + adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter=f"dataset_id == '{dataset_id}' and is_primary_data == True", + ) + adatas.append(adata) + +# Integrate using scanorama, harmony, or other tools +import scanpy.external as sce +sce.pp.scanorama_integrate(adatas) +``` + +## Best Practices + +### 1. Always Filter for Primary Data +Unless specifically analyzing duplicates, always include `is_primary_data == True`: +```python +obs_value_filter="cell_type == 'B cell' and is_primary_data == True" +``` + +### 2. Specify Census Version +For reproducible analysis, always specify the Census version: +```python +census = cellxgene_census.open_soma(census_version="2023-07-25") +``` + +### 3. Use Context Manager +Always use the context manager to ensure proper cleanup: +```python +with cellxgene_census.open_soma() as census: + # Your code here +``` + +### 4. Select Only Needed Columns +Minimize data transfer by selecting only required metadata columns: +```python +obs_column_names=["cell_type", "tissue_general", "disease"] # Not all columns +``` + +### 5. Check Dataset Presence for Gene Queries +When analyzing specific genes, check which datasets measured them: +```python +presence = cellxgene_census.get_presence_matrix( + census, + "homo_sapiens", + var_value_filter="feature_name in ['CD4', 'CD8A']" +) +``` + +### 6. Use tissue_general for Broader Queries +`tissue_general` provides coarser groupings than `tissue`, useful for cross-tissue analyses: +```python +# Better for broad queries +obs_value_filter="tissue_general == 'immune system'" + +# Use specific tissue when needed +obs_value_filter="tissue == 'peripheral blood mononuclear cell'" +``` + +### 7. Combine Metadata Exploration with Expression Queries +First explore metadata to understand available data, then query expression: +```python +# Step 1: Explore +metadata = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="disease == 'COVID-19'", + column_names=["cell_type", "tissue_general"] +) +print(metadata.value_counts()) + +# Step 2: Query based on findings +adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter="disease == 'COVID-19' and cell_type == 'T cell' and is_primary_data == True", +) +``` + +### 8. Memory Management for Large Queries +For large queries, check estimated size before loading: +```python +# Get cell count first +metadata = cellxgene_census.get_obs( + census, "homo_sapiens", + value_filter="tissue_general == 'brain' and is_primary_data == True", + column_names=["soma_joinid"] +) +n_cells = len(metadata) +print(f"Query will return {n_cells} cells") + +# If too large, use out-of-core processing or further filtering +``` + +### 9. Leverage Ontology Terms for Consistency +When possible, use ontology term IDs instead of free text: +```python +# More reliable than cell_type == 'B cell' across datasets +obs_value_filter="cell_type_ontology_term_id == 'CL:0000236'" +``` + +### 10. Batch Processing Pattern +For systematic analyses across multiple conditions: +```python +tissues = ["lung", "liver", "kidney", "heart"] +results = {} + +for tissue in tissues: + adata = cellxgene_census.get_anndata( + census=census, + organism="Homo sapiens", + obs_value_filter=f"tissue_general == '{tissue}' and is_primary_data == True", + ) + # Perform analysis + results[tissue] = analyze(adata) +``` + +## Common Pitfalls to Avoid + +1. **Not filtering for is_primary_data**: Leads to counting duplicate cells +2. **Loading too much data**: Use metadata queries to estimate size first +3. **Not using context manager**: Can cause resource leaks +4. **Inconsistent versioning**: Results not reproducible without specifying version +5. **Overly broad queries**: Start with focused queries, expand as needed +6. **Ignoring dataset presence**: Some genes not measured in all datasets +7. **Wrong count normalization**: Be aware of UMI vs read count differences diff --git a/scientific-packages/cobrapy/SKILL.md b/scientific-packages/cobrapy/SKILL.md new file mode 100644 index 0000000..37b2d19 --- /dev/null +++ b/scientific-packages/cobrapy/SKILL.md @@ -0,0 +1,457 @@ +--- +name: cobrapy +description: Comprehensive toolkit for constraint-based reconstruction and analysis (COBRA) of metabolic models. Use when working with genome-scale metabolic models, performing flux balance analysis (FBA), simulating cellular metabolism, conducting gene/reaction knockout studies, gapfilling metabolic networks, analyzing flux distributions, calculating minimal media requirements, or any systems biology task involving computational modeling of cellular metabolism. Supports SBML, JSON, YAML, and MATLAB formats. +--- + +# COBRApy - Constraint-Based Reconstruction and Analysis + +## Overview + +COBRApy is a Python library for constraint-based reconstruction and analysis (COBRA) of metabolic models, essential for systems biology research. Use this skill to work with genome-scale metabolic models, perform computational simulations of cellular metabolism, conduct metabolic engineering analyses, and predict phenotypic behaviors. + +## Core Capabilities + +COBRApy provides comprehensive tools organized into several key areas: + +### 1. Model Management + +Load existing models from repositories or files: +```python +from cobra.io import load_model + +# Load bundled test models +model = load_model("textbook") # E. coli core model +model = load_model("ecoli") # Full E. coli model +model = load_model("salmonella") + +# Load from files +from cobra.io import read_sbml_model, load_json_model, load_yaml_model +model = read_sbml_model("path/to/model.xml") +model = load_json_model("path/to/model.json") +model = load_yaml_model("path/to/model.yml") +``` + +Save models in various formats: +```python +from cobra.io import write_sbml_model, save_json_model, save_yaml_model +write_sbml_model(model, "output.xml") # Preferred format +save_json_model(model, "output.json") # For Escher compatibility +save_yaml_model(model, "output.yml") # Human-readable +``` + +### 2. Model Structure and Components + +Access and inspect model components: +```python +# Access components +model.reactions # DictList of all reactions +model.metabolites # DictList of all metabolites +model.genes # DictList of all genes + +# Get specific items by ID or index +reaction = model.reactions.get_by_id("PFK") +metabolite = model.metabolites[0] + +# Inspect properties +print(reaction.reaction) # Stoichiometric equation +print(reaction.bounds) # Flux constraints +print(reaction.gene_reaction_rule) # GPR logic +print(metabolite.formula) # Chemical formula +print(metabolite.compartment) # Cellular location +``` + +### 3. Flux Balance Analysis (FBA) + +Perform standard FBA simulation: +```python +# Basic optimization +solution = model.optimize() +print(f"Objective value: {solution.objective_value}") +print(f"Status: {solution.status}") + +# Access fluxes +print(solution.fluxes["PFK"]) +print(solution.fluxes.head()) + +# Fast optimization (objective value only) +objective_value = model.slim_optimize() + +# Change objective +model.objective = "ATPM" +solution = model.optimize() +``` + +Parsimonious FBA (minimize total flux): +```python +from cobra.flux_analysis import pfba +solution = pfba(model) +``` + +Geometric FBA (find central solution): +```python +from cobra.flux_analysis import geometric_fba +solution = geometric_fba(model) +``` + +### 4. Flux Variability Analysis (FVA) + +Determine flux ranges for all reactions: +```python +from cobra.flux_analysis import flux_variability_analysis + +# Standard FVA +fva_result = flux_variability_analysis(model) + +# FVA at 90% optimality +fva_result = flux_variability_analysis(model, fraction_of_optimum=0.9) + +# Loopless FVA (eliminates thermodynamically infeasible loops) +fva_result = flux_variability_analysis(model, loopless=True) + +# FVA for specific reactions +fva_result = flux_variability_analysis( + model, + reaction_list=["PFK", "FBA", "PGI"] +) +``` + +### 5. Gene and Reaction Deletion Studies + +Perform knockout analyses: +```python +from cobra.flux_analysis import ( + single_gene_deletion, + single_reaction_deletion, + double_gene_deletion, + double_reaction_deletion +) + +# Single deletions +gene_results = single_gene_deletion(model) +reaction_results = single_reaction_deletion(model) + +# Double deletions (uses multiprocessing) +double_gene_results = double_gene_deletion( + model, + processes=4 # Number of CPU cores +) + +# Manual knockout using context manager +with model: + model.genes.get_by_id("b0008").knock_out() + solution = model.optimize() + print(f"Growth after knockout: {solution.objective_value}") +# Model automatically reverts after context exit +``` + +### 6. Growth Media and Minimal Media + +Manage growth medium: +```python +# View current medium +print(model.medium) + +# Modify medium (must reassign entire dict) +medium = model.medium +medium["EX_glc__D_e"] = 10.0 # Set glucose uptake +medium["EX_o2_e"] = 0.0 # Anaerobic conditions +model.medium = medium + +# Calculate minimal media +from cobra.medium import minimal_medium + +# Minimize total import flux +min_medium = minimal_medium(model, minimize_components=False) + +# Minimize number of components (uses MILP, slower) +min_medium = minimal_medium( + model, + minimize_components=True, + open_exchanges=True +) +``` + +### 7. Flux Sampling + +Sample the feasible flux space: +```python +from cobra.sampling import sample + +# Sample using OptGP (default, supports parallel processing) +samples = sample(model, n=1000, method="optgp", processes=4) + +# Sample using ACHR +samples = sample(model, n=1000, method="achr") + +# Validate samples +from cobra.sampling import OptGPSampler +sampler = OptGPSampler(model, processes=4) +sampler.sample(1000) +validation = sampler.validate(sampler.samples) +print(validation.value_counts()) # Should be all 'v' for valid +``` + +### 8. Production Envelopes + +Calculate phenotype phase planes: +```python +from cobra.flux_analysis import production_envelope + +# Standard production envelope +envelope = production_envelope( + model, + reactions=["EX_glc__D_e", "EX_o2_e"], + objective="EX_ac_e" # Acetate production +) + +# With carbon yield +envelope = production_envelope( + model, + reactions=["EX_glc__D_e", "EX_o2_e"], + carbon_sources="EX_glc__D_e" +) + +# Visualize (use matplotlib or pandas plotting) +import matplotlib.pyplot as plt +envelope.plot(x="EX_glc__D_e", y="EX_o2_e", kind="scatter") +plt.show() +``` + +### 9. Gapfilling + +Add reactions to make models feasible: +```python +from cobra.flux_analysis import gapfill + +# Prepare universal model with candidate reactions +universal = load_model("universal") + +# Perform gapfilling +with model: + # Remove reactions to create gaps for demonstration + model.remove_reactions([model.reactions.PGI]) + + # Find reactions needed + solution = gapfill(model, universal) + print(f"Reactions to add: {solution}") +``` + +### 10. Model Building + +Build models from scratch: +```python +from cobra import Model, Reaction, Metabolite + +# Create model +model = Model("my_model") + +# Create metabolites +atp_c = Metabolite("atp_c", formula="C10H12N5O13P3", + name="ATP", compartment="c") +adp_c = Metabolite("adp_c", formula="C10H12N5O10P2", + name="ADP", compartment="c") +pi_c = Metabolite("pi_c", formula="HO4P", + name="Phosphate", compartment="c") + +# Create reaction +reaction = Reaction("ATPASE") +reaction.name = "ATP hydrolysis" +reaction.subsystem = "Energy" +reaction.lower_bound = 0.0 +reaction.upper_bound = 1000.0 + +# Add metabolites with stoichiometry +reaction.add_metabolites({ + atp_c: -1.0, + adp_c: 1.0, + pi_c: 1.0 +}) + +# Add gene-reaction rule +reaction.gene_reaction_rule = "(gene1 and gene2) or gene3" + +# Add to model +model.add_reactions([reaction]) + +# Add boundary reactions +model.add_boundary(atp_c, type="exchange") +model.add_boundary(adp_c, type="demand") + +# Set objective +model.objective = "ATPASE" +``` + +## Common Workflows + +### Workflow 1: Load Model and Predict Growth + +```python +from cobra.io import load_model + +# Load model +model = load_model("ecoli") + +# Run FBA +solution = model.optimize() +print(f"Growth rate: {solution.objective_value:.3f} /h") + +# Show active pathways +print(solution.fluxes[solution.fluxes.abs() > 1e-6]) +``` + +### Workflow 2: Gene Knockout Screen + +```python +from cobra.io import load_model +from cobra.flux_analysis import single_gene_deletion + +# Load model +model = load_model("ecoli") + +# Perform single gene deletions +results = single_gene_deletion(model) + +# Find essential genes (growth < threshold) +essential_genes = results[results["growth"] < 0.01] +print(f"Found {len(essential_genes)} essential genes") + +# Find genes with minimal impact +neutral_genes = results[results["growth"] > 0.9 * solution.objective_value] +``` + +### Workflow 3: Media Optimization + +```python +from cobra.io import load_model +from cobra.medium import minimal_medium + +# Load model +model = load_model("ecoli") + +# Calculate minimal medium for 50% of max growth +target_growth = model.slim_optimize() * 0.5 +min_medium = minimal_medium( + model, + target_growth, + minimize_components=True +) + +print(f"Minimal medium components: {len(min_medium)}") +print(min_medium) +``` + +### Workflow 4: Flux Uncertainty Analysis + +```python +from cobra.io import load_model +from cobra.flux_analysis import flux_variability_analysis +from cobra.sampling import sample + +# Load model +model = load_model("ecoli") + +# First check flux ranges at optimality +fva = flux_variability_analysis(model, fraction_of_optimum=1.0) + +# For reactions with large ranges, sample to understand distribution +samples = sample(model, n=1000) + +# Analyze specific reaction +reaction_id = "PFK" +import matplotlib.pyplot as plt +samples[reaction_id].hist(bins=50) +plt.xlabel(f"Flux through {reaction_id}") +plt.ylabel("Frequency") +plt.show() +``` + +### Workflow 5: Context Manager for Temporary Changes + +Use context managers to make temporary modifications: +```python +# Model remains unchanged outside context +with model: + # Temporarily change objective + model.objective = "ATPM" + + # Temporarily modify bounds + model.reactions.EX_glc__D_e.lower_bound = -5.0 + + # Temporarily knock out genes + model.genes.b0008.knock_out() + + # Optimize with changes + solution = model.optimize() + print(f"Modified growth: {solution.objective_value}") + +# All changes automatically reverted +solution = model.optimize() +print(f"Original growth: {solution.objective_value}") +``` + +## Key Concepts + +### DictList Objects +Models use `DictList` objects for reactions, metabolites, and genes - behaving like both lists and dictionaries: +```python +# Access by index +first_reaction = model.reactions[0] + +# Access by ID +pfk = model.reactions.get_by_id("PFK") + +# Query methods +atp_reactions = model.reactions.query("atp") +``` + +### Flux Constraints +Reaction bounds define feasible flux ranges: +- **Irreversible**: `lower_bound = 0, upper_bound > 0` +- **Reversible**: `lower_bound < 0, upper_bound > 0` +- Set both bounds simultaneously with `.bounds` to avoid inconsistencies + +### Gene-Reaction Rules (GPR) +Boolean logic linking genes to reactions: +```python +# AND logic (both required) +reaction.gene_reaction_rule = "gene1 and gene2" + +# OR logic (either sufficient) +reaction.gene_reaction_rule = "gene1 or gene2" + +# Complex logic +reaction.gene_reaction_rule = "(gene1 and gene2) or (gene3 and gene4)" +``` + +### Exchange Reactions +Special reactions representing metabolite import/export: +- Named with prefix `EX_` by convention +- Positive flux = secretion, negative flux = uptake +- Managed through `model.medium` dictionary + +## Best Practices + +1. **Use context managers** for temporary modifications to avoid state management issues +2. **Validate models** before analysis using `model.slim_optimize()` to ensure feasibility +3. **Check solution status** after optimization - `optimal` indicates successful solve +4. **Use loopless FVA** when thermodynamic feasibility matters +5. **Set fraction_of_optimum** appropriately in FVA to explore suboptimal space +6. **Parallelize** computationally expensive operations (sampling, double deletions) +7. **Prefer SBML format** for model exchange and long-term storage +8. **Use slim_optimize()** when only objective value needed for performance +9. **Validate flux samples** to ensure numerical stability + +## Troubleshooting + +**Infeasible solutions**: Check medium constraints, reaction bounds, and model consistency +**Slow optimization**: Try different solvers (GLPK, CPLEX, Gurobi) via `model.solver` +**Unbounded solutions**: Verify exchange reactions have appropriate upper bounds +**Import errors**: Ensure correct file format and valid SBML identifiers + +## References + +For detailed workflows and API patterns, refer to: +- `references/workflows.md` - Comprehensive step-by-step workflow examples +- `references/api_quick_reference.md` - Common function signatures and patterns + +Official documentation: https://cobrapy.readthedocs.io/en/latest/ diff --git a/scientific-packages/cobrapy/references/api_quick_reference.md b/scientific-packages/cobrapy/references/api_quick_reference.md new file mode 100644 index 0000000..c7b4922 --- /dev/null +++ b/scientific-packages/cobrapy/references/api_quick_reference.md @@ -0,0 +1,655 @@ +# COBRApy API Quick Reference + +This document provides quick reference for common COBRApy functions, signatures, and usage patterns. + +## Model I/O + +### Loading Models + +```python +from cobra.io import load_model, read_sbml_model, load_json_model, load_yaml_model, load_matlab_model + +# Bundled test models +model = load_model("textbook") # E. coli core metabolism +model = load_model("ecoli") # Full E. coli iJO1366 +model = load_model("salmonella") # Salmonella LT2 + +# From files +model = read_sbml_model(filename, f_replace={}, **kwargs) +model = load_json_model(filename) +model = load_yaml_model(filename) +model = load_matlab_model(filename, variable_name=None) +``` + +### Saving Models + +```python +from cobra.io import write_sbml_model, save_json_model, save_yaml_model, save_matlab_model + +write_sbml_model(model, filename, f_replace={}, **kwargs) +save_json_model(model, filename, pretty=False, **kwargs) +save_yaml_model(model, filename, **kwargs) +save_matlab_model(model, filename, **kwargs) +``` + +## Model Structure + +### Core Classes + +```python +from cobra import Model, Reaction, Metabolite, Gene + +# Create model +model = Model(id_or_model=None, name=None) + +# Create metabolite +metabolite = Metabolite( + id=None, + formula=None, + name="", + charge=None, + compartment=None +) + +# Create reaction +reaction = Reaction( + id=None, + name="", + subsystem="", + lower_bound=0.0, + upper_bound=None +) + +# Create gene +gene = Gene(id=None, name="", functional=True) +``` + +### Model Attributes + +```python +# Component access (DictList objects) +model.reactions # DictList of Reaction objects +model.metabolites # DictList of Metabolite objects +model.genes # DictList of Gene objects + +# Special reaction lists +model.exchanges # Exchange reactions (external transport) +model.demands # Demand reactions (metabolite sinks) +model.sinks # Sink reactions +model.boundary # All boundary reactions + +# Model properties +model.objective # Current objective (read/write) +model.objective_direction # "max" or "min" +model.medium # Growth medium (dict of exchange: bound) +model.solver # Optimization solver +``` + +### DictList Methods + +```python +# Access by index +item = model.reactions[0] + +# Access by ID +item = model.reactions.get_by_id("PFK") + +# Query by string (substring match) +items = model.reactions.query("atp") # Case-insensitive search +items = model.reactions.query(lambda x: x.subsystem == "Glycolysis") + +# List comprehension +items = [r for r in model.reactions if r.lower_bound < 0] + +# Check membership +"PFK" in model.reactions +``` + +## Optimization + +### Basic Optimization + +```python +# Full optimization (returns Solution object) +solution = model.optimize() + +# Attributes of Solution +solution.objective_value # Objective function value +solution.status # Optimization status ("optimal", "infeasible", etc.) +solution.fluxes # Pandas Series of reaction fluxes +solution.shadow_prices # Pandas Series of metabolite shadow prices +solution.reduced_costs # Pandas Series of reduced costs + +# Fast optimization (returns float only) +objective_value = model.slim_optimize() + +# Change objective +model.objective = "ATPM" +model.objective = model.reactions.ATPM +model.objective = {model.reactions.ATPM: 1.0} + +# Change optimization direction +model.objective_direction = "max" # or "min" +``` + +### Solver Configuration + +```python +# Check available solvers +from cobra.util.solver import solvers +print(solvers) + +# Change solver +model.solver = "glpk" # or "cplex", "gurobi", etc. + +# Solver-specific configuration +model.solver.configuration.timeout = 60 # seconds +model.solver.configuration.verbosity = 1 +model.solver.configuration.tolerances.feasibility = 1e-9 +``` + +## Flux Analysis + +### Flux Balance Analysis (FBA) + +```python +from cobra.flux_analysis import pfba, geometric_fba + +# Parsimonious FBA +solution = pfba(model, fraction_of_optimum=1.0, **kwargs) + +# Geometric FBA +solution = geometric_fba(model, epsilon=1e-06, max_tries=200) +``` + +### Flux Variability Analysis (FVA) + +```python +from cobra.flux_analysis import flux_variability_analysis + +fva_result = flux_variability_analysis( + model, + reaction_list=None, # List of reaction IDs or None for all + loopless=False, # Eliminate thermodynamically infeasible loops + fraction_of_optimum=1.0, # Optimality fraction (0.0-1.0) + pfba_factor=None, # Optional pFBA constraint + processes=1 # Number of parallel processes +) + +# Returns DataFrame with columns: minimum, maximum +``` + +### Gene and Reaction Deletions + +```python +from cobra.flux_analysis import ( + single_gene_deletion, + single_reaction_deletion, + double_gene_deletion, + double_reaction_deletion +) + +# Single deletions +results = single_gene_deletion( + model, + gene_list=None, # None for all genes + processes=1, + **kwargs +) + +results = single_reaction_deletion( + model, + reaction_list=None, # None for all reactions + processes=1, + **kwargs +) + +# Double deletions +results = double_gene_deletion( + model, + gene_list1=None, + gene_list2=None, + processes=1, + **kwargs +) + +results = double_reaction_deletion( + model, + reaction_list1=None, + reaction_list2=None, + processes=1, + **kwargs +) + +# Returns DataFrame with columns: ids, growth, status +# For double deletions, index is MultiIndex of gene/reaction pairs +``` + +### Flux Sampling + +```python +from cobra.sampling import sample, OptGPSampler, ACHRSampler + +# Simple interface +samples = sample( + model, + n, # Number of samples + method="optgp", # or "achr" + thinning=100, # Thinning factor (sample every n iterations) + processes=1, # Parallel processes (OptGP only) + seed=None # Random seed +) + +# Advanced interface with sampler objects +sampler = OptGPSampler(model, processes=4, thinning=100) +sampler = ACHRSampler(model, thinning=100) + +# Generate samples +samples = sampler.sample(n) + +# Validate samples +validation = sampler.validate(sampler.samples) +# Returns array of 'v' (valid), 'l' (lower bound violation), +# 'u' (upper bound violation), 'e' (equality violation) + +# Batch sampling +sampler.batch(n_samples, n_batches) +``` + +### Production Envelopes + +```python +from cobra.flux_analysis import production_envelope + +envelope = production_envelope( + model, + reactions, # List of 1-2 reaction IDs + objective=None, # Objective reaction ID (None uses model objective) + carbon_sources=None, # Carbon source for yield calculation + points=20, # Number of points to calculate + threshold=0.01 # Minimum objective value threshold +) + +# Returns DataFrame with columns: +# - First reaction flux +# - Second reaction flux (if provided) +# - objective_minimum, objective_maximum +# - carbon_yield_minimum, carbon_yield_maximum (if carbon source specified) +# - mass_yield_minimum, mass_yield_maximum +``` + +### Gapfilling + +```python +from cobra.flux_analysis import gapfill + +# Basic gapfilling +solution = gapfill( + model, + universal=None, # Universal model with candidate reactions + lower_bound=0.05, # Minimum objective flux + penalties=None, # Dict of reaction: penalty + demand_reactions=True, # Add demand reactions if needed + exchange_reactions=False, + iterations=1 +) + +# Returns list of Reaction objects to add + +# Multiple solutions +solutions = [] +for i in range(5): + sol = gapfill(model, universal, iterations=1) + solutions.append(sol) + # Prevent finding same solution by increasing penalties +``` + +### Other Analysis Methods + +```python +from cobra.flux_analysis import ( + find_blocked_reactions, + find_essential_genes, + find_essential_reactions +) + +# Blocked reactions (cannot carry flux) +blocked = find_blocked_reactions( + model, + reaction_list=None, + zero_cutoff=1e-9, + open_exchanges=False +) + +# Essential genes/reactions +essential_genes = find_essential_genes(model, threshold=0.01) +essential_reactions = find_essential_reactions(model, threshold=0.01) +``` + +## Media and Boundary Conditions + +### Medium Management + +```python +# Get current medium (returns dict) +medium = model.medium + +# Set medium (must reassign entire dict) +medium = model.medium +medium["EX_glc__D_e"] = 10.0 +medium["EX_o2_e"] = 20.0 +model.medium = medium + +# Alternative: individual modification +with model: + model.reactions.EX_glc__D_e.lower_bound = -10.0 +``` + +### Minimal Media + +```python +from cobra.medium import minimal_medium + +min_medium = minimal_medium( + model, + min_objective_value=0.1, # Minimum growth rate + minimize_components=False, # If True, uses MILP (slower) + open_exchanges=False, # Open all exchanges before optimization + exports=False, # Allow metabolite export + penalties=None # Dict of exchange: penalty +) + +# Returns Series of exchange reactions with fluxes +``` + +### Boundary Reactions + +```python +# Add boundary reaction +model.add_boundary( + metabolite, + type="exchange", # or "demand", "sink" + reaction_id=None, # Auto-generated if None + lb=None, + ub=None, + sbo_term=None +) + +# Access boundary reactions +exchanges = model.exchanges # System boundary +demands = model.demands # Intracellular removal +sinks = model.sinks # Intracellular exchange +boundaries = model.boundary # All boundary reactions +``` + +## Model Manipulation + +### Adding Components + +```python +# Add reactions +model.add_reactions([reaction1, reaction2, ...]) +model.add_reaction(reaction) + +# Add metabolites +reaction.add_metabolites({ + metabolite1: -1.0, # Consumed (negative stoichiometry) + metabolite2: 1.0 # Produced (positive stoichiometry) +}) + +# Add metabolites to model +model.add_metabolites([metabolite1, metabolite2, ...]) + +# Add genes (usually automatic via gene_reaction_rule) +model.genes += [gene1, gene2, ...] +``` + +### Removing Components + +```python +# Remove reactions +model.remove_reactions([reaction1, reaction2, ...]) +model.remove_reactions(["PFK", "FBA"]) + +# Remove metabolites (removes from reactions too) +model.remove_metabolites([metabolite1, metabolite2, ...]) + +# Remove genes (usually via gene_reaction_rule) +model.genes.remove(gene) +``` + +### Modifying Reactions + +```python +# Set bounds +reaction.bounds = (lower, upper) +reaction.lower_bound = 0.0 +reaction.upper_bound = 1000.0 + +# Modify stoichiometry +reaction.add_metabolites({metabolite: 1.0}) +reaction.subtract_metabolites({metabolite: 1.0}) + +# Change gene-reaction rule +reaction.gene_reaction_rule = "(gene1 and gene2) or gene3" + +# Knock out +reaction.knock_out() +gene.knock_out() +``` + +### Model Copying + +```python +# Deep copy (independent model) +model_copy = model.copy() + +# Copy specific reactions +new_model = Model("subset") +reactions_to_copy = [model.reactions.PFK, model.reactions.FBA] +new_model.add_reactions(reactions_to_copy) +``` + +## Context Management + +Use context managers for temporary modifications: + +```python +# Changes automatically revert after with block +with model: + model.objective = "ATPM" + model.reactions.EX_glc__D_e.lower_bound = -5.0 + model.genes.b0008.knock_out() + solution = model.optimize() + +# Model state restored here + +# Multiple nested contexts +with model: + model.objective = "ATPM" + with model: + model.genes.b0008.knock_out() + # Both modifications active + # Only objective change active + +# Context management with reactions +with model: + model.reactions.PFK.knock_out() + # Equivalent to: reaction.lower_bound = reaction.upper_bound = 0 +``` + +## Reaction and Metabolite Properties + +### Reaction Attributes + +```python +reaction.id # Unique identifier +reaction.name # Human-readable name +reaction.subsystem # Pathway/subsystem +reaction.bounds # (lower_bound, upper_bound) +reaction.lower_bound +reaction.upper_bound +reaction.reversibility # Boolean (lower_bound < 0) +reaction.gene_reaction_rule # GPR string +reaction.genes # Set of associated Gene objects +reaction.metabolites # Dict of {metabolite: stoichiometry} + +# Methods +reaction.reaction # Stoichiometric equation string +reaction.build_reaction_string() # Same as above +reaction.check_mass_balance() # Returns imbalances or empty dict +reaction.get_coefficient(metabolite_id) +reaction.add_metabolites({metabolite: coeff}) +reaction.subtract_metabolites({metabolite: coeff}) +reaction.knock_out() +``` + +### Metabolite Attributes + +```python +metabolite.id # Unique identifier +metabolite.name # Human-readable name +metabolite.formula # Chemical formula +metabolite.charge # Charge +metabolite.compartment # Compartment ID +metabolite.reactions # FrozenSet of associated reactions + +# Methods +metabolite.summary() # Print production/consumption +metabolite.copy() +``` + +### Gene Attributes + +```python +gene.id # Unique identifier +gene.name # Human-readable name +gene.functional # Boolean activity status +gene.reactions # FrozenSet of associated reactions + +# Methods +gene.knock_out() +``` + +## Model Validation + +### Consistency Checking + +```python +from cobra.manipulation import check_mass_balance, check_metabolite_compartment_formula + +# Check all reactions for mass balance +unbalanced = {} +for reaction in model.reactions: + balance = reaction.check_mass_balance() + if balance: + unbalanced[reaction.id] = balance + +# Check metabolite formulas are valid +check_metabolite_compartment_formula(model) +``` + +### Model Statistics + +```python +# Basic stats +print(f"Reactions: {len(model.reactions)}") +print(f"Metabolites: {len(model.metabolites)}") +print(f"Genes: {len(model.genes)}") + +# Advanced stats +print(f"Exchanges: {len(model.exchanges)}") +print(f"Demands: {len(model.demands)}") + +# Blocked reactions +from cobra.flux_analysis import find_blocked_reactions +blocked = find_blocked_reactions(model) +print(f"Blocked reactions: {len(blocked)}") + +# Essential genes +from cobra.flux_analysis import find_essential_genes +essential = find_essential_genes(model) +print(f"Essential genes: {len(essential)}") +``` + +## Summary Methods + +```python +# Model summary +model.summary() # Overall model info + +# Metabolite summary +model.metabolites.atp_c.summary() + +# Reaction summary +model.reactions.PFK.summary() + +# Summary with FVA +model.summary(fva=0.95) # Include FVA at 95% optimality +``` + +## Common Patterns + +### Batch Analysis Pattern + +```python +results = [] +for condition in conditions: + with model: + # Apply condition + setup_condition(model, condition) + + # Analyze + solution = model.optimize() + + # Store result + results.append({ + "condition": condition, + "growth": solution.objective_value, + "status": solution.status + }) + +df = pd.DataFrame(results) +``` + +### Systematic Knockout Pattern + +```python +knockout_results = [] +for gene in model.genes: + with model: + gene.knock_out() + + solution = model.optimize() + + knockout_results.append({ + "gene": gene.id, + "growth": solution.objective_value if solution.status == "optimal" else 0, + "status": solution.status + }) + +df = pd.DataFrame(knockout_results) +``` + +### Parameter Scan Pattern + +```python +parameter_values = np.linspace(0, 20, 21) +results = [] + +for value in parameter_values: + with model: + model.reactions.EX_glc__D_e.lower_bound = -value + + solution = model.optimize() + + results.append({ + "glucose_uptake": value, + "growth": solution.objective_value, + "acetate_secretion": solution.fluxes["EX_ac_e"] + }) + +df = pd.DataFrame(results) +``` + +This quick reference covers the most commonly used COBRApy functions and patterns. For complete API documentation, see https://cobrapy.readthedocs.io/ diff --git a/scientific-packages/cobrapy/references/workflows.md b/scientific-packages/cobrapy/references/workflows.md new file mode 100644 index 0000000..08d5e5a --- /dev/null +++ b/scientific-packages/cobrapy/references/workflows.md @@ -0,0 +1,593 @@ +# COBRApy Comprehensive Workflows + +This document provides detailed step-by-step workflows for common COBRApy tasks in metabolic modeling. + +## Workflow 1: Complete Knockout Study with Visualization + +This workflow demonstrates how to perform a comprehensive gene knockout study and visualize the results. + +```python +import pandas as pd +import matplotlib.pyplot as plt +from cobra.io import load_model +from cobra.flux_analysis import single_gene_deletion, double_gene_deletion + +# Step 1: Load model +model = load_model("ecoli") +print(f"Loaded model: {model.id}") +print(f"Model contains {len(model.reactions)} reactions, {len(model.metabolites)} metabolites, {len(model.genes)} genes") + +# Step 2: Get baseline growth rate +baseline = model.slim_optimize() +print(f"Baseline growth rate: {baseline:.3f} /h") + +# Step 3: Perform single gene deletions +print("Performing single gene deletions...") +single_results = single_gene_deletion(model) + +# Step 4: Classify genes by impact +essential_genes = single_results[single_results["growth"] < 0.01] +severely_impaired = single_results[(single_results["growth"] >= 0.01) & + (single_results["growth"] < 0.5 * baseline)] +moderately_impaired = single_results[(single_results["growth"] >= 0.5 * baseline) & + (single_results["growth"] < 0.9 * baseline)] +neutral_genes = single_results[single_results["growth"] >= 0.9 * baseline] + +print(f"\nSingle Deletion Results:") +print(f" Essential genes: {len(essential_genes)}") +print(f" Severely impaired: {len(severely_impaired)}") +print(f" Moderately impaired: {len(moderately_impaired)}") +print(f" Neutral genes: {len(neutral_genes)}") + +# Step 5: Visualize distribution +fig, ax = plt.subplots(figsize=(10, 6)) +single_results["growth"].hist(bins=50, ax=ax) +ax.axvline(baseline, color='r', linestyle='--', label='Baseline') +ax.set_xlabel("Growth rate (/h)") +ax.set_ylabel("Number of genes") +ax.set_title("Distribution of Growth Rates After Single Gene Deletions") +ax.legend() +plt.tight_layout() +plt.savefig("single_deletion_distribution.png", dpi=300) + +# Step 6: Identify gene pairs for double deletions +# Focus on non-essential genes to find synthetic lethals +target_genes = single_results[single_results["growth"] >= 0.5 * baseline].index.tolist() +target_genes = [list(gene)[0] for gene in target_genes[:50]] # Limit for performance + +print(f"\nPerforming double deletions on {len(target_genes)} genes...") +double_results = double_gene_deletion( + model, + gene_list1=target_genes, + processes=4 +) + +# Step 7: Find synthetic lethal pairs +synthetic_lethals = double_results[ + (double_results["growth"] < 0.01) & + (single_results.loc[double_results.index.get_level_values(0)]["growth"].values >= 0.5 * baseline) & + (single_results.loc[double_results.index.get_level_values(1)]["growth"].values >= 0.5 * baseline) +] + +print(f"Found {len(synthetic_lethals)} synthetic lethal gene pairs") +print("\nTop 10 synthetic lethal pairs:") +print(synthetic_lethals.head(10)) + +# Step 8: Export results +single_results.to_csv("single_gene_deletions.csv") +double_results.to_csv("double_gene_deletions.csv") +synthetic_lethals.to_csv("synthetic_lethals.csv") +``` + +## Workflow 2: Media Design and Optimization + +This workflow shows how to systematically design growth media and find minimal media compositions. + +```python +from cobra.io import load_model +from cobra.medium import minimal_medium +import pandas as pd + +# Step 1: Load model and check current medium +model = load_model("ecoli") +current_medium = model.medium +print("Current medium composition:") +for exchange, bound in current_medium.items(): + metabolite_id = exchange.replace("EX_", "").replace("_e", "") + print(f" {metabolite_id}: {bound:.2f} mmol/gDW/h") + +# Step 2: Get baseline growth +baseline_growth = model.slim_optimize() +print(f"\nBaseline growth rate: {baseline_growth:.3f} /h") + +# Step 3: Calculate minimal medium for different growth targets +growth_targets = [0.25, 0.5, 0.75, 1.0] +minimal_media = {} + +for fraction in growth_targets: + target_growth = baseline_growth * fraction + print(f"\nCalculating minimal medium for {fraction*100:.0f}% growth ({target_growth:.3f} /h)...") + + min_medium = minimal_medium( + model, + target_growth, + minimize_components=True, + open_exchanges=True + ) + + minimal_media[fraction] = min_medium + print(f" Required components: {len(min_medium)}") + print(f" Components: {list(min_medium.index)}") + +# Step 4: Compare media compositions +media_df = pd.DataFrame(minimal_media).fillna(0) +media_df.to_csv("minimal_media_comparison.csv") + +# Step 5: Test aerobic vs anaerobic conditions +print("\n--- Aerobic vs Anaerobic Comparison ---") + +# Aerobic +model_aerobic = model.copy() +aerobic_growth = model_aerobic.slim_optimize() +aerobic_medium = minimal_medium(model_aerobic, aerobic_growth * 0.9, minimize_components=True) + +# Anaerobic +model_anaerobic = model.copy() +medium_anaerobic = model_anaerobic.medium +medium_anaerobic["EX_o2_e"] = 0.0 +model_anaerobic.medium = medium_anaerobic +anaerobic_growth = model_anaerobic.slim_optimize() +anaerobic_medium = minimal_medium(model_anaerobic, anaerobic_growth * 0.9, minimize_components=True) + +print(f"Aerobic growth: {aerobic_growth:.3f} /h (requires {len(aerobic_medium)} components)") +print(f"Anaerobic growth: {anaerobic_growth:.3f} /h (requires {len(anaerobic_medium)} components)") + +# Step 6: Identify unique requirements +aerobic_only = set(aerobic_medium.index) - set(anaerobic_medium.index) +anaerobic_only = set(anaerobic_medium.index) - set(aerobic_medium.index) +shared = set(aerobic_medium.index) & set(anaerobic_medium.index) + +print(f"\nShared components: {len(shared)}") +print(f"Aerobic-only: {aerobic_only}") +print(f"Anaerobic-only: {anaerobic_only}") + +# Step 7: Test custom medium +print("\n--- Testing Custom Medium ---") +custom_medium = { + "EX_glc__D_e": 10.0, # Glucose + "EX_o2_e": 20.0, # Oxygen + "EX_nh4_e": 5.0, # Ammonium + "EX_pi_e": 5.0, # Phosphate + "EX_so4_e": 1.0, # Sulfate +} + +with model: + model.medium = custom_medium + custom_growth = model.optimize().objective_value + print(f"Growth on custom medium: {custom_growth:.3f} /h") + + # Check which nutrients are limiting + for exchange in custom_medium: + with model: + # Double the uptake rate + medium_test = model.medium + medium_test[exchange] *= 2 + model.medium = medium_test + test_growth = model.optimize().objective_value + improvement = (test_growth - custom_growth) / custom_growth * 100 + if improvement > 1: + print(f" {exchange}: +{improvement:.1f}% growth when doubled (LIMITING)") +``` + +## Workflow 3: Flux Space Exploration with Sampling + +This workflow demonstrates comprehensive flux space analysis using FVA and sampling. + +```python +from cobra.io import load_model +from cobra.flux_analysis import flux_variability_analysis +from cobra.sampling import sample +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +# Step 1: Load model +model = load_model("ecoli") +baseline = model.slim_optimize() +print(f"Baseline growth: {baseline:.3f} /h") + +# Step 2: Perform FVA at optimal growth +print("\nPerforming FVA at optimal growth...") +fva_optimal = flux_variability_analysis(model, fraction_of_optimum=1.0) + +# Step 3: Identify reactions with flexibility +fva_optimal["range"] = fva_optimal["maximum"] - fva_optimal["minimum"] +fva_optimal["relative_range"] = fva_optimal["range"] / (fva_optimal["maximum"].abs() + 1e-9) + +flexible_reactions = fva_optimal[fva_optimal["range"] > 1.0].sort_values("range", ascending=False) +print(f"\nFound {len(flexible_reactions)} reactions with >1.0 mmol/gDW/h flexibility") +print("\nTop 10 most flexible reactions:") +print(flexible_reactions.head(10)[["minimum", "maximum", "range"]]) + +# Step 4: Perform FVA at suboptimal growth (90%) +print("\nPerforming FVA at 90% optimal growth...") +fva_suboptimal = flux_variability_analysis(model, fraction_of_optimum=0.9) +fva_suboptimal["range"] = fva_suboptimal["maximum"] - fva_suboptimal["minimum"] + +# Step 5: Compare flexibility at different optimality levels +comparison = pd.DataFrame({ + "range_100": fva_optimal["range"], + "range_90": fva_suboptimal["range"] +}) +comparison["range_increase"] = comparison["range_90"] - comparison["range_100"] + +print("\nReactions with largest increase in flexibility at suboptimality:") +print(comparison.sort_values("range_increase", ascending=False).head(10)) + +# Step 6: Perform flux sampling +print("\nPerforming flux sampling (1000 samples)...") +samples = sample(model, n=1000, method="optgp", processes=4) + +# Step 7: Analyze sampling results for key reactions +key_reactions = ["PFK", "FBA", "TPI", "GAPD", "PGK", "PGM", "ENO", "PYK"] +available_key_reactions = [r for r in key_reactions if r in samples.columns] + +if available_key_reactions: + fig, axes = plt.subplots(2, 4, figsize=(16, 8)) + axes = axes.flatten() + + for idx, reaction_id in enumerate(available_key_reactions[:8]): + ax = axes[idx] + samples[reaction_id].hist(bins=30, ax=ax, alpha=0.7) + + # Overlay FVA bounds + fva_min = fva_optimal.loc[reaction_id, "minimum"] + fva_max = fva_optimal.loc[reaction_id, "maximum"] + ax.axvline(fva_min, color='r', linestyle='--', label='FVA min') + ax.axvline(fva_max, color='r', linestyle='--', label='FVA max') + + ax.set_xlabel("Flux (mmol/gDW/h)") + ax.set_ylabel("Frequency") + ax.set_title(reaction_id) + if idx == 0: + ax.legend() + + plt.tight_layout() + plt.savefig("flux_distributions.png", dpi=300) + +# Step 8: Calculate correlation between reactions +print("\nCalculating flux correlations...") +correlation_matrix = samples[available_key_reactions].corr() + +fig, ax = plt.subplots(figsize=(10, 8)) +sns.heatmap(correlation_matrix, annot=True, fmt=".2f", cmap="coolwarm", + center=0, ax=ax, square=True) +ax.set_title("Flux Correlations Between Key Glycolysis Reactions") +plt.tight_layout() +plt.savefig("flux_correlations.png", dpi=300) + +# Step 9: Identify reaction modules (highly correlated groups) +print("\nHighly correlated reaction pairs (|r| > 0.9):") +for i in range(len(correlation_matrix)): + for j in range(i+1, len(correlation_matrix)): + corr = correlation_matrix.iloc[i, j] + if abs(corr) > 0.9: + print(f" {correlation_matrix.index[i]} <-> {correlation_matrix.columns[j]}: {corr:.3f}") + +# Step 10: Export all results +fva_optimal.to_csv("fva_optimal.csv") +fva_suboptimal.to_csv("fva_suboptimal.csv") +samples.to_csv("flux_samples.csv") +correlation_matrix.to_csv("flux_correlations.csv") +``` + +## Workflow 4: Production Strain Design + +This workflow demonstrates how to design a production strain for a target metabolite. + +```python +from cobra.io import load_model +from cobra.flux_analysis import ( + production_envelope, + flux_variability_analysis, + single_gene_deletion +) +import pandas as pd +import matplotlib.pyplot as plt + +# Step 1: Define production target +TARGET_METABOLITE = "EX_ac_e" # Acetate production +CARBON_SOURCE = "EX_glc__D_e" # Glucose uptake + +# Step 2: Load model +model = load_model("ecoli") +print(f"Designing strain for {TARGET_METABOLITE} production") + +# Step 3: Calculate baseline production envelope +print("\nCalculating production envelope...") +envelope = production_envelope( + model, + reactions=[CARBON_SOURCE, TARGET_METABOLITE], + carbon_sources=CARBON_SOURCE +) + +# Visualize production envelope +fig, ax = plt.subplots(figsize=(10, 6)) +ax.plot(envelope[CARBON_SOURCE], envelope["mass_yield_maximum"], 'b-', label='Max yield') +ax.plot(envelope[CARBON_SOURCE], envelope["mass_yield_minimum"], 'r-', label='Min yield') +ax.set_xlabel(f"Glucose uptake (mmol/gDW/h)") +ax.set_ylabel(f"Acetate yield") +ax.set_title("Wild-type Production Envelope") +ax.legend() +ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.savefig("production_envelope_wildtype.png", dpi=300) + +# Step 4: Maximize production while maintaining growth +print("\nOptimizing for production...") + +# Set minimum growth constraint +MIN_GROWTH = 0.1 # Maintain at least 10% of max growth + +with model: + # Change objective to product formation + model.objective = TARGET_METABOLITE + model.objective_direction = "max" + + # Add growth constraint + growth_reaction = model.reactions.get_by_id(model.objective.name) if hasattr(model.objective, 'name') else list(model.objective.variables.keys())[0].name + max_growth = model.slim_optimize() + +model.reactions.BIOMASS_Ecoli_core_w_GAM.lower_bound = MIN_GROWTH + +with model: + model.objective = TARGET_METABOLITE + model.objective_direction = "max" + production_solution = model.optimize() + + max_production = production_solution.objective_value + print(f"Maximum production: {max_production:.3f} mmol/gDW/h") + print(f"Growth rate: {production_solution.fluxes['BIOMASS_Ecoli_core_w_GAM']:.3f} /h") + +# Step 5: Identify beneficial gene knockouts +print("\nScreening for beneficial knockouts...") + +# Reset model +model.reactions.BIOMASS_Ecoli_core_w_GAM.lower_bound = MIN_GROWTH +model.objective = TARGET_METABOLITE +model.objective_direction = "max" + +knockout_results = [] +for gene in model.genes: + with model: + gene.knock_out() + try: + solution = model.optimize() + if solution.status == "optimal": + production = solution.objective_value + growth = solution.fluxes["BIOMASS_Ecoli_core_w_GAM"] + + if production > max_production * 1.05: # >5% improvement + knockout_results.append({ + "gene": gene.id, + "production": production, + "growth": growth, + "improvement": (production / max_production - 1) * 100 + }) + except: + continue + +knockout_df = pd.DataFrame(knockout_results) +if len(knockout_df) > 0: + knockout_df = knockout_df.sort_values("improvement", ascending=False) + print(f"\nFound {len(knockout_df)} beneficial knockouts:") + print(knockout_df.head(10)) + knockout_df.to_csv("beneficial_knockouts.csv", index=False) +else: + print("No beneficial single knockouts found") + +# Step 6: Test combination of best knockouts +if len(knockout_df) > 0: + print("\nTesting knockout combinations...") + top_genes = knockout_df.head(3)["gene"].tolist() + + with model: + for gene_id in top_genes: + model.genes.get_by_id(gene_id).knock_out() + + solution = model.optimize() + if solution.status == "optimal": + combined_production = solution.objective_value + combined_growth = solution.fluxes["BIOMASS_Ecoli_core_w_GAM"] + combined_improvement = (combined_production / max_production - 1) * 100 + + print(f"\nCombined knockout results:") + print(f" Genes: {', '.join(top_genes)}") + print(f" Production: {combined_production:.3f} mmol/gDW/h") + print(f" Growth: {combined_growth:.3f} /h") + print(f" Improvement: {combined_improvement:.1f}%") + +# Step 7: Analyze flux distribution in production strain +if len(knockout_df) > 0: + best_gene = knockout_df.iloc[0]["gene"] + + with model: + model.genes.get_by_id(best_gene).knock_out() + solution = model.optimize() + + # Get active pathways + active_fluxes = solution.fluxes[solution.fluxes.abs() > 0.1] + active_fluxes.to_csv(f"production_strain_fluxes_{best_gene}_knockout.csv") + + print(f"\nActive reactions in production strain: {len(active_fluxes)}") +``` + +## Workflow 5: Model Validation and Debugging + +This workflow shows systematic approaches to validate and debug metabolic models. + +```python +from cobra.io import load_model, read_sbml_model +from cobra.flux_analysis import flux_variability_analysis +import pandas as pd + +# Step 1: Load model +model = load_model("ecoli") # Or read_sbml_model("your_model.xml") +print(f"Model: {model.id}") +print(f"Reactions: {len(model.reactions)}") +print(f"Metabolites: {len(model.metabolites)}") +print(f"Genes: {len(model.genes)}") + +# Step 2: Check model feasibility +print("\n--- Feasibility Check ---") +try: + objective_value = model.slim_optimize() + print(f"Model is feasible (objective: {objective_value:.3f})") +except: + print("Model is INFEASIBLE") + print("Troubleshooting steps:") + + # Check for blocked reactions + from cobra.flux_analysis import find_blocked_reactions + blocked = find_blocked_reactions(model) + print(f" Blocked reactions: {len(blocked)}") + if len(blocked) > 0: + print(f" First 10 blocked: {list(blocked)[:10]}") + + # Check medium + print(f"\n Current medium: {model.medium}") + + # Try opening all exchanges + for reaction in model.exchanges: + reaction.lower_bound = -1000 + + try: + objective_value = model.slim_optimize() + print(f"\n Model feasible with open exchanges (objective: {objective_value:.3f})") + print(" Issue: Medium constraints too restrictive") + except: + print("\n Model still infeasible with open exchanges") + print(" Issue: Structural problem (missing reactions, mass imbalance, etc.)") + +# Step 3: Check mass and charge balance +print("\n--- Mass and Charge Balance Check ---") +unbalanced_reactions = [] +for reaction in model.reactions: + try: + balance = reaction.check_mass_balance() + if balance: + unbalanced_reactions.append({ + "reaction": reaction.id, + "imbalance": balance + }) + except: + pass + +if unbalanced_reactions: + print(f"Found {len(unbalanced_reactions)} unbalanced reactions:") + for item in unbalanced_reactions[:10]: + print(f" {item['reaction']}: {item['imbalance']}") +else: + print("All reactions are mass balanced") + +# Step 4: Identify dead-end metabolites +print("\n--- Dead-end Metabolite Check ---") +dead_end_metabolites = [] +for metabolite in model.metabolites: + producing_reactions = [r for r in metabolite.reactions + if r.metabolites[metabolite] > 0] + consuming_reactions = [r for r in metabolite.reactions + if r.metabolites[metabolite] < 0] + + if len(producing_reactions) == 0 or len(consuming_reactions) == 0: + dead_end_metabolites.append({ + "metabolite": metabolite.id, + "producers": len(producing_reactions), + "consumers": len(consuming_reactions) + }) + +if dead_end_metabolites: + print(f"Found {len(dead_end_metabolites)} dead-end metabolites:") + for item in dead_end_metabolites[:10]: + print(f" {item['metabolite']}: {item['producers']} producers, {item['consumers']} consumers") +else: + print("No dead-end metabolites found") + +# Step 5: Check for duplicate reactions +print("\n--- Duplicate Reaction Check ---") +reaction_equations = {} +duplicates = [] + +for reaction in model.reactions: + equation = reaction.build_reaction_string() + if equation in reaction_equations: + duplicates.append({ + "reaction1": reaction_equations[equation], + "reaction2": reaction.id, + "equation": equation + }) + else: + reaction_equations[equation] = reaction.id + +if duplicates: + print(f"Found {len(duplicates)} duplicate reaction pairs:") + for item in duplicates[:10]: + print(f" {item['reaction1']} == {item['reaction2']}") +else: + print("No duplicate reactions found") + +# Step 6: Identify orphan genes +print("\n--- Orphan Gene Check ---") +orphan_genes = [gene for gene in model.genes if len(gene.reactions) == 0] + +if orphan_genes: + print(f"Found {len(orphan_genes)} orphan genes (not associated with reactions):") + print(f" First 10: {[g.id for g in orphan_genes[:10]]}") +else: + print("No orphan genes found") + +# Step 7: Check for thermodynamically infeasible loops +print("\n--- Thermodynamic Loop Check ---") +fva_loopless = flux_variability_analysis(model, loopless=True) +fva_standard = flux_variability_analysis(model) + +loop_reactions = [] +for reaction_id in fva_standard.index: + standard_range = fva_standard.loc[reaction_id, "maximum"] - fva_standard.loc[reaction_id, "minimum"] + loopless_range = fva_loopless.loc[reaction_id, "maximum"] - fva_loopless.loc[reaction_id, "minimum"] + + if standard_range > loopless_range + 0.1: + loop_reactions.append({ + "reaction": reaction_id, + "standard_range": standard_range, + "loopless_range": loopless_range + }) + +if loop_reactions: + print(f"Found {len(loop_reactions)} reactions potentially involved in loops:") + loop_df = pd.DataFrame(loop_reactions).sort_values("standard_range", ascending=False) + print(loop_df.head(10)) +else: + print("No thermodynamically infeasible loops detected") + +# Step 8: Generate validation report +print("\n--- Generating Validation Report ---") +validation_report = { + "model_id": model.id, + "feasible": objective_value if 'objective_value' in locals() else None, + "n_reactions": len(model.reactions), + "n_metabolites": len(model.metabolites), + "n_genes": len(model.genes), + "n_unbalanced": len(unbalanced_reactions), + "n_dead_ends": len(dead_end_metabolites), + "n_duplicates": len(duplicates), + "n_orphan_genes": len(orphan_genes), + "n_loop_reactions": len(loop_reactions) +} + +validation_df = pd.DataFrame([validation_report]) +validation_df.to_csv("model_validation_report.csv", index=False) +print("Validation report saved to model_validation_report.csv") +``` + +These workflows provide comprehensive templates for common COBRApy tasks. Adapt them as needed for specific research questions and models. diff --git a/scientific-packages/datamol/SKILL.md b/scientific-packages/datamol/SKILL.md new file mode 100644 index 0000000..095736f --- /dev/null +++ b/scientific-packages/datamol/SKILL.md @@ -0,0 +1,704 @@ +--- +name: datamol +description: Comprehensive toolkit for molecular cheminformatics using datamol, a Pythonic layer built on RDKit. Use this skill when working with molecular structures, SMILES strings, chemical reactions, molecular descriptors, conformer generation, molecular clustering, scaffold analysis, or any cheminformatics tasks. This skill should be applied when users need to process molecules, analyze chemical properties, visualize molecular structures, fragment compounds, or perform molecular similarity calculations. +--- + +# Datamol Cheminformatics Skill + +## Overview + +Datamol is a Python library that provides a lightweight, Pythonic abstraction layer over RDKit for molecular cheminformatics. It simplifies complex molecular operations with sensible defaults, efficient parallelization, and modern I/O capabilities. All molecular objects are native `rdkit.Chem.Mol` instances, ensuring full compatibility with the RDKit ecosystem. + +**Key capabilities**: +- Molecular format conversion (SMILES, SELFIES, InChI) +- Structure standardization and sanitization +- Molecular descriptors and fingerprints +- 3D conformer generation and analysis +- Clustering and diversity selection +- Scaffold and fragment analysis +- Chemical reaction application +- Visualization and alignment +- Batch processing with parallelization +- Cloud storage support via fsspec + +## Installation and Setup + +Guide users to install datamol: + +```bash +# Via conda/mamba (recommended) +conda install -c conda-forge datamol + +# Via pip +pip install datamol +``` + +**Import convention**: +```python +import datamol as dm +``` + +## Core Workflows + +### 1. Basic Molecule Handling + +**Creating molecules from SMILES**: +```python +import datamol as dm + +# Single molecule +mol = dm.to_mol("CCO") # Ethanol + +# From list of SMILES +smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"] +mols = [dm.to_mol(smi) for smi in smiles_list] + +# Error handling +mol = dm.to_mol("invalid_smiles") # Returns None +if mol is None: + print("Failed to parse SMILES") +``` + +**Converting molecules to SMILES**: +```python +# Canonical SMILES +smiles = dm.to_smiles(mol) + +# Isomeric SMILES (includes stereochemistry) +smiles = dm.to_smiles(mol, isomeric=True) + +# Other formats +inchi = dm.to_inchi(mol) +inchikey = dm.to_inchikey(mol) +selfies = dm.to_selfies(mol) +``` + +**Standardization and sanitization** (always recommend for user-provided molecules): +```python +# Sanitize molecule +mol = dm.sanitize_mol(mol) + +# Full standardization (recommended for datasets) +mol = dm.standardize_mol( + mol, + disconnect_metals=True, + normalize=True, + reionize=True +) + +# For SMILES strings directly +clean_smiles = dm.standardize_smiles(smiles) +``` + +### 2. Reading and Writing Molecular Files + +Refer to `references/io_module.md` for comprehensive I/O documentation. + +**Reading files**: +```python +# SDF files (most common in chemistry) +df = dm.read_sdf("compounds.sdf", mol_column='mol') + +# SMILES files +df = dm.read_smi("molecules.smi", smiles_column='smiles', mol_column='mol') + +# CSV with SMILES column +df = dm.read_csv("data.csv", smiles_column="SMILES", mol_column="mol") + +# Excel files +df = dm.read_excel("compounds.xlsx", sheet_name=0, mol_column="mol") + +# Universal reader (auto-detects format) +df = dm.open_df("file.sdf") # Works with .sdf, .csv, .xlsx, .parquet, .json +``` + +**Writing files**: +```python +# Save as SDF +dm.to_sdf(mols, "output.sdf") +# Or from DataFrame +dm.to_sdf(df, "output.sdf", mol_column="mol") + +# Save as SMILES file +dm.to_smi(mols, "output.smi") + +# Excel with rendered molecule images +dm.to_xlsx(df, "output.xlsx", mol_columns=["mol"]) +``` + +**Remote file support** (S3, GCS, HTTP): +```python +# Read from cloud storage +df = dm.read_sdf("s3://bucket/compounds.sdf") +df = dm.read_csv("https://example.com/data.csv") + +# Write to cloud storage +dm.to_sdf(mols, "s3://bucket/output.sdf") +``` + +### 3. Molecular Descriptors and Properties + +Refer to `references/descriptors_viz.md` for detailed descriptor documentation. + +**Computing descriptors for a single molecule**: +```python +# Get standard descriptor set +descriptors = dm.descriptors.compute_many_descriptors(mol) +# Returns: {'mw': 46.07, 'logp': -0.03, 'hbd': 1, 'hba': 1, +# 'tpsa': 20.23, 'n_aromatic_atoms': 0, ...} +``` + +**Batch descriptor computation** (recommended for datasets): +```python +# Compute for all molecules in parallel +desc_df = dm.descriptors.batch_compute_many_descriptors( + mols, + n_jobs=-1, # Use all CPU cores + progress=True # Show progress bar +) +``` + +**Specific descriptors**: +```python +# Aromaticity +n_aromatic = dm.descriptors.n_aromatic_atoms(mol) +aromatic_ratio = dm.descriptors.n_aromatic_atoms_proportion(mol) + +# Stereochemistry +n_stereo = dm.descriptors.n_stereo_centers(mol) +n_unspec = dm.descriptors.n_stereo_centers_unspecified(mol) + +# Flexibility +n_rigid = dm.descriptors.n_rigid_bonds(mol) +``` + +**Drug-likeness filtering (Lipinski's Rule of Five)**: +```python +# Filter compounds +def is_druglike(mol): + desc = dm.descriptors.compute_many_descriptors(mol) + return ( + desc['mw'] <= 500 and + desc['logp'] <= 5 and + desc['hbd'] <= 5 and + desc['hba'] <= 10 + ) + +druglike_mols = [mol for mol in mols if is_druglike(mol)] +``` + +### 4. Molecular Fingerprints and Similarity + +**Generating fingerprints**: +```python +# ECFP (Extended Connectivity Fingerprint, default) +fp = dm.to_fp(mol, fp_type='ecfp', radius=2, n_bits=2048) + +# Other fingerprint types +fp_maccs = dm.to_fp(mol, fp_type='maccs') +fp_topological = dm.to_fp(mol, fp_type='topological') +fp_atompair = dm.to_fp(mol, fp_type='atompair') +``` + +**Similarity calculations**: +```python +# Pairwise distances within a set +distance_matrix = dm.pdist(mols, n_jobs=-1) + +# Distances between two sets +distances = dm.cdist(query_mols, library_mols, n_jobs=-1) + +# Find most similar molecules +from scipy.spatial.distance import squareform +dist_matrix = squareform(dm.pdist(mols)) +# Lower distance = higher similarity (Tanimoto distance = 1 - Tanimoto similarity) +``` + +### 5. Clustering and Diversity Selection + +Refer to `references/core_api.md` for clustering details. + +**Butina clustering**: +```python +# Cluster molecules by structural similarity +clusters = dm.cluster_mols( + mols, + cutoff=0.2, # Tanimoto distance threshold (0=identical, 1=completely different) + n_jobs=-1 # Parallel processing +) + +# Each cluster is a list of molecule indices +for i, cluster in enumerate(clusters): + print(f"Cluster {i}: {len(cluster)} molecules") + cluster_mols = [mols[idx] for idx in cluster] +``` + +**Important**: Butina clustering builds a full distance matrix - suitable for ~1000 molecules, not for 10,000+. + +**Diversity selection**: +```python +# Pick diverse subset +diverse_mols = dm.pick_diverse( + mols, + npick=100 # Select 100 diverse molecules +) + +# Pick cluster centroids +centroids = dm.pick_centroids( + mols, + npick=50 # Select 50 representative molecules +) +``` + +### 6. Scaffold Analysis + +Refer to `references/fragments_scaffolds.md` for complete scaffold documentation. + +**Extracting Murcko scaffolds**: +```python +# Get Bemis-Murcko scaffold (core structure) +scaffold = dm.to_scaffold_murcko(mol) +scaffold_smiles = dm.to_smiles(scaffold) +``` + +**Scaffold-based analysis**: +```python +# Group compounds by scaffold +from collections import Counter + +scaffolds = [dm.to_scaffold_murcko(mol) for mol in mols] +scaffold_smiles = [dm.to_smiles(s) for s in scaffolds] + +# Count scaffold frequency +scaffold_counts = Counter(scaffold_smiles) +most_common = scaffold_counts.most_common(10) + +# Create scaffold-to-molecules mapping +scaffold_groups = {} +for mol, scaf_smi in zip(mols, scaffold_smiles): + if scaf_smi not in scaffold_groups: + scaffold_groups[scaf_smi] = [] + scaffold_groups[scaf_smi].append(mol) +``` + +**Scaffold-based train/test splitting** (for ML): +```python +# Ensure train and test sets have different scaffolds +scaffold_to_mols = {} +for mol, scaf in zip(mols, scaffold_smiles): + if scaf not in scaffold_to_mols: + scaffold_to_mols[scaf] = [] + scaffold_to_mols[scaf].append(mol) + +# Split scaffolds into train/test +import random +scaffolds = list(scaffold_to_mols.keys()) +random.shuffle(scaffolds) +split_idx = int(0.8 * len(scaffolds)) +train_scaffolds = scaffolds[:split_idx] +test_scaffolds = scaffolds[split_idx:] + +# Get molecules for each split +train_mols = [mol for scaf in train_scaffolds for mol in scaffold_to_mols[scaf]] +test_mols = [mol for scaf in test_scaffolds for mol in scaffold_to_mols[scaf]] +``` + +### 7. Molecular Fragmentation + +Refer to `references/fragments_scaffolds.md` for fragmentation details. + +**BRICS fragmentation** (16 bond types): +```python +# Fragment molecule +fragments = dm.fragment.brics(mol) +# Returns: set of fragment SMILES with attachment points like '[1*]CCN' +``` + +**RECAP fragmentation** (11 bond types): +```python +fragments = dm.fragment.recap(mol) +``` + +**Fragment analysis**: +```python +# Find common fragments across compound library +from collections import Counter + +all_fragments = [] +for mol in mols: + frags = dm.fragment.brics(mol) + all_fragments.extend(frags) + +fragment_counts = Counter(all_fragments) +common_frags = fragment_counts.most_common(20) + +# Fragment-based scoring +def fragment_score(mol, reference_fragments): + mol_frags = dm.fragment.brics(mol) + overlap = mol_frags.intersection(reference_fragments) + return len(overlap) / len(mol_frags) if mol_frags else 0 +``` + +### 8. 3D Conformer Generation + +Refer to `references/conformers_module.md` for detailed conformer documentation. + +**Generating conformers**: +```python +# Generate 3D conformers +mol_3d = dm.conformers.generate( + mol, + n_confs=50, # Number to generate (auto if None) + rms_cutoff=0.5, # Filter similar conformers (Ångströms) + minimize_energy=True, # Minimize with UFF force field + method='ETKDGv3' # Embedding method (recommended) +) + +# Access conformers +n_conformers = mol_3d.GetNumConformers() +conf = mol_3d.GetConformer(0) # Get first conformer +positions = conf.GetPositions() # Nx3 array of atom coordinates +``` + +**Conformer clustering**: +```python +# Cluster conformers by RMSD +clusters = dm.conformers.cluster( + mol_3d, + rms_cutoff=1.0, + centroids=False +) + +# Get representative conformers +centroids = dm.conformers.return_centroids(mol_3d, clusters) +``` + +**SASA calculation**: +```python +# Calculate solvent accessible surface area +sasa_values = dm.conformers.sasa(mol_3d, n_jobs=-1) + +# Access SASA from conformer properties +conf = mol_3d.GetConformer(0) +sasa = conf.GetDoubleProp('rdkit_free_sasa') +``` + +### 9. Visualization + +Refer to `references/descriptors_viz.md` for visualization documentation. + +**Basic molecule grid**: +```python +# Visualize molecules +dm.viz.to_image( + mols[:20], + legends=[dm.to_smiles(m) for m in mols[:20]], + n_cols=5, + mol_size=(300, 300) +) + +# Save to file +dm.viz.to_image(mols, outfile="molecules.png") + +# SVG for publications +dm.viz.to_image(mols, outfile="molecules.svg", use_svg=True) +``` + +**Aligned visualization** (for SAR analysis): +```python +# Align molecules by common substructure +dm.viz.to_image( + similar_mols, + align=True, # Enable MCS alignment + legends=activity_labels, + n_cols=4 +) +``` + +**Highlighting substructures**: +```python +# Highlight specific atoms and bonds +dm.viz.to_image( + mol, + highlight_atom=[0, 1, 2, 3], # Atom indices + highlight_bond=[0, 1, 2] # Bond indices +) +``` + +**Conformer visualization**: +```python +# Display multiple conformers +dm.viz.conformers( + mol_3d, + n_confs=10, + align_conf=True, + n_cols=3 +) +``` + +### 10. Chemical Reactions + +Refer to `references/reactions_data.md` for reactions documentation. + +**Applying reactions**: +```python +from rdkit.Chem import rdChemReactions + +# Define reaction from SMARTS +rxn_smarts = '[C:1](=[O:2])[OH:3]>>[C:1](=[O:2])[Cl:3]' +rxn = rdChemReactions.ReactionFromSmarts(rxn_smarts) + +# Apply to molecule +reactant = dm.to_mol("CC(=O)O") # Acetic acid +product = dm.reactions.apply_reaction( + rxn, + (reactant,), + sanitize=True +) + +# Convert to SMILES +product_smiles = dm.to_smiles(product) +``` + +**Batch reaction application**: +```python +# Apply reaction to library +products = [] +for mol in reactant_mols: + try: + prod = dm.reactions.apply_reaction(rxn, (mol,)) + if prod is not None: + products.append(prod) + except Exception as e: + print(f"Reaction failed: {e}") +``` + +## Parallelization + +Datamol includes built-in parallelization for many operations. Use `n_jobs` parameter: +- `n_jobs=1`: Sequential (no parallelization) +- `n_jobs=-1`: Use all available CPU cores +- `n_jobs=4`: Use 4 cores + +**Functions supporting parallelization**: +- `dm.read_sdf(..., n_jobs=-1)` +- `dm.descriptors.batch_compute_many_descriptors(..., n_jobs=-1)` +- `dm.cluster_mols(..., n_jobs=-1)` +- `dm.pdist(..., n_jobs=-1)` +- `dm.conformers.sasa(..., n_jobs=-1)` + +**Progress bars**: Many batch operations support `progress=True` parameter. + +## Common Workflows and Patterns + +### Complete Pipeline: Data Loading → Filtering → Analysis + +```python +import datamol as dm +import pandas as pd + +# 1. Load molecules +df = dm.read_sdf("compounds.sdf") + +# 2. Standardize +df['mol'] = df['mol'].apply(lambda m: dm.standardize_mol(m) if m else None) +df = df[df['mol'].notna()] # Remove failed molecules + +# 3. Compute descriptors +desc_df = dm.descriptors.batch_compute_many_descriptors( + df['mol'].tolist(), + n_jobs=-1, + progress=True +) + +# 4. Filter by drug-likeness +druglike = ( + (desc_df['mw'] <= 500) & + (desc_df['logp'] <= 5) & + (desc_df['hbd'] <= 5) & + (desc_df['hba'] <= 10) +) +filtered_df = df[druglike] + +# 5. Cluster and select diverse subset +diverse_mols = dm.pick_diverse( + filtered_df['mol'].tolist(), + npick=100 +) + +# 6. Visualize results +dm.viz.to_image( + diverse_mols, + legends=[dm.to_smiles(m) for m in diverse_mols], + outfile="diverse_compounds.png", + n_cols=10 +) +``` + +### Structure-Activity Relationship (SAR) Analysis + +```python +# Group by scaffold +scaffolds = [dm.to_scaffold_murcko(mol) for mol in mols] +scaffold_smiles = [dm.to_smiles(s) for s in scaffolds] + +# Create DataFrame with activities +sar_df = pd.DataFrame({ + 'mol': mols, + 'scaffold': scaffold_smiles, + 'activity': activities # User-provided activity data +}) + +# Analyze each scaffold series +for scaffold, group in sar_df.groupby('scaffold'): + if len(group) >= 3: # Need multiple examples + print(f"\nScaffold: {scaffold}") + print(f"Count: {len(group)}") + print(f"Activity range: {group['activity'].min():.2f} - {group['activity'].max():.2f}") + + # Visualize with activities as legends + dm.viz.to_image( + group['mol'].tolist(), + legends=[f"Activity: {act:.2f}" for act in group['activity']], + align=True # Align by common substructure + ) +``` + +### Virtual Screening Pipeline + +```python +# 1. Generate fingerprints for query and library +query_fps = [dm.to_fp(mol) for mol in query_actives] +library_fps = [dm.to_fp(mol) for mol in library_mols] + +# 2. Calculate similarities +from scipy.spatial.distance import cdist +import numpy as np + +distances = dm.cdist(query_actives, library_mols, n_jobs=-1) + +# 3. Find closest matches (min distance to any query) +min_distances = distances.min(axis=0) +similarities = 1 - min_distances # Convert distance to similarity + +# 4. Rank and select top hits +top_indices = np.argsort(similarities)[::-1][:100] # Top 100 +top_hits = [library_mols[i] for i in top_indices] +top_scores = [similarities[i] for i in top_indices] + +# 5. Visualize hits +dm.viz.to_image( + top_hits[:20], + legends=[f"Sim: {score:.3f}" for score in top_scores[:20]], + outfile="screening_hits.png" +) +``` + +## Reference Documentation + +For detailed API documentation, consult these reference files: + +- **`references/core_api.md`**: Core namespace functions (conversions, standardization, fingerprints, clustering) +- **`references/io_module.md`**: File I/O operations (read/write SDF, CSV, Excel, remote files) +- **`references/conformers_module.md`**: 3D conformer generation, clustering, SASA calculations +- **`references/descriptors_viz.md`**: Molecular descriptors and visualization functions +- **`references/fragments_scaffolds.md`**: Scaffold extraction, BRICS/RECAP fragmentation +- **`references/reactions_data.md`**: Chemical reactions and toy datasets + +## Best Practices + +1. **Always standardize molecules** from external sources: + ```python + mol = dm.standardize_mol(mol, disconnect_metals=True, normalize=True, reionize=True) + ``` + +2. **Check for None values** after molecule parsing: + ```python + mol = dm.to_mol(smiles) + if mol is None: + # Handle invalid SMILES + ``` + +3. **Use parallel processing** for large datasets: + ```python + result = dm.operation(..., n_jobs=-1, progress=True) + ``` + +4. **Leverage fsspec** for cloud storage: + ```python + df = dm.read_sdf("s3://bucket/compounds.sdf") + ``` + +5. **Use appropriate fingerprints** for similarity: + - ECFP (Morgan): General purpose, structural similarity + - MACCS: Fast, smaller feature space + - Atom pairs: Considers atom pairs and distances + +6. **Consider scale limitations**: + - Butina clustering: ~1,000 molecules (full distance matrix) + - For larger datasets: Use diversity selection or hierarchical methods + +7. **Scaffold splitting for ML**: Ensure proper train/test separation by scaffold + +8. **Align molecules** when visualizing SAR series + +## Error Handling + +```python +# Safe molecule creation +def safe_to_mol(smiles): + try: + mol = dm.to_mol(smiles) + if mol is not None: + mol = dm.standardize_mol(mol) + return mol + except Exception as e: + print(f"Failed to process {smiles}: {e}") + return None + +# Safe batch processing +valid_mols = [] +for smiles in smiles_list: + mol = safe_to_mol(smiles) + if mol is not None: + valid_mols.append(mol) +``` + +## Integration with Machine Learning + +```python +# Feature generation +X = np.array([dm.to_fp(mol) for mol in mols]) + +# Or descriptors +desc_df = dm.descriptors.batch_compute_many_descriptors(mols, n_jobs=-1) +X = desc_df.values + +# Train model +from sklearn.ensemble import RandomForestRegressor +model = RandomForestRegressor() +model.fit(X, y_target) + +# Predict +predictions = model.predict(X_test) +``` + +## Troubleshooting + +**Issue**: Molecule parsing fails +- **Solution**: Use `dm.standardize_smiles()` first or try `dm.fix_mol()` + +**Issue**: Memory errors with clustering +- **Solution**: Use `dm.pick_diverse()` instead of full clustering for large sets + +**Issue**: Slow conformer generation +- **Solution**: Reduce `n_confs` or increase `rms_cutoff` to generate fewer conformers + +**Issue**: Remote file access fails +- **Solution**: Ensure fsspec and appropriate cloud provider libraries are installed (s3fs, gcsfs, etc.) + +## Additional Resources + +- **Datamol Documentation**: https://docs.datamol.io/ +- **RDKit Documentation**: https://www.rdkit.org/docs/ +- **GitHub Repository**: https://github.com/datamol-io/datamol diff --git a/scientific-packages/datamol/references/conformers_module.md b/scientific-packages/datamol/references/conformers_module.md new file mode 100644 index 0000000..06fa2e0 --- /dev/null +++ b/scientific-packages/datamol/references/conformers_module.md @@ -0,0 +1,131 @@ +# Datamol Conformers Module Reference + +The `datamol.conformers` module provides tools for generating and analyzing 3D molecular conformations. + +## Conformer Generation + +### `dm.conformers.generate(mol, n_confs=None, rms_cutoff=None, minimize_energy=True, method='ETKDGv3', add_hs=True, ...)` +Generate 3D molecular conformers. +- **Parameters**: + - `mol`: Input molecule + - `n_confs`: Number of conformers to generate (auto-determined based on rotatable bonds if None) + - `rms_cutoff`: RMS threshold in Ångströms for filtering similar conformers (removes duplicates) + - `minimize_energy`: Apply UFF energy minimization (default: True) + - `method`: Embedding method - options: + - `'ETDG'` - Experimental Torsion Distance Geometry + - `'ETKDG'` - ETDG with additional basic knowledge + - `'ETKDGv2'` - Enhanced version 2 + - `'ETKDGv3'` - Enhanced version 3 (default, recommended) + - `add_hs`: Add hydrogens before embedding (default: True, critical for quality) + - `random_seed`: Set for reproducibility +- **Returns**: Molecule with embedded conformers +- **Example**: + ```python + mol = dm.to_mol("CCO") + mol_3d = dm.conformers.generate(mol, n_confs=10, rms_cutoff=0.5) + conformers = mol_3d.GetConformers() # Access all conformers + ``` + +## Conformer Clustering + +### `dm.conformers.cluster(mol, rms_cutoff=1.0, already_aligned=False, centroids=False)` +Group conformers by RMS distance. +- **Parameters**: + - `rms_cutoff`: Clustering threshold in Ångströms (default: 1.0) + - `already_aligned`: Whether conformers are pre-aligned + - `centroids`: Return centroid conformers (True) or cluster groups (False) +- **Returns**: Cluster information or centroid conformers +- **Use case**: Identify distinct conformational families + +### `dm.conformers.return_centroids(mol, conf_clusters, centroids=True)` +Extract representative conformers from clusters. +- **Parameters**: + - `conf_clusters`: Sequence of cluster indices from `cluster()` + - `centroids`: Return single molecule (True) or list of molecules (False) +- **Returns**: Centroid conformer(s) + +## Conformer Analysis + +### `dm.conformers.rmsd(mol)` +Calculate pairwise RMSD matrix across all conformers. +- **Requirements**: Minimum 2 conformers +- **Returns**: NxN matrix of RMSD values +- **Use case**: Quantify conformer diversity + +### `dm.conformers.sasa(mol, n_jobs=1, ...)` +Calculate Solvent Accessible Surface Area (SASA) using FreeSASA. +- **Parameters**: + - `n_jobs`: Parallelization for multiple conformers +- **Returns**: Array of SASA values (one per conformer) +- **Storage**: Values stored in each conformer as property `'rdkit_free_sasa'` +- **Example**: + ```python + sasa_values = dm.conformers.sasa(mol_3d) + # Or access from conformer properties + conf = mol_3d.GetConformer(0) + sasa = conf.GetDoubleProp('rdkit_free_sasa') + ``` + +## Low-Level Conformer Manipulation + +### `dm.conformers.center_of_mass(mol, conf_id=-1, use_atoms=True, round_coord=None)` +Calculate molecular center. +- **Parameters**: + - `conf_id`: Conformer index (-1 for first conformer) + - `use_atoms`: Use atomic masses (True) or geometric center (False) + - `round_coord`: Decimal precision for rounding +- **Returns**: 3D coordinates of center +- **Use case**: Centering molecules for visualization or alignment + +### `dm.conformers.get_coords(mol, conf_id=-1)` +Retrieve atomic coordinates from a conformer. +- **Returns**: Nx3 numpy array of atomic positions +- **Example**: + ```python + positions = dm.conformers.get_coords(mol_3d, conf_id=0) + # positions.shape: (num_atoms, 3) + ``` + +### `dm.conformers.translate(mol, conf_id=-1, transform_matrix=None)` +Reposition conformer using transformation matrix. +- **Modification**: Operates in-place +- **Use case**: Aligning or repositioning molecules + +## Workflow Example + +```python +import datamol as dm + +# 1. Create molecule and generate conformers +mol = dm.to_mol("CC(C)CCO") # Isopentanol +mol_3d = dm.conformers.generate( + mol, + n_confs=50, # Generate 50 initial conformers + rms_cutoff=0.5, # Filter similar conformers + minimize_energy=True # Minimize energy +) + +# 2. Analyze conformers +n_conformers = mol_3d.GetNumConformers() +print(f"Generated {n_conformers} unique conformers") + +# 3. Calculate SASA +sasa_values = dm.conformers.sasa(mol_3d) + +# 4. Cluster conformers +clusters = dm.conformers.cluster(mol_3d, rms_cutoff=1.0, centroids=False) + +# 5. Get representative conformers +centroids = dm.conformers.return_centroids(mol_3d, clusters) + +# 6. Access 3D coordinates +coords = dm.conformers.get_coords(mol_3d, conf_id=0) +``` + +## Key Concepts + +- **Distance Geometry**: Method for generating 3D structures from connectivity information +- **ETKDG**: Uses experimental torsion angle preferences and additional chemical knowledge +- **RMS Cutoff**: Lower values = more unique conformers; higher values = fewer, more distinct conformers +- **Energy Minimization**: Relaxes structures to nearest local energy minimum +- **Hydrogens**: Critical for accurate 3D geometry - always include during embedding diff --git a/scientific-packages/datamol/references/core_api.md b/scientific-packages/datamol/references/core_api.md new file mode 100644 index 0000000..12f4627 --- /dev/null +++ b/scientific-packages/datamol/references/core_api.md @@ -0,0 +1,130 @@ +# Datamol Core API Reference + +This document covers the main functions available in the datamol namespace. + +## Molecule Creation and Conversion + +### `to_mol(mol, ...)` +Convert SMILES string or other molecular representations to RDKit molecule objects. +- **Parameters**: Accepts SMILES strings, InChI, or other molecular formats +- **Returns**: `rdkit.Chem.Mol` object +- **Common usage**: `mol = dm.to_mol("CCO")` + +### `from_inchi(inchi)` +Convert InChI string to molecule object. + +### `from_smarts(smarts)` +Convert SMARTS pattern to molecule object. + +### `from_selfies(selfies)` +Convert SELFIES string to molecule object. + +### `copy_mol(mol)` +Create a copy of a molecule object to avoid modifying the original. + +## Molecule Export + +### `to_smiles(mol, ...)` +Convert molecule object to SMILES string. +- **Common parameters**: `canonical=True`, `isomeric=True` + +### `to_inchi(mol, ...)` +Convert molecule to InChI string representation. + +### `to_inchikey(mol)` +Convert molecule to InChI key (fixed-length hash). + +### `to_smarts(mol)` +Convert molecule to SMARTS pattern. + +### `to_selfies(mol)` +Convert molecule to SELFIES (Self-Referencing Embedded Strings) format. + +## Sanitization and Standardization + +### `sanitize_mol(mol, ...)` +Enhanced version of RDKit's sanitize operation using mol→SMILES→mol conversion and aromatic nitrogen fixing. +- **Purpose**: Fix common molecular structure issues +- **Returns**: Sanitized molecule or None if sanitization fails + +### `standardize_mol(mol, disconnect_metals=False, normalize=True, reionize=True, ...)` +Apply comprehensive standardization procedures including: +- Metal disconnection +- Normalization (charge corrections) +- Reionization +- Fragment handling (largest fragment selection) + +### `standardize_smiles(smiles, ...)` +Apply SMILES standardization procedures directly to a SMILES string. + +### `fix_mol(mol)` +Attempt to fix molecular structure issues automatically. + +### `fix_valence(mol)` +Correct valence errors in molecular structures. + +## Molecular Properties + +### `reorder_atoms(mol, ...)` +Ensure consistent atom ordering for the same molecule regardless of original SMILES representation. +- **Purpose**: Maintain reproducible feature generation + +### `remove_hs(mol, ...)` +Remove hydrogen atoms from molecular structure. + +### `add_hs(mol, ...)` +Add explicit hydrogen atoms to molecular structure. + +## Fingerprints and Similarity + +### `to_fp(mol, fp_type='ecfp', ...)` +Generate molecular fingerprints for similarity calculations. +- **Fingerprint types**: + - `'ecfp'` - Extended Connectivity Fingerprints (Morgan) + - `'fcfp'` - Functional Connectivity Fingerprints + - `'maccs'` - MACCS keys + - `'topological'` - Topological fingerprints + - `'atompair'` - Atom pair fingerprints +- **Common parameters**: `n_bits`, `radius` +- **Returns**: Numpy array or RDKit fingerprint object + +### `pdist(mols, ...)` +Calculate pairwise Tanimoto distances between all molecules in a list. +- **Supports**: Parallel processing via `n_jobs` parameter +- **Returns**: Distance matrix + +### `cdist(mols1, mols2, ...)` +Calculate Tanimoto distances between two sets of molecules. + +## Clustering and Diversity + +### `cluster_mols(mols, cutoff=0.2, feature_fn=None, n_jobs=1)` +Cluster molecules using Butina clustering algorithm. +- **Parameters**: + - `cutoff`: Distance threshold (default 0.2) + - `feature_fn`: Custom function for molecular features + - `n_jobs`: Parallelization (-1 for all cores) +- **Important**: Builds full distance matrix - suitable for ~1000 structures, not for 10,000+ +- **Returns**: List of clusters (each cluster is a list of molecule indices) + +### `pick_diverse(mols, npick, ...)` +Select diverse subset of molecules based on fingerprint diversity. + +### `pick_centroids(mols, npick, ...)` +Select centroid molecules representing clusters. + +## Graph Operations + +### `to_graph(mol)` +Convert molecule to graph representation for graph-based analysis. + +### `get_all_path_between(mol, start, end)` +Find all paths between two atoms in molecular structure. + +## DataFrame Integration + +### `to_df(mols, smiles_column='smiles', mol_column='mol')` +Convert list of molecules to pandas DataFrame. + +### `from_df(df, smiles_column='smiles', mol_column='mol')` +Convert pandas DataFrame to list of molecules. diff --git a/scientific-packages/datamol/references/descriptors_viz.md b/scientific-packages/datamol/references/descriptors_viz.md new file mode 100644 index 0000000..87b60d3 --- /dev/null +++ b/scientific-packages/datamol/references/descriptors_viz.md @@ -0,0 +1,195 @@ +# Datamol Descriptors and Visualization Reference + +## Descriptors Module (`datamol.descriptors`) + +The descriptors module provides tools for computing molecular properties and descriptors. + +### Specialized Descriptor Functions + +#### `dm.descriptors.n_aromatic_atoms(mol)` +Calculate the number of aromatic atoms. +- **Returns**: Integer count +- **Use case**: Aromaticity analysis + +#### `dm.descriptors.n_aromatic_atoms_proportion(mol)` +Calculate ratio of aromatic atoms to total heavy atoms. +- **Returns**: Float between 0 and 1 +- **Use case**: Quantifying aromatic character + +#### `dm.descriptors.n_charged_atoms(mol)` +Count atoms with nonzero formal charge. +- **Returns**: Integer count +- **Use case**: Charge distribution analysis + +#### `dm.descriptors.n_rigid_bonds(mol)` +Count non-rotatable bonds (neither single bonds nor ring bonds). +- **Returns**: Integer count +- **Use case**: Molecular flexibility assessment + +#### `dm.descriptors.n_stereo_centers(mol)` +Count stereogenic centers (chiral centers). +- **Returns**: Integer count +- **Use case**: Stereochemistry analysis + +#### `dm.descriptors.n_stereo_centers_unspecified(mol)` +Count stereocenters lacking stereochemical specification. +- **Returns**: Integer count +- **Use case**: Identifying incomplete stereochemistry + +### Batch Descriptor Computation + +#### `dm.descriptors.compute_many_descriptors(mol, properties_fn=None, add_properties=True)` +Compute multiple molecular properties for a single molecule. +- **Parameters**: + - `properties_fn`: Custom list of descriptor functions + - `add_properties`: Include additional computed properties +- **Returns**: Dictionary of descriptor name → value pairs +- **Default descriptors include**: + - Molecular weight, LogP, number of H-bond donors/acceptors + - Aromatic atoms, stereocenters, rotatable bonds + - TPSA (Topological Polar Surface Area) + - Ring count, heteroatom count +- **Example**: + ```python + mol = dm.to_mol("CCO") + descriptors = dm.descriptors.compute_many_descriptors(mol) + # Returns: {'mw': 46.07, 'logp': -0.03, 'hbd': 1, 'hba': 1, ...} + ``` + +#### `dm.descriptors.batch_compute_many_descriptors(mols, properties_fn=None, add_properties=True, n_jobs=1, batch_size=None, progress=False)` +Compute descriptors for multiple molecules in parallel. +- **Parameters**: + - `mols`: List of molecules + - `n_jobs`: Number of parallel jobs (-1 for all cores) + - `batch_size`: Chunk size for parallel processing + - `progress`: Show progress bar +- **Returns**: Pandas DataFrame with one row per molecule +- **Example**: + ```python + mols = [dm.to_mol(smi) for smi in smiles_list] + df = dm.descriptors.batch_compute_many_descriptors( + mols, + n_jobs=-1, + progress=True + ) + ``` + +### RDKit Descriptor Access + +#### `dm.descriptors.any_rdkit_descriptor(name)` +Retrieve any descriptor function from RDKit by name. +- **Parameters**: `name` - Descriptor function name (e.g., 'MolWt', 'TPSA') +- **Returns**: RDKit descriptor function +- **Available descriptors**: From `rdkit.Chem.Descriptors` and `rdkit.Chem.rdMolDescriptors` +- **Example**: + ```python + tpsa_fn = dm.descriptors.any_rdkit_descriptor('TPSA') + tpsa_value = tpsa_fn(mol) + ``` + +### Common Use Cases + +**Drug-likeness Filtering (Lipinski's Rule of Five)**: +```python +descriptors = dm.descriptors.compute_many_descriptors(mol) +is_druglike = ( + descriptors['mw'] <= 500 and + descriptors['logp'] <= 5 and + descriptors['hbd'] <= 5 and + descriptors['hba'] <= 10 +) +``` + +**ADME Property Analysis**: +```python +df = dm.descriptors.batch_compute_many_descriptors(compound_library) +# Filter by TPSA for blood-brain barrier penetration +bbb_candidates = df[df['tpsa'] < 90] +``` + +--- + +## Visualization Module (`datamol.viz`) + +The viz module provides tools for rendering molecules and conformers as images. + +### Main Visualization Function + +#### `dm.viz.to_image(mols, legends=None, n_cols=4, use_svg=False, mol_size=(200, 200), highlight_atom=None, highlight_bond=None, outfile=None, max_mols=None, copy=True, indices=False, ...)` +Generate image grid from molecules. +- **Parameters**: + - `mols`: Single molecule or list of molecules + - `legends`: String or list of strings as labels (one per molecule) + - `n_cols`: Number of molecules per row (default: 4) + - `use_svg`: Output SVG format (True) or PNG (False, default) + - `mol_size`: Tuple (width, height) or single int for square images + - `highlight_atom`: Atom indices to highlight (list or dict) + - `highlight_bond`: Bond indices to highlight (list or dict) + - `outfile`: Save path (local or remote, supports fsspec) + - `max_mols`: Maximum number of molecules to display + - `indices`: Draw atom indices on structures (default: False) + - `align`: Align molecules using MCS (Maximum Common Substructure) +- **Returns**: Image object (can be displayed in Jupyter) or saves to file +- **Example**: + ```python + # Basic grid + dm.viz.to_image(mols[:10], legends=[dm.to_smiles(m) for m in mols[:10]]) + + # Save to file + dm.viz.to_image(mols, outfile="molecules.png", n_cols=5) + + # Highlight substructure + dm.viz.to_image(mol, highlight_atom=[0, 1, 2], highlight_bond=[0, 1]) + + # Aligned visualization + dm.viz.to_image(mols, align=True, legends=activity_labels) + ``` + +### Conformer Visualization + +#### `dm.viz.conformers(mol, n_confs=None, align_conf=True, n_cols=3, sync_views=True, remove_hs=True, ...)` +Display multiple conformers in grid layout. +- **Parameters**: + - `mol`: Molecule with embedded conformers + - `n_confs`: Number or list of conformer indices to display (None = all) + - `align_conf`: Align conformers for comparison (default: True) + - `n_cols`: Grid columns (default: 3) + - `sync_views`: Synchronize 3D views when interactive (default: True) + - `remove_hs`: Remove hydrogens for clarity (default: True) +- **Returns**: Grid of conformer visualizations +- **Use case**: Comparing conformational diversity +- **Example**: + ```python + mol_3d = dm.conformers.generate(mol, n_confs=20) + dm.viz.conformers(mol_3d, n_confs=10, align_conf=True) + ``` + +### Circle Grid Visualization + +#### `dm.viz.circle_grid(center_mol, circle_mols, mol_size=200, circle_margin=50, act_mapper=None, ...)` +Create concentric ring visualization with central molecule. +- **Parameters**: + - `center_mol`: Molecule at center + - `circle_mols`: List of molecule lists (one list per ring) + - `mol_size`: Image size per molecule + - `circle_margin`: Spacing between rings (default: 50) + - `act_mapper`: Activity mapping dictionary for color-coding +- **Returns**: Circular grid image +- **Use case**: Visualizing molecular neighborhoods, SAR analysis, similarity networks +- **Example**: + ```python + # Show a reference molecule surrounded by similar compounds + dm.viz.circle_grid( + center_mol=reference, + circle_mols=[nearest_neighbors, second_tier] + ) + ``` + +### Visualization Best Practices + +1. **Use legends for clarity**: Always label molecules with SMILES, IDs, or activity values +2. **Align related molecules**: Use `align=True` in `to_image()` for SAR analysis +3. **Adjust grid size**: Set `n_cols` based on molecule count and display width +4. **Use SVG for publications**: Set `use_svg=True` for scalable vector graphics +5. **Highlight substructures**: Use `highlight_atom` and `highlight_bond` to emphasize features +6. **Save large grids**: Use `outfile` parameter to save rather than display in memory diff --git a/scientific-packages/datamol/references/fragments_scaffolds.md b/scientific-packages/datamol/references/fragments_scaffolds.md new file mode 100644 index 0000000..77dc19f --- /dev/null +++ b/scientific-packages/datamol/references/fragments_scaffolds.md @@ -0,0 +1,174 @@ +# Datamol Fragments and Scaffolds Reference + +## Scaffolds Module (`datamol.scaffold`) + +Scaffolds represent the core structure of molecules, useful for identifying structural families and analyzing structure-activity relationships (SAR). + +### Murcko Scaffolds + +#### `dm.to_scaffold_murcko(mol)` +Extract Bemis-Murcko scaffold (molecular framework). +- **Method**: Removes side chains, retaining ring systems and linkers +- **Returns**: Molecule object representing the scaffold +- **Use case**: Identify core structures across compound series +- **Example**: + ```python + mol = dm.to_mol("c1ccc(cc1)CCN") # Phenethylamine + scaffold = dm.to_scaffold_murcko(mol) + scaffold_smiles = dm.to_smiles(scaffold) + # Returns: 'c1ccccc1CC' (benzene ring + ethyl linker) + ``` + +**Workflow for scaffold analysis**: +```python +# Extract scaffolds from compound library +scaffolds = [dm.to_scaffold_murcko(mol) for mol in mols] +scaffold_smiles = [dm.to_smiles(s) for s in scaffolds] + +# Count scaffold frequency +from collections import Counter +scaffold_counts = Counter(scaffold_smiles) +most_common = scaffold_counts.most_common(10) +``` + +### Fuzzy Scaffolds + +#### `dm.scaffold.fuzzy_scaffolding(mol, ...)` +Generate fuzzy scaffolds with enforceable groups that must appear in the core. +- **Purpose**: More flexible scaffold definition allowing specified functional groups +- **Use case**: Custom scaffold definitions beyond Murcko rules + +### Applications + +**Scaffold-based splitting** (for ML model validation): +```python +# Group compounds by scaffold +scaffold_to_mols = {} +for mol, scaffold in zip(mols, scaffolds): + smi = dm.to_smiles(scaffold) + if smi not in scaffold_to_mols: + scaffold_to_mols[smi] = [] + scaffold_to_mols[smi].append(mol) + +# Ensure train/test sets have different scaffolds +``` + +**SAR analysis**: +```python +# Group by scaffold and analyze activity +for scaffold_smi, molecules in scaffold_to_mols.items(): + activities = [get_activity(mol) for mol in molecules] + print(f"Scaffold: {scaffold_smi}, Mean activity: {np.mean(activities)}") +``` + +--- + +## Fragments Module (`datamol.fragment`) + +Molecular fragmentation breaks molecules into smaller pieces based on chemical rules, useful for fragment-based drug design and substructure analysis. + +### BRICS Fragmentation + +#### `dm.fragment.brics(mol, ...)` +Fragment molecule using BRICS (Breaking Retrosynthetically Interesting Chemical Substructures). +- **Method**: Dissects based on 16 chemically meaningful bond types +- **Consideration**: Considers chemical environment and surrounding substructures +- **Returns**: Set of fragment SMILES strings +- **Use case**: Retrosynthetic analysis, fragment-based design +- **Example**: + ```python + mol = dm.to_mol("c1ccccc1CCN") + fragments = dm.fragment.brics(mol) + # Returns fragments like: '[1*]CCN', '[1*]c1ccccc1', etc. + # [1*] represents attachment points + ``` + +### RECAP Fragmentation + +#### `dm.fragment.recap(mol, ...)` +Fragment molecule using RECAP (Retrosynthetic Combinatorial Analysis Procedure). +- **Method**: Dissects based on 11 predefined bond types +- **Rules**: + - Leaves alkyl groups smaller than 5 carbons intact + - Preserves cyclic bonds +- **Returns**: Set of fragment SMILES strings +- **Use case**: Combinatorial library design +- **Example**: + ```python + mol = dm.to_mol("CCCCCc1ccccc1") + fragments = dm.fragment.recap(mol) + ``` + +### MMPA Fragmentation + +#### `dm.fragment.mmpa_frag(mol, ...)` +Fragment for Matched Molecular Pair Analysis. +- **Purpose**: Generate fragments suitable for identifying molecular pairs +- **Use case**: Analyzing how small structural changes affect properties +- **Example**: + ```python + fragments = dm.fragment.mmpa_frag(mol) + # Used to find pairs of molecules differing by single transformation + ``` + +### Comparison of Methods + +| Method | Bond Types | Preserves Cycles | Best For | +|--------|-----------|------------------|----------| +| BRICS | 16 | Yes | Retrosynthetic analysis, fragment recombination | +| RECAP | 11 | Yes | Combinatorial library design | +| MMPA | Variable | Depends | Structure-activity relationship analysis | + +### Fragmentation Workflow + +```python +import datamol as dm + +# 1. Fragment a molecule +mol = dm.to_mol("CC(=O)Oc1ccccc1C(=O)O") # Aspirin +brics_frags = dm.fragment.brics(mol) +recap_frags = dm.fragment.recap(mol) + +# 2. Analyze fragment frequency across library +all_fragments = [] +for mol in molecule_library: + frags = dm.fragment.brics(mol) + all_fragments.extend(frags) + +# 3. Identify common fragments +from collections import Counter +fragment_counts = Counter(all_fragments) +common_fragments = fragment_counts.most_common(20) + +# 4. Convert fragments back to molecules (remove attachment points) +def clean_fragment(frag_smiles): + # Remove [1*], [2*], etc. attachment point markers + clean = frag_smiles.replace('[1*]', '[H]') + return dm.to_mol(clean) +``` + +### Advanced: Fragment-Based Virtual Screening + +```python +# Build fragment library from known actives +active_fragments = set() +for active_mol in active_compounds: + frags = dm.fragment.brics(active_mol) + active_fragments.update(frags) + +# Screen compounds for presence of active fragments +def score_by_fragments(mol, fragment_set): + mol_frags = dm.fragment.brics(mol) + overlap = mol_frags.intersection(fragment_set) + return len(overlap) / len(mol_frags) + +# Score screening library +scores = [score_by_fragments(mol, active_fragments) for mol in screening_lib] +``` + +### Key Concepts + +- **Attachment Points**: Marked with [1*], [2*], etc. in fragment SMILES +- **Retrosynthetic**: Fragmentation mimics synthetic disconnections +- **Chemically Meaningful**: Breaks occur at typical synthetic bonds +- **Recombination**: Fragments can theoretically be recombined into valid molecules diff --git a/scientific-packages/datamol/references/io_module.md b/scientific-packages/datamol/references/io_module.md new file mode 100644 index 0000000..71e5027 --- /dev/null +++ b/scientific-packages/datamol/references/io_module.md @@ -0,0 +1,109 @@ +# Datamol I/O Module Reference + +The `datamol.io` module provides comprehensive file handling for molecular data across multiple formats. + +## Reading Molecular Files + +### `dm.read_sdf(filename, sanitize=True, remove_hs=True, as_df=True, mol_column='mol', ...)` +Read Structure-Data File (SDF) format. +- **Parameters**: + - `filename`: Path to SDF file (supports local and remote paths via fsspec) + - `sanitize`: Apply sanitization to molecules + - `remove_hs`: Remove explicit hydrogens + - `as_df`: Return as DataFrame (True) or list of molecules (False) + - `mol_column`: Name of molecule column in DataFrame + - `n_jobs`: Enable parallel processing +- **Returns**: DataFrame or list of molecules +- **Example**: `df = dm.read_sdf("compounds.sdf")` + +### `dm.read_smi(filename, smiles_column='smiles', mol_column='mol', as_df=True, ...)` +Read SMILES file (space-delimited by default). +- **Common format**: SMILES followed by molecule ID/name +- **Example**: `df = dm.read_smi("molecules.smi")` + +### `dm.read_csv(filename, smiles_column='smiles', mol_column=None, ...)` +Read CSV file with optional automatic SMILES-to-molecule conversion. +- **Parameters**: + - `smiles_column`: Column containing SMILES strings + - `mol_column`: If specified, creates molecule objects from SMILES column +- **Example**: `df = dm.read_csv("data.csv", smiles_column="SMILES", mol_column="mol")` + +### `dm.read_excel(filename, sheet_name=0, smiles_column='smiles', mol_column=None, ...)` +Read Excel files with molecule handling. +- **Parameters**: + - `sheet_name`: Sheet to read (index or name) + - Other parameters similar to `read_csv` +- **Example**: `df = dm.read_excel("compounds.xlsx", sheet_name="Sheet1")` + +### `dm.read_molblock(molblock, sanitize=True, remove_hs=True)` +Parse MOL block string (molecular structure text representation). + +### `dm.read_mol2file(filename, sanitize=True, remove_hs=True, cleanupSubstructures=True)` +Read Mol2 format files. + +### `dm.read_pdbfile(filename, sanitize=True, remove_hs=True, proximityBonding=True)` +Read Protein Data Bank (PDB) format files. + +### `dm.read_pdbblock(pdbblock, sanitize=True, remove_hs=True, proximityBonding=True)` +Parse PDB block string. + +### `dm.open_df(filename, ...)` +Universal DataFrame reader - automatically detects format. +- **Supported formats**: CSV, Excel, Parquet, JSON, SDF +- **Example**: `df = dm.open_df("data.csv")` or `df = dm.open_df("molecules.sdf")` + +## Writing Molecular Files + +### `dm.to_sdf(mols, filename, mol_column=None, ...)` +Write molecules to SDF file. +- **Input types**: + - List of molecules + - DataFrame with molecule column + - Sequence of molecules +- **Parameters**: + - `mol_column`: Column name if input is DataFrame +- **Example**: + ```python + dm.to_sdf(mols, "output.sdf") + # or from DataFrame + dm.to_sdf(df, "output.sdf", mol_column="mol") + ``` + +### `dm.to_smi(mols, filename, mol_column=None, ...)` +Write molecules to SMILES file with optional validation. +- **Format**: SMILES strings with optional molecule names/IDs + +### `dm.to_xlsx(df, filename, mol_columns=None, ...)` +Export DataFrame to Excel with rendered molecular images. +- **Parameters**: + - `mol_columns`: Columns containing molecules to render as images +- **Special feature**: Automatically renders molecules as images in Excel cells +- **Example**: `dm.to_xlsx(df, "molecules.xlsx", mol_columns=["mol"])` + +### `dm.to_molblock(mol, ...)` +Convert molecule to MOL block string. + +### `dm.to_pdbblock(mol, ...)` +Convert molecule to PDB block string. + +### `dm.save_df(df, filename, ...)` +Save DataFrame in multiple formats (CSV, Excel, Parquet, JSON). + +## Remote File Support + +All I/O functions support remote file paths through fsspec integration: +- **Supported protocols**: S3 (AWS), GCS (Google Cloud), Azure, HTTP/HTTPS +- **Example**: + ```python + dm.read_sdf("s3://bucket/compounds.sdf") + dm.read_csv("https://example.com/data.csv") + ``` + +## Key Parameters Across Functions + +- **`sanitize`**: Apply molecule sanitization (default: True) +- **`remove_hs`**: Remove explicit hydrogens (default: True) +- **`as_df`**: Return DataFrame vs list (default: True for most functions) +- **`n_jobs`**: Enable parallel processing (None = all cores, 1 = sequential) +- **`mol_column`**: Name of molecule column in DataFrames +- **`smiles_column`**: Name of SMILES column in DataFrames diff --git a/scientific-packages/datamol/references/reactions_data.md b/scientific-packages/datamol/references/reactions_data.md new file mode 100644 index 0000000..e351d2f --- /dev/null +++ b/scientific-packages/datamol/references/reactions_data.md @@ -0,0 +1,218 @@ +# Datamol Reactions and Data Modules Reference + +## Reactions Module (`datamol.reactions`) + +The reactions module enables programmatic application of chemical transformations using SMARTS reaction patterns. + +### Applying Chemical Reactions + +#### `dm.reactions.apply_reaction(rxn, reactants, as_smiles=False, sanitize=True, single_product_group=True, rm_attach=True, product_index=0)` +Apply a chemical reaction to reactant molecules. +- **Parameters**: + - `rxn`: Reaction object (from SMARTS pattern) + - `reactants`: Tuple of reactant molecules + - `as_smiles`: Return SMILES strings (True) or molecule objects (False) + - `sanitize`: Sanitize product molecules + - `single_product_group`: Return single product (True) or all product groups (False) + - `rm_attach`: Remove attachment point markers + - `product_index`: Which product to return from reaction +- **Returns**: Product molecule(s) or SMILES +- **Example**: + ```python + from rdkit import Chem + + # Define reaction: alcohol + carboxylic acid → ester + rxn = Chem.rdChemReactions.ReactionFromSmarts( + '[C:1][OH:2].[C:3](=[O:4])[OH:5]>>[C:1][O:2][C:3](=[O:4])' + ) + + # Apply to reactants + alcohol = dm.to_mol("CCO") + acid = dm.to_mol("CC(=O)O") + product = dm.reactions.apply_reaction(rxn, (alcohol, acid)) + ``` + +### Creating Reactions + +Reactions are typically created from SMARTS patterns using RDKit: +```python +from rdkit.Chem import rdChemReactions + +# Reaction pattern: [reactant1].[reactant2]>>[product] +rxn = rdChemReactions.ReactionFromSmarts( + '[1*][*:1].[1*][*:2]>>[*:1][*:2]' +) +``` + +### Validation Functions + +The module includes functions to: +- **Check if molecule is reactant**: Verify if molecule matches reactant pattern +- **Validate reaction**: Check if reaction is synthetically reasonable +- **Process reaction files**: Load reactions from files or databases + +### Common Reaction Patterns + +**Amide formation**: +```python +# Amine + carboxylic acid → amide +amide_rxn = rdChemReactions.ReactionFromSmarts( + '[N:1].[C:2](=[O:3])[OH]>>[N:1][C:2](=[O:3])' +) +``` + +**Suzuki coupling**: +```python +# Aryl halide + boronic acid → biaryl +suzuki_rxn = rdChemReactions.ReactionFromSmarts( + '[c:1][Br].[c:2][B]([OH])[OH]>>[c:1][c:2]' +) +``` + +**Functional group transformations**: +```python +# Alcohol → ester +esterification = rdChemReactions.ReactionFromSmarts( + '[C:1][OH:2].[C:3](=[O:4])[Cl]>>[C:1][O:2][C:3](=[O:4])' +) +``` + +### Workflow Example + +```python +import datamol as dm +from rdkit.Chem import rdChemReactions + +# 1. Define reaction +rxn_smarts = '[C:1](=[O:2])[OH:3]>>[C:1](=[O:2])[Cl:3]' # Acid → acid chloride +rxn = rdChemReactions.ReactionFromSmarts(rxn_smarts) + +# 2. Apply to molecule library +acids = [dm.to_mol(smi) for smi in acid_smiles_list] +acid_chlorides = [] + +for acid in acids: + try: + product = dm.reactions.apply_reaction( + rxn, + (acid,), # Single reactant as tuple + sanitize=True + ) + acid_chlorides.append(product) + except Exception as e: + print(f"Reaction failed: {e}") + +# 3. Validate products +valid_products = [p for p in acid_chlorides if p is not None] +``` + +### Key Concepts + +- **SMARTS**: SMiles ARbitrary Target Specification - pattern language for reactions +- **Atom Mapping**: Numbers like [C:1] preserve atom identity through reaction +- **Attachment Points**: [1*] represents generic connection points +- **Reaction Validation**: Not all SMARTS reactions are chemically reasonable + +--- + +## Data Module (`datamol.data`) + +The data module provides convenient access to curated molecular datasets for testing and learning. + +### Available Datasets + +#### `dm.data.cdk2(as_df=True, mol_column='mol')` +RDKit CDK2 dataset - kinase inhibitor data. +- **Parameters**: + - `as_df`: Return as DataFrame (True) or list of molecules (False) + - `mol_column`: Name for molecule column +- **Returns**: Dataset with molecular structures and activity data +- **Use case**: Small dataset for algorithm testing +- **Example**: + ```python + cdk2_df = dm.data.cdk2(as_df=True) + print(cdk2_df.shape) + print(cdk2_df.columns) + ``` + +#### `dm.data.freesolv()` +FreeSolv dataset - experimental and calculated hydration free energies. +- **Contents**: 642 molecules with: + - IUPAC names + - SMILES strings + - Experimental hydration free energy values + - Calculated values +- **Warning**: "Only meant to be used as a toy dataset for pedagogic and testing purposes" +- **Not suitable for**: Benchmarking or production model training +- **Example**: + ```python + freesolv_df = dm.data.freesolv() + # Columns: iupac, smiles, expt (kcal/mol), calc (kcal/mol) + ``` + +#### `dm.data.solubility(as_df=True, mol_column='mol')` +RDKit solubility dataset with train/test splits. +- **Contents**: Aqueous solubility data with pre-defined splits +- **Columns**: Includes 'split' column with 'train' or 'test' values +- **Use case**: Testing ML workflows with proper train/test separation +- **Example**: + ```python + sol_df = dm.data.solubility(as_df=True) + + # Split into train/test + train_df = sol_df[sol_df['split'] == 'train'] + test_df = sol_df[sol_df['split'] == 'test'] + + # Use for model development + X_train = dm.to_fp(train_df[mol_column]) + y_train = train_df['solubility'] + ``` + +### Usage Guidelines + +**For testing and tutorials**: +```python +# Quick dataset for testing code +df = dm.data.cdk2() +mols = df['mol'].tolist() + +# Test descriptor calculation +descriptors_df = dm.descriptors.batch_compute_many_descriptors(mols) + +# Test clustering +clusters = dm.cluster_mols(mols, cutoff=0.3) +``` + +**For learning workflows**: +```python +# Complete ML pipeline example +sol_df = dm.data.solubility() + +# Preprocessing +train = sol_df[sol_df['split'] == 'train'] +test = sol_df[sol_df['split'] == 'test'] + +# Featurization +X_train = dm.to_fp(train['mol']) +X_test = dm.to_fp(test['mol']) + +# Model training (example) +from sklearn.ensemble import RandomForestRegressor +model = RandomForestRegressor() +model.fit(X_train, train['solubility']) +predictions = model.predict(X_test) +``` + +### Important Notes + +- **Toy Datasets**: Designed for pedagogical purposes, not production use +- **Small Size**: Limited number of compounds suitable for quick tests +- **Pre-processed**: Data already cleaned and formatted +- **Citations**: Check dataset documentation for proper attribution if publishing + +### Best Practices + +1. **Use for development only**: Don't draw scientific conclusions from toy datasets +2. **Validate on real data**: Always test production code on actual project data +3. **Proper attribution**: Cite original data sources if using in publications +4. **Understand limitations**: Know the scope and quality of each dataset diff --git a/scientific-packages/deepchem/SKILL.md b/scientific-packages/deepchem/SKILL.md new file mode 100644 index 0000000..7c058b1 --- /dev/null +++ b/scientific-packages/deepchem/SKILL.md @@ -0,0 +1,591 @@ +--- +name: deepchem +description: Comprehensive toolkit for molecular machine learning, drug discovery, and materials science using DeepChem. Use this skill when working with molecular data (SMILES, SDF files), predicting molecular properties (solubility, toxicity, binding affinity), training graph neural networks on molecules, using MoleculeNet benchmarks, performing molecular featurization, or applying transfer learning with pretrained chemical models (ChemBERTa, GROVER). Also applicable for materials science (crystal structures, bandgap prediction) and protein/DNA sequence analysis. +--- + +# DeepChem + +## Overview + +DeepChem is a comprehensive Python library for applying machine learning to chemistry, materials science, and biology. Enable molecular property prediction, drug discovery, materials design, and biomolecule analysis through specialized neural networks, molecular featurization methods, and pretrained models. + +## When to Use This Skill + +Apply this skill when: +- Loading and processing molecular data (SMILES strings, SDF files, protein sequences) +- Predicting molecular properties (solubility, toxicity, binding affinity, ADMET properties) +- Training models on chemical/biological datasets +- Using MoleculeNet benchmark datasets (Tox21, BBBP, Delaney, etc.) +- Converting molecules to ML-ready features (fingerprints, graph representations, descriptors) +- Implementing graph neural networks for molecules (GCN, GAT, MPNN, AttentiveFP) +- Applying transfer learning with pretrained models (ChemBERTa, GROVER, MolFormer) +- Predicting crystal/materials properties (bandgap, formation energy) +- Analyzing protein or DNA sequences + +## Core Capabilities + +### 1. Molecular Data Loading and Processing + +DeepChem provides specialized loaders for various chemical data formats: + +```python +import deepchem as dc + +# Load CSV with SMILES +featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) +loader = dc.data.CSVLoader( + tasks=['solubility', 'toxicity'], + feature_field='smiles', + featurizer=featurizer +) +dataset = loader.create_dataset('molecules.csv') + +# Load SDF files +loader = dc.data.SDFLoader(tasks=['activity'], featurizer=featurizer) +dataset = loader.create_dataset('compounds.sdf') + +# Load protein sequences +loader = dc.data.FASTALoader() +dataset = loader.create_dataset('proteins.fasta') +``` + +**Key Loaders**: +- `CSVLoader`: Tabular data with molecular identifiers +- `SDFLoader`: Molecular structure files +- `FASTALoader`: Protein/DNA sequences +- `ImageLoader`: Molecular images +- `JsonLoader`: JSON-formatted datasets + +### 2. Molecular Featurization + +Convert molecules into numerical representations for ML models. + +#### Decision Tree for Featurizer Selection + +``` +Is the model a graph neural network? +├─ YES → Use graph featurizers +│ ├─ Standard GNN → MolGraphConvFeaturizer +│ ├─ Message passing → DMPNNFeaturizer +│ └─ Pretrained → GroverFeaturizer +│ +└─ NO → What type of model? + ├─ Traditional ML (RF, XGBoost, SVM) + │ ├─ Fast baseline → CircularFingerprint (ECFP) + │ ├─ Interpretable → RDKitDescriptors + │ └─ Maximum coverage → MordredDescriptors + │ + ├─ Deep learning (non-graph) + │ ├─ Dense networks → CircularFingerprint + │ └─ CNN → SmilesToImage + │ + ├─ Sequence models (LSTM, Transformer) + │ └─ SmilesToSeq + │ + └─ 3D structure analysis + └─ CoulombMatrix +``` + +#### Example Featurization + +```python +# Fingerprints (for traditional ML) +fp = dc.feat.CircularFingerprint(radius=2, size=2048) + +# Descriptors (for interpretable models) +desc = dc.feat.RDKitDescriptors() + +# Graph features (for GNNs) +graph_feat = dc.feat.MolGraphConvFeaturizer() + +# Apply featurization +features = fp.featurize(['CCO', 'c1ccccc1']) +``` + +**Selection Guide**: +- **Small datasets (<1K)**: CircularFingerprint or RDKitDescriptors +- **Medium datasets (1K-100K)**: CircularFingerprint or graph featurizers +- **Large datasets (>100K)**: Graph featurizers (MolGraphConvFeaturizer, DMPNNFeaturizer) +- **Transfer learning**: Pretrained model featurizers (GroverFeaturizer) + +See `references/api_reference.md` for complete featurizer documentation. + +### 3. Data Splitting + +**Critical**: For drug discovery tasks, use `ScaffoldSplitter` to prevent data leakage from similar molecular structures appearing in both training and test sets. + +```python +# Scaffold splitting (recommended for molecules) +splitter = dc.splits.ScaffoldSplitter() +train, valid, test = splitter.train_valid_test_split( + dataset, + frac_train=0.8, + frac_valid=0.1, + frac_test=0.1 +) + +# Random splitting (for non-molecular data) +splitter = dc.splits.RandomSplitter() +train, test = splitter.train_test_split(dataset) + +# Stratified splitting (for imbalanced classification) +splitter = dc.splits.RandomStratifiedSplitter() +train, test = splitter.train_test_split(dataset) +``` + +**Available Splitters**: +- `ScaffoldSplitter`: Split by molecular scaffolds (prevents leakage) +- `ButinaSplitter`: Clustering-based molecular splitting +- `MaxMinSplitter`: Maximize diversity between sets +- `RandomSplitter`: Random splitting +- `RandomStratifiedSplitter`: Preserves class distributions + +### 4. Model Selection and Training + +#### Quick Model Selection Guide + +| Dataset Size | Task | Recommended Model | Featurizer | +|-------------|------|-------------------|------------| +| < 1K samples | Any | SklearnModel (RandomForest) | CircularFingerprint | +| 1K-100K | Classification/Regression | GBDTModel or MultitaskRegressor | CircularFingerprint | +| > 100K | Molecular properties | GCNModel, AttentiveFPModel, DMPNNModel | MolGraphConvFeaturizer | +| Any (small preferred) | Transfer learning | ChemBERTa, GROVER, MolFormer | Model-specific | +| Crystal structures | Materials properties | CGCNNModel, MEGNetModel | Structure-based | +| Protein sequences | Protein properties | ProtBERT | Sequence-based | + +#### Example: Traditional ML +```python +from sklearn.ensemble import RandomForestRegressor + +# Wrap scikit-learn model +sklearn_model = RandomForestRegressor(n_estimators=100) +model = dc.models.SklearnModel(model=sklearn_model) +model.fit(train) +``` + +#### Example: Deep Learning +```python +# Multitask regressor (for fingerprints) +model = dc.models.MultitaskRegressor( + n_tasks=2, + n_features=2048, + layer_sizes=[1000, 500], + dropouts=0.25, + learning_rate=0.001 +) +model.fit(train, nb_epoch=50) +``` + +#### Example: Graph Neural Networks +```python +# Graph Convolutional Network +model = dc.models.GCNModel( + n_tasks=1, + mode='regression', + batch_size=128, + learning_rate=0.001 +) +model.fit(train, nb_epoch=50) + +# Graph Attention Network +model = dc.models.GATModel(n_tasks=1, mode='classification') +model.fit(train, nb_epoch=50) + +# Attentive Fingerprint +model = dc.models.AttentiveFPModel(n_tasks=1, mode='regression') +model.fit(train, nb_epoch=50) +``` + +### 5. MoleculeNet Benchmarks + +Quick access to 30+ curated benchmark datasets with standardized train/valid/test splits: + +```python +# Load benchmark dataset +tasks, datasets, transformers = dc.molnet.load_tox21( + featurizer='GraphConv', # or 'ECFP', 'Weave', 'Raw' + splitter='scaffold', # or 'random', 'stratified' + reload=False +) +train, valid, test = datasets + +# Train and evaluate +model = dc.models.GCNModel(n_tasks=len(tasks), mode='classification') +model.fit(train, nb_epoch=50) + +metric = dc.metrics.Metric(dc.metrics.roc_auc_score) +test_score = model.evaluate(test, [metric]) +``` + +**Common Datasets**: +- **Classification**: `load_tox21()`, `load_bbbp()`, `load_hiv()`, `load_clintox()` +- **Regression**: `load_delaney()`, `load_freesolv()`, `load_lipo()` +- **Quantum properties**: `load_qm7()`, `load_qm8()`, `load_qm9()` +- **Materials**: `load_perovskite()`, `load_bandgap()`, `load_mp_formation_energy()` + +See `references/api_reference.md` for complete dataset list. + +### 6. Transfer Learning + +Leverage pretrained models for improved performance, especially on small datasets: + +```python +# ChemBERTa (BERT pretrained on 77M molecules) +model = dc.models.HuggingFaceModel( + model='seyonec/ChemBERTa-zinc-base-v1', + task='classification', + n_tasks=1, + learning_rate=2e-5 # Lower LR for fine-tuning +) +model.fit(train, nb_epoch=10) + +# GROVER (graph transformer pretrained on 10M molecules) +model = dc.models.GroverModel( + task='regression', + n_tasks=1 +) +model.fit(train, nb_epoch=20) +``` + +**When to use transfer learning**: +- Small datasets (< 1000 samples) +- Novel molecular scaffolds +- Limited computational resources +- Need for rapid prototyping + +Use the `scripts/transfer_learning.py` script for guided transfer learning workflows. + +### 7. Model Evaluation + +```python +# Define metrics +classification_metrics = [ + dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'), + dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'), + dc.metrics.Metric(dc.metrics.f1_score, name='F1') +] + +regression_metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE') +] + +# Evaluate +train_scores = model.evaluate(train, classification_metrics) +test_scores = model.evaluate(test, classification_metrics) +``` + +### 8. Making Predictions + +```python +# Predict on test set +predictions = model.predict(test) + +# Predict on new molecules +new_smiles = ['CCO', 'c1ccccc1', 'CC(C)O'] +new_features = featurizer.featurize(new_smiles) +new_dataset = dc.data.NumpyDataset(X=new_features) + +# Apply same transformations as training +for transformer in transformers: + new_dataset = transformer.transform(new_dataset) + +predictions = model.predict(new_dataset) +``` + +## Typical Workflows + +### Workflow A: Quick Benchmark Evaluation + +For evaluating a model on standard benchmarks: + +```python +import deepchem as dc + +# 1. Load benchmark +tasks, datasets, _ = dc.molnet.load_bbbp( + featurizer='GraphConv', + splitter='scaffold' +) +train, valid, test = datasets + +# 2. Train model +model = dc.models.GCNModel(n_tasks=len(tasks), mode='classification') +model.fit(train, nb_epoch=50) + +# 3. Evaluate +metric = dc.metrics.Metric(dc.metrics.roc_auc_score) +test_score = model.evaluate(test, [metric]) +print(f"Test ROC-AUC: {test_score}") +``` + +### Workflow B: Custom Data Prediction + +For training on custom molecular datasets: + +```python +import deepchem as dc + +# 1. Load and featurize data +featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) +loader = dc.data.CSVLoader( + tasks=['activity'], + feature_field='smiles', + featurizer=featurizer +) +dataset = loader.create_dataset('my_molecules.csv') + +# 2. Split data (use ScaffoldSplitter for molecules!) +splitter = dc.splits.ScaffoldSplitter() +train, valid, test = splitter.train_valid_test_split(dataset) + +# 3. Normalize (optional but recommended) +transformers = [dc.trans.NormalizationTransformer( + transform_y=True, dataset=train +)] +for transformer in transformers: + train = transformer.transform(train) + valid = transformer.transform(valid) + test = transformer.transform(test) + +# 4. Train model +model = dc.models.MultitaskRegressor( + n_tasks=1, + n_features=2048, + layer_sizes=[1000, 500], + dropouts=0.25 +) +model.fit(train, nb_epoch=50) + +# 5. Evaluate +metric = dc.metrics.Metric(dc.metrics.r2_score) +test_score = model.evaluate(test, [metric]) +``` + +### Workflow C: Transfer Learning on Small Dataset + +For leveraging pretrained models: + +```python +import deepchem as dc + +# 1. Load data (pretrained models often need raw SMILES) +loader = dc.data.CSVLoader( + tasks=['activity'], + feature_field='smiles', + featurizer=dc.feat.DummyFeaturizer() # Model handles featurization +) +dataset = loader.create_dataset('small_dataset.csv') + +# 2. Split data +splitter = dc.splits.ScaffoldSplitter() +train, test = splitter.train_test_split(dataset) + +# 3. Load pretrained model +model = dc.models.HuggingFaceModel( + model='seyonec/ChemBERTa-zinc-base-v1', + task='classification', + n_tasks=1, + learning_rate=2e-5 +) + +# 4. Fine-tune +model.fit(train, nb_epoch=10) + +# 5. Evaluate +predictions = model.predict(test) +``` + +See `references/workflows.md` for 8 detailed workflow examples covering molecular generation, materials science, protein analysis, and more. + +## Example Scripts + +This skill includes three production-ready scripts in the `scripts/` directory: + +### 1. `predict_solubility.py` +Train and evaluate solubility prediction models. Works with Delaney benchmark or custom CSV data. + +```bash +# Use Delaney benchmark +python scripts/predict_solubility.py + +# Use custom data +python scripts/predict_solubility.py \ + --data my_data.csv \ + --smiles-col smiles \ + --target-col solubility \ + --predict "CCO" "c1ccccc1" +``` + +### 2. `graph_neural_network.py` +Train various graph neural network architectures on molecular data. + +```bash +# Train GCN on Tox21 +python scripts/graph_neural_network.py --model gcn --dataset tox21 + +# Train AttentiveFP on custom data +python scripts/graph_neural_network.py \ + --model attentivefp \ + --data molecules.csv \ + --task-type regression \ + --targets activity \ + --epochs 100 +``` + +### 3. `transfer_learning.py` +Fine-tune pretrained models (ChemBERTa, GROVER) on molecular property prediction tasks. + +```bash +# Fine-tune ChemBERTa on BBBP +python scripts/transfer_learning.py --model chemberta --dataset bbbp + +# Fine-tune GROVER on custom data +python scripts/transfer_learning.py \ + --model grover \ + --data small_dataset.csv \ + --target activity \ + --task-type classification \ + --epochs 20 +``` + +## Common Patterns and Best Practices + +### Pattern 1: Always Use Scaffold Splitting for Molecules +```python +# GOOD: Prevents data leakage +splitter = dc.splits.ScaffoldSplitter() +train, test = splitter.train_test_split(dataset) + +# BAD: Similar molecules in train and test +splitter = dc.splits.RandomSplitter() +train, test = splitter.train_test_split(dataset) +``` + +### Pattern 2: Normalize Features and Targets +```python +transformers = [ + dc.trans.NormalizationTransformer( + transform_y=True, # Also normalize target values + dataset=train + ) +] +for transformer in transformers: + train = transformer.transform(train) + test = transformer.transform(test) +``` + +### Pattern 3: Start Simple, Then Scale +1. Start with Random Forest + CircularFingerprint (fast baseline) +2. Try XGBoost/LightGBM if RF works well +3. Move to deep learning (MultitaskRegressor) if you have >5K samples +4. Try GNNs if you have >10K samples +5. Use transfer learning for small datasets or novel scaffolds + +### Pattern 4: Handle Imbalanced Data +```python +# Option 1: Balancing transformer +transformer = dc.trans.BalancingTransformer(dataset=train) +train = transformer.transform(train) + +# Option 2: Use balanced metrics +metric = dc.metrics.Metric(dc.metrics.balanced_accuracy_score) +``` + +### Pattern 5: Avoid Memory Issues +```python +# Use DiskDataset for large datasets +dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids) + +# Use smaller batch sizes +model = dc.models.GCNModel(batch_size=32) # Instead of 128 +``` + +## Common Pitfalls + +### Issue 1: Data Leakage in Drug Discovery +**Problem**: Using random splitting allows similar molecules in train/test sets. +**Solution**: Always use `ScaffoldSplitter` for molecular datasets. + +### Issue 2: GNN Underperforming vs Fingerprints +**Problem**: Graph neural networks perform worse than simple fingerprints. +**Solutions**: +- Ensure dataset is large enough (>10K samples typically) +- Increase training epochs (50-100) +- Try different architectures (AttentiveFP, DMPNN instead of GCN) +- Use pretrained models (GROVER) + +### Issue 3: Overfitting on Small Datasets +**Problem**: Model memorizes training data. +**Solutions**: +- Use stronger regularization (increase dropout to 0.5) +- Use simpler models (Random Forest instead of deep learning) +- Apply transfer learning (ChemBERTa, GROVER) +- Collect more data + +### Issue 4: Import Errors +**Problem**: Module not found errors. +**Solution**: Ensure DeepChem is installed with required dependencies: +```bash +pip install deepchem +# For PyTorch models +pip install deepchem[torch] +# For all features +pip install deepchem[all] +``` + +## Reference Documentation + +This skill includes comprehensive reference documentation: + +### `references/api_reference.md` +Complete API documentation including: +- All data loaders and their use cases +- Dataset classes and when to use each +- Complete featurizer catalog with selection guide +- Model catalog organized by category (50+ models) +- MoleculeNet dataset descriptions +- Metrics and evaluation functions +- Common code patterns + +**When to reference**: Search this file when you need specific API details, parameter names, or want to explore available options. + +### `references/workflows.md` +Eight detailed end-to-end workflows: +1. Molecular property prediction from SMILES +2. Using MoleculeNet benchmarks +3. Hyperparameter optimization +4. Transfer learning with pretrained models +5. Molecular generation with GANs +6. Materials property prediction +7. Protein sequence analysis +8. Custom model integration + +**When to reference**: Use these workflows as templates for implementing complete solutions. + +## Installation Notes + +Basic installation: +```bash +pip install deepchem +``` + +For PyTorch models (GCN, GAT, etc.): +```bash +pip install deepchem[torch] +``` + +For all features: +```bash +pip install deepchem[all] +``` + +If import errors occur, the user may need specific dependencies. Check the DeepChem documentation for detailed installation instructions. + +## Additional Resources + +- Official documentation: https://deepchem.readthedocs.io/ +- GitHub repository: https://github.com/deepchem/deepchem +- Tutorials: https://deepchem.readthedocs.io/en/latest/get_started/tutorials.html +- Paper: "MoleculeNet: A Benchmark for Molecular Machine Learning" diff --git a/scientific-packages/deepchem/references/api_reference.md b/scientific-packages/deepchem/references/api_reference.md new file mode 100644 index 0000000..86ae0cf --- /dev/null +++ b/scientific-packages/deepchem/references/api_reference.md @@ -0,0 +1,303 @@ +# DeepChem API Reference + +This document provides a comprehensive reference for DeepChem's core APIs, organized by functionality. + +## Data Handling + +### Data Loaders + +#### File Format Loaders +- **CSVLoader**: Load tabular data from CSV files with customizable feature handling +- **UserCSVLoader**: User-defined CSV loading with flexible column specifications +- **SDFLoader**: Process molecular structure files (SDF format) +- **JsonLoader**: Import JSON-structured datasets +- **ImageLoader**: Load image data for computer vision tasks + +#### Biological Data Loaders +- **FASTALoader**: Handle protein/DNA sequences in FASTA format +- **FASTQLoader**: Process FASTQ sequencing data with quality scores +- **SAMLoader/BAMLoader/CRAMLoader**: Support sequence alignment formats + +#### Specialized Loaders +- **DFTYamlLoader**: Process density functional theory computational data +- **InMemoryLoader**: Load data directly from Python objects + +### Dataset Classes + +- **NumpyDataset**: Wrap NumPy arrays for in-memory data manipulation +- **DiskDataset**: Manage larger datasets stored on disk, reducing memory overhead +- **ImageDataset**: Specialized container for image-based ML tasks + +### Data Splitters + +#### General Splitters +- **RandomSplitter**: Random dataset partitioning +- **IndexSplitter**: Split by specified indices +- **SpecifiedSplitter**: Use pre-defined splits +- **RandomStratifiedSplitter**: Stratified random splitting +- **SingletaskStratifiedSplitter**: Stratified splitting for single tasks +- **TaskSplitter**: Split for multitask scenarios + +#### Molecule-Specific Splitters +- **ScaffoldSplitter**: Divide molecules by structural scaffolds (prevents data leakage) +- **ButinaSplitter**: Clustering-based molecular splitting +- **FingerprintSplitter**: Split based on molecular fingerprint similarity +- **MaxMinSplitter**: Maximize diversity between training/test sets +- **MolecularWeightSplitter**: Split by molecular weight properties + +**Best Practice**: For drug discovery tasks, use ScaffoldSplitter to prevent overfitting on similar molecular structures. + +### Transformers + +#### Normalization +- **NormalizationTransformer**: Standard normalization (mean=0, std=1) +- **MinMaxTransformer**: Scale features to [0,1] range +- **LogTransformer**: Apply log transformation +- **PowerTransformer**: Box-Cox and Yeo-Johnson transformations +- **CDFTransformer**: Cumulative distribution function normalization + +#### Task-Specific +- **BalancingTransformer**: Address class imbalance +- **FeaturizationTransformer**: Apply dynamic feature engineering +- **CoulombFitTransformer**: Quantum chemistry specific +- **DAGTransformer**: Directed acyclic graph transformations +- **RxnSplitTransformer**: Chemical reaction preprocessing + +## Molecular Featurizers + +### Graph-Based Featurizers +Use these with graph neural networks (GCNs, MPNNs, etc.): + +- **ConvMolFeaturizer**: Graph representations for graph convolutional networks +- **WeaveFeaturizer**: "Weave" graph embeddings +- **MolGraphConvFeaturizer**: Graph convolution-ready representations +- **EquivariantGraphFeaturizer**: Maintains geometric invariance +- **DMPNNFeaturizer**: Directed message-passing neural network inputs +- **GroverFeaturizer**: Pre-trained molecular embeddings + +### Fingerprint-Based Featurizers +Use these with traditional ML (Random Forest, SVM, XGBoost): + +- **MACCSKeysFingerprint**: 167-bit structural keys +- **CircularFingerprint**: Extended connectivity fingerprints (Morgan fingerprints) + - Parameters: `radius` (default 2), `size` (default 2048), `useChirality` (default False) +- **PubChemFingerprint**: 881-bit structural descriptors +- **Mol2VecFingerprint**: Learned molecular vector representations + +### Descriptor Featurizers +Calculate molecular properties directly: + +- **RDKitDescriptors**: ~200 molecular descriptors (MW, LogP, H-donors, H-acceptors, TPSA, etc.) +- **MordredDescriptors**: Comprehensive structural and physicochemical descriptors +- **CoulombMatrix**: Interatomic distance matrices for 3D structures + +### Sequence-Based Featurizers +For recurrent networks and transformers: + +- **SmilesToSeq**: Convert SMILES strings to sequences +- **SmilesToImage**: Generate 2D image representations from SMILES +- **RawFeaturizer**: Pass through raw molecular data unchanged + +### Selection Guide + +| Use Case | Recommended Featurizer | Model Type | +|----------|----------------------|------------| +| Graph neural networks | ConvMolFeaturizer, MolGraphConvFeaturizer | GCN, MPNN, GAT | +| Traditional ML | CircularFingerprint, RDKitDescriptors | Random Forest, XGBoost, SVM | +| Deep learning (non-graph) | CircularFingerprint, Mol2VecFingerprint | Dense networks, CNN | +| Sequence models | SmilesToSeq | LSTM, GRU, Transformer | +| 3D molecular structures | CoulombMatrix | Specialized 3D models | +| Quick baseline | RDKitDescriptors | Linear, Ridge, Lasso | + +## Models + +### Scikit-Learn Integration +- **SklearnModel**: Wrapper for any scikit-learn algorithm + - Usage: `SklearnModel(model=RandomForestRegressor())` + +### Gradient Boosting +- **GBDTModel**: Gradient boosting decision trees (XGBoost, LightGBM) + +### PyTorch Models + +#### Molecular Property Prediction +- **MultitaskRegressor**: Multi-task regression with shared representations +- **MultitaskClassifier**: Multi-task classification +- **MultitaskFitTransformRegressor**: Regression with learned transformations +- **GCNModel**: Graph convolutional networks +- **GATModel**: Graph attention networks +- **AttentiveFPModel**: Attentive fingerprint networks +- **DMPNNModel**: Directed message passing neural networks +- **GroverModel**: GROVER pre-trained transformer +- **MATModel**: Molecule attention transformer + +#### Materials Science +- **CGCNNModel**: Crystal graph convolutional networks +- **MEGNetModel**: Materials graph networks +- **LCNNModel**: Lattice CNN for materials + +#### Generative Models +- **GANModel**: Generative adversarial networks +- **WGANModel**: Wasserstein GAN +- **BasicMolGANModel**: Molecular GAN +- **LSTMGenerator**: LSTM-based molecule generation +- **SeqToSeqModel**: Sequence-to-sequence models + +#### Physics-Informed Models +- **PINNModel**: Physics-informed neural networks +- **HNNModel**: Hamiltonian neural networks +- **LNN**: Lagrangian neural networks +- **FNOModel**: Fourier neural operators + +#### Computer Vision +- **CNN**: Convolutional neural networks +- **UNetModel**: U-Net architecture for segmentation +- **InceptionV3Model**: Pre-trained Inception v3 +- **MobileNetV2Model**: Lightweight mobile networks + +### Hugging Face Models + +- **HuggingFaceModel**: General wrapper for HF transformers +- **Chemberta**: Chemical BERT for molecular property prediction +- **MoLFormer**: Molecular transformer architecture +- **ProtBERT**: Protein sequence BERT +- **DeepAbLLM**: Antibody large language models + +### Model Selection Guide + +| Task | Recommended Model | Featurizer | +|------|------------------|------------| +| Small dataset (<1000 samples) | SklearnModel (Random Forest) | CircularFingerprint | +| Medium dataset (1K-100K) | GBDTModel or MultitaskRegressor | CircularFingerprint or ConvMolFeaturizer | +| Large dataset (>100K) | GCNModel, AttentiveFPModel, or DMPNN | MolGraphConvFeaturizer | +| Transfer learning | GroverModel, Chemberta, MoLFormer | Model-specific | +| Materials properties | CGCNNModel, MEGNetModel | Structure-based | +| Molecule generation | BasicMolGANModel, LSTMGenerator | SmilesToSeq | +| Protein sequences | ProtBERT | Sequence-based | + +## MoleculeNet Datasets + +Quick access to 30+ benchmark datasets via `dc.molnet.load_*()` functions. + +### Classification Datasets +- **load_bace()**: BACE-1 inhibitors (binary classification) +- **load_bbbp()**: Blood-brain barrier penetration +- **load_clintox()**: Clinical toxicity +- **load_hiv()**: HIV inhibition activity +- **load_muv()**: PubChem BioAssay (challenging, sparse) +- **load_pcba()**: PubChem screening data +- **load_sider()**: Adverse drug reactions (multi-label) +- **load_tox21()**: 12 toxicity assays (multi-task) +- **load_toxcast()**: EPA ToxCast screening + +### Regression Datasets +- **load_delaney()**: Aqueous solubility (ESOL) +- **load_freesolv()**: Solvation free energy +- **load_lipo()**: Lipophilicity (octanol-water partition) +- **load_qm7/qm8/qm9()**: Quantum mechanical properties +- **load_hopv()**: Organic photovoltaic properties + +### Protein-Ligand Binding +- **load_pdbbind()**: Binding affinity data + +### Materials Science +- **load_perovskite()**: Perovskite stability +- **load_mp_formation_energy()**: Materials Project formation energy +- **load_mp_metallicity()**: Metal vs. non-metal classification +- **load_bandgap()**: Electronic bandgap prediction + +### Chemical Reactions +- **load_uspto()**: USPTO reaction dataset + +### Usage Pattern +```python +tasks, datasets, transformers = dc.molnet.load_bbbp( + featurizer='GraphConv', # or 'ECFP', 'GraphConv', 'Weave', etc. + splitter='scaffold', # or 'random', 'stratified', etc. + reload=False # set True to skip caching +) +train, valid, test = datasets +``` + +## Metrics + +Common evaluation metrics available in `dc.metrics`: + +### Classification Metrics +- **roc_auc_score**: Area under ROC curve (binary/multi-class) +- **prc_auc_score**: Area under precision-recall curve +- **accuracy_score**: Classification accuracy +- **balanced_accuracy_score**: Balanced accuracy for imbalanced datasets +- **recall_score**: Sensitivity/recall +- **precision_score**: Precision +- **f1_score**: F1 score + +### Regression Metrics +- **mean_absolute_error**: MAE +- **mean_squared_error**: MSE +- **root_mean_squared_error**: RMSE +- **r2_score**: R² coefficient of determination +- **pearson_r2_score**: Pearson correlation +- **spearman_correlation**: Spearman rank correlation + +### Multi-Task Metrics +Most metrics support multi-task evaluation by averaging over tasks. + +## Training Pattern + +Standard DeepChem workflow: + +```python +# 1. Load data +loader = dc.data.CSVLoader(tasks=['task1'], feature_field='smiles', + featurizer=dc.feat.CircularFingerprint()) +dataset = loader.create_dataset('data.csv') + +# 2. Split data +splitter = dc.splits.ScaffoldSplitter() +train, valid, test = splitter.train_valid_test_split(dataset) + +# 3. Transform data (optional) +transformers = [dc.trans.NormalizationTransformer(dataset=train)] +for transformer in transformers: + train = transformer.transform(train) + valid = transformer.transform(valid) + test = transformer.transform(test) + +# 4. Create and train model +model = dc.models.MultitaskRegressor(n_tasks=1, n_features=2048, layer_sizes=[1000]) +model.fit(train, nb_epoch=50) + +# 5. Evaluate +metric = dc.metrics.Metric(dc.metrics.r2_score) +train_score = model.evaluate(train, [metric]) +test_score = model.evaluate(test, [metric]) +``` + +## Common Patterns + +### Pattern 1: Quick Baseline with MoleculeNet +```python +tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='ECFP') +train, valid, test = datasets +model = dc.models.MultitaskClassifier(n_tasks=len(tasks), n_features=1024) +model.fit(train) +``` + +### Pattern 2: Custom Data with Graph Networks +```python +featurizer = dc.feat.MolGraphConvFeaturizer() +loader = dc.data.CSVLoader(tasks=['activity'], feature_field='smiles', + featurizer=featurizer) +dataset = loader.create_dataset('my_data.csv') +train, test = dc.splits.RandomSplitter().train_test_split(dataset) +model = dc.models.GCNModel(mode='classification', n_tasks=1) +model.fit(train) +``` + +### Pattern 3: Transfer Learning with Pretrained Models +```python +model = dc.models.GroverModel(task='classification', n_tasks=1) +model.fit(train_dataset) +predictions = model.predict(test_dataset) +``` diff --git a/scientific-packages/deepchem/references/workflows.md b/scientific-packages/deepchem/references/workflows.md new file mode 100644 index 0000000..9b98011 --- /dev/null +++ b/scientific-packages/deepchem/references/workflows.md @@ -0,0 +1,491 @@ +# DeepChem Workflows + +This document provides detailed workflows for common DeepChem use cases. + +## Workflow 1: Molecular Property Prediction from SMILES + +**Goal**: Predict molecular properties (e.g., solubility, toxicity, activity) from SMILES strings. + +### Step-by-Step Process + +#### 1. Prepare Your Data +Data should be in CSV format with at minimum: +- A column with SMILES strings +- One or more columns with property values (targets) + +Example CSV structure: +```csv +smiles,solubility,toxicity +CCO,-0.77,0 +CC(=O)OC1=CC=CC=C1C(=O)O,-1.19,1 +``` + +#### 2. Choose Featurizer +Decision tree: +- **Small dataset (<1K)**: Use `CircularFingerprint` or `RDKitDescriptors` +- **Medium dataset (1K-100K)**: Use `CircularFingerprint` or `MolGraphConvFeaturizer` +- **Large dataset (>100K)**: Use graph-based featurizers (`MolGraphConvFeaturizer`, `DMPNNFeaturizer`) +- **Transfer learning**: Use pretrained model featurizers (`GroverFeaturizer`) + +#### 3. Load and Featurize Data +```python +import deepchem as dc + +# For fingerprint-based +featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) +# OR for graph-based +featurizer = dc.feat.MolGraphConvFeaturizer() + +loader = dc.data.CSVLoader( + tasks=['solubility', 'toxicity'], # column names to predict + feature_field='smiles', # column with SMILES + featurizer=featurizer +) +dataset = loader.create_dataset('data.csv') +``` + +#### 4. Split Data +**Critical**: Use `ScaffoldSplitter` for drug discovery to prevent data leakage. + +```python +splitter = dc.splits.ScaffoldSplitter() +train, valid, test = splitter.train_valid_test_split( + dataset, + frac_train=0.8, + frac_valid=0.1, + frac_test=0.1 +) +``` + +#### 5. Transform Data (Optional but Recommended) +```python +transformers = [ + dc.trans.NormalizationTransformer( + transform_y=True, + dataset=train + ) +] + +for transformer in transformers: + train = transformer.transform(train) + valid = transformer.transform(valid) + test = transformer.transform(test) +``` + +#### 6. Select and Train Model +```python +# For fingerprints +model = dc.models.MultitaskRegressor( + n_tasks=2, # number of properties to predict + n_features=2048, # fingerprint size + layer_sizes=[1000, 500], # hidden layer sizes + dropouts=0.25, + learning_rate=0.001 +) + +# OR for graphs +model = dc.models.GCNModel( + n_tasks=2, + mode='regression', + batch_size=128, + learning_rate=0.001 +) + +# Train +model.fit(train, nb_epoch=50) +``` + +#### 7. Evaluate +```python +metric = dc.metrics.Metric(dc.metrics.r2_score) +train_score = model.evaluate(train, [metric]) +valid_score = model.evaluate(valid, [metric]) +test_score = model.evaluate(test, [metric]) + +print(f"Train R²: {train_score}") +print(f"Valid R²: {valid_score}") +print(f"Test R²: {test_score}") +``` + +#### 8. Make Predictions +```python +# Predict on new molecules +new_smiles = ['CCO', 'CC(C)O', 'c1ccccc1'] +new_featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) +new_features = new_featurizer.featurize(new_smiles) +new_dataset = dc.data.NumpyDataset(X=new_features) + +# Apply same transformations +for transformer in transformers: + new_dataset = transformer.transform(new_dataset) + +predictions = model.predict(new_dataset) +``` + +--- + +## Workflow 2: Using MoleculeNet Benchmark Datasets + +**Goal**: Quickly train and evaluate models on standard benchmarks. + +### Quick Start +```python +import deepchem as dc + +# Load benchmark dataset +tasks, datasets, transformers = dc.molnet.load_tox21( + featurizer='GraphConv', + splitter='scaffold' +) +train, valid, test = datasets + +# Train model +model = dc.models.GCNModel( + n_tasks=len(tasks), + mode='classification' +) +model.fit(train, nb_epoch=50) + +# Evaluate +metric = dc.metrics.Metric(dc.metrics.roc_auc_score) +test_score = model.evaluate(test, [metric]) +print(f"Test ROC-AUC: {test_score}") +``` + +### Available Featurizer Options +When calling `load_*()` functions: +- `'ECFP'`: Extended-connectivity fingerprints (circular fingerprints) +- `'GraphConv'`: Graph convolution features +- `'Weave'`: Weave features +- `'Raw'`: Raw SMILES strings +- `'smiles2img'`: 2D molecular images + +### Available Splitter Options +- `'scaffold'`: Scaffold-based splitting (recommended for drug discovery) +- `'random'`: Random splitting +- `'stratified'`: Stratified splitting (preserves class distributions) +- `'butina'`: Butina clustering-based splitting + +--- + +## Workflow 3: Hyperparameter Optimization + +**Goal**: Find optimal model hyperparameters systematically. + +### Using GridHyperparamOpt +```python +import deepchem as dc +import numpy as np + +# Load data +tasks, datasets, transformers = dc.molnet.load_bbbp( + featurizer='ECFP', + splitter='scaffold' +) +train, valid, test = datasets + +# Define parameter grid +params_dict = { + 'layer_sizes': [[1000], [1000, 500], [1000, 1000]], + 'dropouts': [0.0, 0.25, 0.5], + 'learning_rate': [0.001, 0.0001] +} + +# Define model builder function +def model_builder(model_params, model_dir): + return dc.models.MultitaskClassifier( + n_tasks=len(tasks), + n_features=1024, + **model_params + ) + +# Setup optimizer +metric = dc.metrics.Metric(dc.metrics.roc_auc_score) +optimizer = dc.hyper.GridHyperparamOpt(model_builder) + +# Run optimization +best_model, best_params, all_results = optimizer.hyperparam_search( + params_dict, + train, + valid, + metric, + transformers=transformers +) + +print(f"Best parameters: {best_params}") +print(f"Best validation score: {all_results['best_validation_score']}") +``` + +--- + +## Workflow 4: Transfer Learning with Pretrained Models + +**Goal**: Leverage pretrained models for improved performance on small datasets. + +### Using ChemBERTa +```python +import deepchem as dc +from transformers import AutoTokenizer + +# Load your data +loader = dc.data.CSVLoader( + tasks=['activity'], + feature_field='smiles', + featurizer=dc.feat.DummyFeaturizer() # ChemBERTa handles featurization +) +dataset = loader.create_dataset('data.csv') + +# Split data +splitter = dc.splits.ScaffoldSplitter() +train, test = splitter.train_test_split(dataset) + +# Load pretrained ChemBERTa +model = dc.models.HuggingFaceModel( + model='seyonec/ChemBERTa-zinc-base-v1', + task='regression', + n_tasks=1 +) + +# Fine-tune +model.fit(train, nb_epoch=10) + +# Evaluate +predictions = model.predict(test) +``` + +### Using GROVER +```python +# GROVER: pre-trained on molecular graphs +model = dc.models.GroverModel( + task='classification', + n_tasks=1, + model_dir='./grover_model' +) + +# Fine-tune on your data +model.fit(train_dataset, nb_epoch=20) +``` + +--- + +## Workflow 5: Molecular Generation with GANs + +**Goal**: Generate novel molecules with desired properties. + +### Basic MolGAN +```python +import deepchem as dc + +# Load training data (molecules for the generator to learn from) +tasks, datasets, _ = dc.molnet.load_qm9( + featurizer='GraphConv', + splitter='random' +) +train, _, _ = datasets + +# Create and train MolGAN +gan = dc.models.BasicMolGANModel( + learning_rate=0.001, + vertices=9, # max atoms in molecule + edges=5, # max bonds + nodes=[128, 256, 512] +) + +# Train +gan.fit_gan( + train, + nb_epoch=100, + generator_steps=0.2, + checkpoint_interval=10 +) + +# Generate new molecules +generated_molecules = gan.predict_gan_generator(1000) +``` + +### Conditional Generation +```python +# For property-targeted generation +from deepchem.models.optimizers import ExponentialDecay + +gan = dc.models.BasicMolGANModel( + learning_rate=ExponentialDecay(0.001, 0.9, 1000), + conditional=True # enable conditional generation +) + +# Train with properties +gan.fit_gan(train, nb_epoch=100) + +# Generate molecules with target properties +target_properties = np.array([[5.0, 300.0]]) # e.g., [logP, MW] +molecules = gan.predict_gan_generator( + 1000, + conditional_inputs=target_properties +) +``` + +--- + +## Workflow 6: Materials Property Prediction + +**Goal**: Predict properties of crystalline materials. + +### Using Crystal Graph Convolutional Networks +```python +import deepchem as dc + +# Load materials data (structure files in CIF format) +loader = dc.data.CIFLoader() +dataset = loader.create_dataset('materials.csv') + +# Split data +splitter = dc.splits.RandomSplitter() +train, test = splitter.train_test_split(dataset) + +# Create CGCNN model +model = dc.models.CGCNNModel( + n_tasks=1, + mode='regression', + batch_size=32, + learning_rate=0.001 +) + +# Train +model.fit(train, nb_epoch=100) + +# Evaluate +metric = dc.metrics.Metric(dc.metrics.mae_score) +test_score = model.evaluate(test, [metric]) +``` + +--- + +## Workflow 7: Protein Sequence Analysis + +**Goal**: Predict protein properties from sequences. + +### Using ProtBERT +```python +import deepchem as dc + +# Load protein sequence data +loader = dc.data.FASTALoader() +dataset = loader.create_dataset('proteins.fasta') + +# Use ProtBERT +model = dc.models.HuggingFaceModel( + model='Rostlab/prot_bert', + task='classification', + n_tasks=1 +) + +# Split and train +splitter = dc.splits.RandomSplitter() +train, test = splitter.train_test_split(dataset) +model.fit(train, nb_epoch=5) + +# Predict +predictions = model.predict(test) +``` + +--- + +## Workflow 8: Custom Model Integration + +**Goal**: Use your own PyTorch/scikit-learn models with DeepChem. + +### Wrapping Scikit-Learn Models +```python +from sklearn.ensemble import RandomForestRegressor +import deepchem as dc + +# Create scikit-learn model +sklearn_model = RandomForestRegressor( + n_estimators=100, + max_depth=10, + random_state=42 +) + +# Wrap in DeepChem +model = dc.models.SklearnModel(model=sklearn_model) + +# Use with DeepChem datasets +model.fit(train) +predictions = model.predict(test) + +# Evaluate +metric = dc.metrics.Metric(dc.metrics.r2_score) +score = model.evaluate(test, [metric]) +``` + +### Creating Custom PyTorch Models +```python +import torch +import torch.nn as nn +import deepchem as dc + +class CustomNetwork(nn.Module): + def __init__(self, n_features, n_tasks): + super().__init__() + self.fc1 = nn.Linear(n_features, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, n_tasks) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.2) + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.dropout(x) + x = self.relu(self.fc2(x)) + x = self.dropout(x) + return self.fc3(x) + +# Wrap in DeepChem TorchModel +model = dc.models.TorchModel( + model=CustomNetwork(n_features=2048, n_tasks=1), + loss=nn.MSELoss(), + output_types=['prediction'] +) + +# Train +model.fit(train, nb_epoch=50) +``` + +--- + +## Common Pitfalls and Solutions + +### Issue 1: Data Leakage in Drug Discovery +**Problem**: Using random splitting allows similar molecules in train and test sets. +**Solution**: Always use `ScaffoldSplitter` for molecular datasets. + +### Issue 2: Imbalanced Classification +**Problem**: Poor performance on minority class. +**Solution**: Use `BalancingTransformer` or weighted metrics. +```python +transformer = dc.trans.BalancingTransformer(dataset=train) +train = transformer.transform(train) +``` + +### Issue 3: Memory Issues with Large Datasets +**Problem**: Dataset doesn't fit in memory. +**Solution**: Use `DiskDataset` instead of `NumpyDataset`. +```python +dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids) +``` + +### Issue 4: Overfitting on Small Datasets +**Problem**: Model memorizes training data. +**Solutions**: +1. Use stronger regularization (increase dropout) +2. Use simpler models (Random Forest, Ridge) +3. Apply transfer learning (pretrained models) +4. Collect more data + +### Issue 5: Poor Graph Neural Network Performance +**Problem**: GNN performs worse than fingerprints. +**Solutions**: +1. Check if dataset is large enough (GNNs need >10K samples typically) +2. Increase training epochs +3. Try different GNN architectures (AttentiveFP, DMPNN) +4. Use pretrained models (GROVER) diff --git a/scientific-packages/deepchem/scripts/graph_neural_network.py b/scientific-packages/deepchem/scripts/graph_neural_network.py new file mode 100644 index 0000000..6863b5b --- /dev/null +++ b/scientific-packages/deepchem/scripts/graph_neural_network.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +Graph Neural Network Training Script + +This script demonstrates training Graph Convolutional Networks (GCNs) and other +graph-based models for molecular property prediction. + +Usage: + python graph_neural_network.py --dataset tox21 --model gcn + python graph_neural_network.py --dataset bbbp --model attentivefp + python graph_neural_network.py --data custom.csv --task-type regression +""" + +import argparse +import deepchem as dc +import sys + + +AVAILABLE_MODELS = { + 'gcn': 'Graph Convolutional Network', + 'gat': 'Graph Attention Network', + 'attentivefp': 'Attentive Fingerprint', + 'mpnn': 'Message Passing Neural Network', + 'dmpnn': 'Directed Message Passing Neural Network' +} + +MOLNET_DATASETS = { + 'tox21': ('classification', 12), + 'bbbp': ('classification', 1), + 'bace': ('classification', 1), + 'hiv': ('classification', 1), + 'delaney': ('regression', 1), + 'freesolv': ('regression', 1), + 'lipo': ('regression', 1) +} + + +def create_model(model_type, n_tasks, mode='classification'): + """ + Create a graph neural network model. + + Args: + model_type: Type of model ('gcn', 'gat', 'attentivefp', etc.) + n_tasks: Number of prediction tasks + mode: 'classification' or 'regression' + + Returns: + DeepChem model + """ + if model_type == 'gcn': + return dc.models.GCNModel( + n_tasks=n_tasks, + mode=mode, + batch_size=128, + learning_rate=0.001, + dropout=0.0 + ) + elif model_type == 'gat': + return dc.models.GATModel( + n_tasks=n_tasks, + mode=mode, + batch_size=128, + learning_rate=0.001 + ) + elif model_type == 'attentivefp': + return dc.models.AttentiveFPModel( + n_tasks=n_tasks, + mode=mode, + batch_size=128, + learning_rate=0.001 + ) + elif model_type == 'mpnn': + return dc.models.MPNNModel( + n_tasks=n_tasks, + mode=mode, + batch_size=128, + learning_rate=0.001 + ) + elif model_type == 'dmpnn': + return dc.models.DMPNNModel( + n_tasks=n_tasks, + mode=mode, + batch_size=128, + learning_rate=0.001 + ) + else: + raise ValueError(f"Unknown model type: {model_type}") + + +def train_on_molnet(dataset_name, model_type, n_epochs=50): + """ + Train a graph neural network on a MoleculeNet benchmark dataset. + + Args: + dataset_name: Name of MoleculeNet dataset + model_type: Type of model to train + n_epochs: Number of training epochs + + Returns: + Trained model and test scores + """ + print("=" * 70) + print(f"Training {AVAILABLE_MODELS[model_type]} on {dataset_name.upper()}") + print("=" * 70) + + # Get dataset info + task_type, n_tasks_default = MOLNET_DATASETS[dataset_name] + + # Load dataset with graph featurization + print(f"\nLoading {dataset_name} dataset with GraphConv featurizer...") + load_func = getattr(dc.molnet, f'load_{dataset_name}') + tasks, datasets, transformers = load_func( + featurizer='GraphConv', + splitter='scaffold' + ) + train, valid, test = datasets + + n_tasks = len(tasks) + print(f"\nDataset Information:") + print(f" Task type: {task_type}") + print(f" Number of tasks: {n_tasks}") + print(f" Training samples: {len(train)}") + print(f" Validation samples: {len(valid)}") + print(f" Test samples: {len(test)}") + + # Create model + print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...") + model = create_model(model_type, n_tasks, mode=task_type) + + # Train + print(f"\nTraining for {n_epochs} epochs...") + model.fit(train, nb_epoch=n_epochs) + print("Training complete!") + + # Evaluate + print("\n" + "=" * 70) + print("Model Evaluation") + print("=" * 70) + + if task_type == 'classification': + metrics = [ + dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'), + dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'), + dc.metrics.Metric(dc.metrics.f1_score, name='F1'), + ] + else: + metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'), + ] + + results = {} + for dataset_name_eval, dataset in [('Train', train), ('Valid', valid), ('Test', test)]: + print(f"\n{dataset_name_eval} Set:") + scores = model.evaluate(dataset, metrics) + results[dataset_name_eval] = scores + for metric_name, score in scores.items(): + print(f" {metric_name}: {score:.4f}") + + return model, results + + +def train_on_custom_data(data_path, model_type, task_type, target_cols, smiles_col='smiles', n_epochs=50): + """ + Train a graph neural network on custom CSV data. + + Args: + data_path: Path to CSV file + model_type: Type of model to train + task_type: 'classification' or 'regression' + target_cols: List of target column names + smiles_col: Name of SMILES column + n_epochs: Number of training epochs + + Returns: + Trained model and test dataset + """ + print("=" * 70) + print(f"Training {AVAILABLE_MODELS[model_type]} on Custom Data") + print("=" * 70) + + # Load and featurize data + print(f"\nLoading data from {data_path}...") + featurizer = dc.feat.MolGraphConvFeaturizer() + loader = dc.data.CSVLoader( + tasks=target_cols, + feature_field=smiles_col, + featurizer=featurizer + ) + dataset = loader.create_dataset(data_path) + + print(f"Loaded {len(dataset)} molecules") + + # Split data + print("\nSplitting data with scaffold splitter...") + splitter = dc.splits.ScaffoldSplitter() + train, valid, test = splitter.train_valid_test_split( + dataset, + frac_train=0.8, + frac_valid=0.1, + frac_test=0.1 + ) + + print(f" Training: {len(train)}") + print(f" Validation: {len(valid)}") + print(f" Test: {len(test)}") + + # Create model + print(f"\nCreating {AVAILABLE_MODELS[model_type]} model...") + n_tasks = len(target_cols) + model = create_model(model_type, n_tasks, mode=task_type) + + # Train + print(f"\nTraining for {n_epochs} epochs...") + model.fit(train, nb_epoch=n_epochs) + print("Training complete!") + + # Evaluate + print("\n" + "=" * 70) + print("Model Evaluation") + print("=" * 70) + + if task_type == 'classification': + metrics = [ + dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'), + dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'), + ] + else: + metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + ] + + for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]: + print(f"\n{dataset_name} Set:") + scores = model.evaluate(dataset, metrics) + for metric_name, score in scores.items(): + print(f" {metric_name}: {score:.4f}") + + return model, test + + +def main(): + parser = argparse.ArgumentParser( + description='Train graph neural networks for molecular property prediction' + ) + parser.add_argument( + '--model', + type=str, + choices=list(AVAILABLE_MODELS.keys()), + default='gcn', + help='Type of graph neural network model' + ) + parser.add_argument( + '--dataset', + type=str, + choices=list(MOLNET_DATASETS.keys()), + default=None, + help='MoleculeNet dataset to use' + ) + parser.add_argument( + '--data', + type=str, + default=None, + help='Path to custom CSV file' + ) + parser.add_argument( + '--task-type', + type=str, + choices=['classification', 'regression'], + default='classification', + help='Type of prediction task (for custom data)' + ) + parser.add_argument( + '--targets', + nargs='+', + default=['target'], + help='Names of target columns (for custom data)' + ) + parser.add_argument( + '--smiles-col', + type=str, + default='smiles', + help='Name of SMILES column' + ) + parser.add_argument( + '--epochs', + type=int, + default=50, + help='Number of training epochs' + ) + + args = parser.parse_args() + + # Validate arguments + if args.dataset is None and args.data is None: + print("Error: Must specify either --dataset (MoleculeNet) or --data (custom CSV)", + file=sys.stderr) + return 1 + + if args.dataset and args.data: + print("Error: Cannot specify both --dataset and --data", + file=sys.stderr) + return 1 + + # Train model + try: + if args.dataset: + model, results = train_on_molnet( + args.dataset, + args.model, + n_epochs=args.epochs + ) + else: + model, test_set = train_on_custom_data( + args.data, + args.model, + args.task_type, + args.targets, + smiles_col=args.smiles_col, + n_epochs=args.epochs + ) + + print("\n" + "=" * 70) + print("Training Complete!") + print("=" * 70) + return 0 + + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/deepchem/scripts/predict_solubility.py b/scientific-packages/deepchem/scripts/predict_solubility.py new file mode 100644 index 0000000..a33ba35 --- /dev/null +++ b/scientific-packages/deepchem/scripts/predict_solubility.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +""" +Molecular Solubility Prediction Script + +This script trains a model to predict aqueous solubility from SMILES strings +using the Delaney (ESOL) dataset as an example. Can be adapted for custom datasets. + +Usage: + python predict_solubility.py --data custom_data.csv --smiles-col smiles --target-col solubility + python predict_solubility.py # Uses Delaney dataset by default +""" + +import argparse +import deepchem as dc +import numpy as np +import sys + + +def train_solubility_model(data_path=None, smiles_col='smiles', target_col='measured log solubility in mols per litre'): + """ + Train a solubility prediction model. + + Args: + data_path: Path to CSV file with SMILES and solubility data. If None, uses Delaney dataset. + smiles_col: Name of column containing SMILES strings + target_col: Name of column containing solubility values + + Returns: + Trained model, test dataset, and transformers + """ + print("=" * 60) + print("DeepChem Solubility Prediction") + print("=" * 60) + + # Load data + if data_path is None: + print("\nUsing Delaney (ESOL) benchmark dataset...") + tasks, datasets, transformers = dc.molnet.load_delaney( + featurizer='ECFP', + splitter='scaffold' + ) + train, valid, test = datasets + else: + print(f"\nLoading custom data from {data_path}...") + featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) + loader = dc.data.CSVLoader( + tasks=[target_col], + feature_field=smiles_col, + featurizer=featurizer + ) + dataset = loader.create_dataset(data_path) + + # Split data + print("Splitting data with scaffold splitter...") + splitter = dc.splits.ScaffoldSplitter() + train, valid, test = splitter.train_valid_test_split( + dataset, + frac_train=0.8, + frac_valid=0.1, + frac_test=0.1 + ) + + # Normalize data + print("Normalizing features and targets...") + transformers = [ + dc.trans.NormalizationTransformer( + transform_y=True, + dataset=train + ) + ] + for transformer in transformers: + train = transformer.transform(train) + valid = transformer.transform(valid) + test = transformer.transform(test) + + tasks = [target_col] + + print(f"\nDataset sizes:") + print(f" Training: {len(train)} molecules") + print(f" Validation: {len(valid)} molecules") + print(f" Test: {len(test)} molecules") + + # Create model + print("\nCreating multitask regressor...") + model = dc.models.MultitaskRegressor( + n_tasks=len(tasks), + n_features=2048, # ECFP fingerprint size + layer_sizes=[1000, 500], + dropouts=0.25, + learning_rate=0.001, + batch_size=50 + ) + + # Train model + print("\nTraining model...") + model.fit(train, nb_epoch=50) + print("Training complete!") + + # Evaluate model + print("\n" + "=" * 60) + print("Model Evaluation") + print("=" * 60) + + metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + dc.metrics.Metric(dc.metrics.root_mean_squared_error, name='RMSE'), + ] + + for dataset_name, dataset in [('Train', train), ('Valid', valid), ('Test', test)]: + print(f"\n{dataset_name} Set:") + scores = model.evaluate(dataset, metrics) + for metric_name, score in scores.items(): + print(f" {metric_name}: {score:.4f}") + + return model, test, transformers + + +def predict_new_molecules(model, smiles_list, transformers=None): + """ + Predict solubility for new molecules. + + Args: + model: Trained DeepChem model + smiles_list: List of SMILES strings + transformers: List of data transformers to apply + + Returns: + Array of predictions + """ + print("\n" + "=" * 60) + print("Predicting New Molecules") + print("=" * 60) + + # Featurize new molecules + featurizer = dc.feat.CircularFingerprint(radius=2, size=2048) + features = featurizer.featurize(smiles_list) + + # Create dataset + new_dataset = dc.data.NumpyDataset(X=features) + + # Apply transformers (if any) + if transformers: + for transformer in transformers: + new_dataset = transformer.transform(new_dataset) + + # Predict + predictions = model.predict(new_dataset) + + # Display results + print("\nPredictions:") + for smiles, pred in zip(smiles_list, predictions): + print(f" {smiles:30s} -> {pred[0]:.3f} log(mol/L)") + + return predictions + + +def main(): + parser = argparse.ArgumentParser( + description='Train a molecular solubility prediction model' + ) + parser.add_argument( + '--data', + type=str, + default=None, + help='Path to CSV file with molecular data' + ) + parser.add_argument( + '--smiles-col', + type=str, + default='smiles', + help='Name of column containing SMILES strings' + ) + parser.add_argument( + '--target-col', + type=str, + default='solubility', + help='Name of column containing target values' + ) + parser.add_argument( + '--predict', + nargs='+', + default=None, + help='SMILES strings to predict after training' + ) + + args = parser.parse_args() + + # Train model + try: + model, test_set, transformers = train_solubility_model( + data_path=args.data, + smiles_col=args.smiles_col, + target_col=args.target_col + ) + except Exception as e: + print(f"\nError during training: {e}", file=sys.stderr) + return 1 + + # Make predictions on new molecules if provided + if args.predict: + try: + predict_new_molecules(model, args.predict, transformers) + except Exception as e: + print(f"\nError during prediction: {e}", file=sys.stderr) + return 1 + else: + # Example predictions + example_smiles = [ + 'CCO', # Ethanol + 'CC(=O)O', # Acetic acid + 'c1ccccc1', # Benzene + 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', # Caffeine + ] + predict_new_molecules(model, example_smiles, transformers) + + print("\n" + "=" * 60) + print("Complete!") + print("=" * 60) + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/deepchem/scripts/transfer_learning.py b/scientific-packages/deepchem/scripts/transfer_learning.py new file mode 100644 index 0000000..1665334 --- /dev/null +++ b/scientific-packages/deepchem/scripts/transfer_learning.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 +""" +Transfer Learning Script for DeepChem + +Use pretrained models (ChemBERTa, GROVER, MolFormer) for molecular property prediction +with transfer learning. Particularly useful for small datasets. + +Usage: + python transfer_learning.py --model chemberta --data my_data.csv --target activity + python transfer_learning.py --model grover --dataset bbbp +""" + +import argparse +import deepchem as dc +import sys + + +PRETRAINED_MODELS = { + 'chemberta': { + 'name': 'ChemBERTa', + 'description': 'BERT pretrained on 77M molecules from ZINC15', + 'model_id': 'seyonec/ChemBERTa-zinc-base-v1' + }, + 'grover': { + 'name': 'GROVER', + 'description': 'Graph transformer pretrained on 10M molecules', + 'model_id': None # GROVER uses its own loading mechanism + }, + 'molformer': { + 'name': 'MolFormer', + 'description': 'Transformer pretrained on molecular structures', + 'model_id': 'ibm/MoLFormer-XL-both-10pct' + } +} + + +def train_chemberta(train_dataset, valid_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=10): + """ + Fine-tune ChemBERTa on a dataset. + + Args: + train_dataset: Training dataset + valid_dataset: Validation dataset + test_dataset: Test dataset + task_type: 'classification' or 'regression' + n_tasks: Number of prediction tasks + n_epochs: Number of fine-tuning epochs + + Returns: + Trained model and evaluation results + """ + print("=" * 70) + print("Fine-tuning ChemBERTa") + print("=" * 70) + print("\nChemBERTa is a BERT model pretrained on 77M molecules from ZINC15.") + print("It uses SMILES strings as input and has learned rich molecular") + print("representations that transfer well to downstream tasks.") + + print(f"\nLoading pretrained ChemBERTa model...") + model = dc.models.HuggingFaceModel( + model=PRETRAINED_MODELS['chemberta']['model_id'], + task=task_type, + n_tasks=n_tasks, + batch_size=32, + learning_rate=2e-5 # Lower LR for fine-tuning + ) + + print(f"\nFine-tuning for {n_epochs} epochs...") + print("(This may take a while on the first run as the model is downloaded)") + model.fit(train_dataset, nb_epoch=n_epochs) + print("Fine-tuning complete!") + + # Evaluate + print("\n" + "=" * 70) + print("Model Evaluation") + print("=" * 70) + + if task_type == 'classification': + metrics = [ + dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'), + dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'), + ] + else: + metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + ] + + results = {} + for name, dataset in [('Train', train_dataset), ('Valid', valid_dataset), ('Test', test_dataset)]: + print(f"\n{name} Set:") + scores = model.evaluate(dataset, metrics) + results[name] = scores + for metric_name, score in scores.items(): + print(f" {metric_name}: {score:.4f}") + + return model, results + + +def train_grover(train_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=20): + """ + Fine-tune GROVER on a dataset. + + Args: + train_dataset: Training dataset + test_dataset: Test dataset + task_type: 'classification' or 'regression' + n_tasks: Number of prediction tasks + n_epochs: Number of fine-tuning epochs + + Returns: + Trained model and evaluation results + """ + print("=" * 70) + print("Fine-tuning GROVER") + print("=" * 70) + print("\nGROVER is a graph transformer pretrained on 10M molecules using") + print("self-supervised learning. It learns both node and graph-level") + print("representations through masked atom/bond prediction tasks.") + + print(f"\nCreating GROVER model...") + model = dc.models.GroverModel( + task=task_type, + n_tasks=n_tasks, + model_dir='./grover_pretrained' + ) + + print(f"\nFine-tuning for {n_epochs} epochs...") + model.fit(train_dataset, nb_epoch=n_epochs) + print("Fine-tuning complete!") + + # Evaluate + print("\n" + "=" * 70) + print("Model Evaluation") + print("=" * 70) + + if task_type == 'classification': + metrics = [ + dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'), + dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'), + ] + else: + metrics = [ + dc.metrics.Metric(dc.metrics.r2_score, name='R²'), + dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'), + ] + + results = {} + for name, dataset in [('Train', train_dataset), ('Test', test_dataset)]: + print(f"\n{name} Set:") + scores = model.evaluate(dataset, metrics) + results[name] = scores + for metric_name, score in scores.items(): + print(f" {metric_name}: {score:.4f}") + + return model, results + + +def load_molnet_dataset(dataset_name, model_type): + """ + Load a MoleculeNet dataset with appropriate featurization. + + Args: + dataset_name: Name of MoleculeNet dataset + model_type: Type of pretrained model being used + + Returns: + tasks, train/valid/test datasets, transformers + """ + # Map of MoleculeNet datasets + molnet_datasets = { + 'tox21': dc.molnet.load_tox21, + 'bbbp': dc.molnet.load_bbbp, + 'bace': dc.molnet.load_bace_classification, + 'hiv': dc.molnet.load_hiv, + 'delaney': dc.molnet.load_delaney, + 'freesolv': dc.molnet.load_freesolv, + 'lipo': dc.molnet.load_lipo + } + + if dataset_name not in molnet_datasets: + raise ValueError(f"Unknown dataset: {dataset_name}") + + # ChemBERTa and MolFormer use raw SMILES + if model_type in ['chemberta', 'molformer']: + featurizer = 'Raw' + # GROVER needs graph features + elif model_type == 'grover': + featurizer = 'GraphConv' + else: + featurizer = 'ECFP' + + print(f"\nLoading {dataset_name} dataset...") + load_func = molnet_datasets[dataset_name] + tasks, datasets, transformers = load_func( + featurizer=featurizer, + splitter='scaffold' + ) + + return tasks, datasets, transformers + + +def load_custom_dataset(data_path, target_cols, smiles_col, model_type): + """ + Load a custom CSV dataset. + + Args: + data_path: Path to CSV file + target_cols: List of target column names + smiles_col: Name of SMILES column + model_type: Type of pretrained model being used + + Returns: + train, valid, test datasets + """ + print(f"\nLoading custom data from {data_path}...") + + # Choose featurizer based on model + if model_type in ['chemberta', 'molformer']: + featurizer = dc.feat.DummyFeaturizer() # Models handle featurization + elif model_type == 'grover': + featurizer = dc.feat.MolGraphConvFeaturizer() + else: + featurizer = dc.feat.CircularFingerprint() + + loader = dc.data.CSVLoader( + tasks=target_cols, + feature_field=smiles_col, + featurizer=featurizer + ) + dataset = loader.create_dataset(data_path) + + print(f"Loaded {len(dataset)} molecules") + + # Split data + print("Splitting data with scaffold splitter...") + splitter = dc.splits.ScaffoldSplitter() + train, valid, test = splitter.train_valid_test_split( + dataset, + frac_train=0.8, + frac_valid=0.1, + frac_test=0.1 + ) + + print(f" Training: {len(train)}") + print(f" Validation: {len(valid)}") + print(f" Test: {len(test)}") + + return train, valid, test + + +def main(): + parser = argparse.ArgumentParser( + description='Transfer learning for molecular property prediction' + ) + parser.add_argument( + '--model', + type=str, + choices=list(PRETRAINED_MODELS.keys()), + required=True, + help='Pretrained model to use' + ) + parser.add_argument( + '--dataset', + type=str, + choices=['tox21', 'bbbp', 'bace', 'hiv', 'delaney', 'freesolv', 'lipo'], + default=None, + help='MoleculeNet dataset to use' + ) + parser.add_argument( + '--data', + type=str, + default=None, + help='Path to custom CSV file' + ) + parser.add_argument( + '--target', + nargs='+', + default=['target'], + help='Target column name(s) for custom data' + ) + parser.add_argument( + '--smiles-col', + type=str, + default='smiles', + help='SMILES column name for custom data' + ) + parser.add_argument( + '--task-type', + type=str, + choices=['classification', 'regression'], + default='classification', + help='Type of prediction task' + ) + parser.add_argument( + '--epochs', + type=int, + default=10, + help='Number of fine-tuning epochs' + ) + + args = parser.parse_args() + + # Validate arguments + if args.dataset is None and args.data is None: + print("Error: Must specify either --dataset or --data", file=sys.stderr) + return 1 + + if args.dataset and args.data: + print("Error: Cannot specify both --dataset and --data", file=sys.stderr) + return 1 + + # Print model info + model_info = PRETRAINED_MODELS[args.model] + print("\n" + "=" * 70) + print(f"Transfer Learning with {model_info['name']}") + print("=" * 70) + print(f"\n{model_info['description']}") + + try: + # Load dataset + if args.dataset: + tasks, datasets, transformers = load_molnet_dataset(args.dataset, args.model) + train, valid, test = datasets + task_type = 'classification' if args.dataset in ['tox21', 'bbbp', 'bace', 'hiv'] else 'regression' + n_tasks = len(tasks) + else: + train, valid, test = load_custom_dataset( + args.data, + args.target, + args.smiles_col, + args.model + ) + task_type = args.task_type + n_tasks = len(args.target) + + # Train model + if args.model == 'chemberta': + model, results = train_chemberta( + train, valid, test, + task_type=task_type, + n_tasks=n_tasks, + n_epochs=args.epochs + ) + elif args.model == 'grover': + model, results = train_grover( + train, test, + task_type=task_type, + n_tasks=n_tasks, + n_epochs=args.epochs + ) + else: + print(f"Error: Model {args.model} not yet implemented", file=sys.stderr) + return 1 + + print("\n" + "=" * 70) + print("Transfer Learning Complete!") + print("=" * 70) + print("\nTip: Pretrained models often work best with:") + print(" - Small datasets (< 1000 samples)") + print(" - Lower learning rates (1e-5 to 5e-5)") + print(" - Fewer epochs (5-20)") + print(" - Avoiding overfitting through early stopping") + + return 0 + + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/deeptools/SKILL.md b/scientific-packages/deeptools/SKILL.md new file mode 100644 index 0000000..3adc503 --- /dev/null +++ b/scientific-packages/deeptools/SKILL.md @@ -0,0 +1,537 @@ +--- +name: deeptools +description: Comprehensive toolkit for analyzing next-generation sequencing (NGS) data including ChIP-seq, RNA-seq, ATAC-seq, and related experiments. Use this skill when working with BAM files, bigWig coverage tracks, or when creating heatmaps, profile plots, and quality control visualizations for genomic data. Applicable for tasks involving read coverage analysis, sample correlation, ChIP enrichment assessment, normalization, and publication-quality visualization generation. +--- + +# deepTools: NGS Data Analysis Toolkit + +## Overview + +deepTools is a comprehensive suite of Python command-line tools designed for processing and analyzing high-throughput sequencing data. This skill provides guidance for using deepTools to perform quality control, normalize data, compare samples, and generate publication-quality visualizations for ChIP-seq, RNA-seq, ATAC-seq, MNase-seq, and other NGS experiments. + +**Core capabilities:** +- Convert BAM alignments to normalized coverage tracks (bigWig/bedGraph) +- Quality control assessment (fingerprint, correlation, coverage) +- Sample comparison and correlation analysis +- Heatmap and profile plot generation around genomic features +- Enrichment analysis and peak region visualization + +## When to Use This Skill + +Invoke this skill when users request tasks involving: + +- **File conversion**: "Convert BAM to bigWig", "generate coverage tracks", "normalize ChIP-seq data" +- **Quality control**: "check ChIP quality", "compare replicates", "assess sequencing depth", "QC analysis" +- **Visualization**: "create heatmap around TSS", "plot ChIP signal", "visualize enrichment", "generate profile plot" +- **Sample comparison**: "compare treatment vs control", "correlate samples", "PCA analysis" +- **Analysis workflows**: "analyze ChIP-seq data", "RNA-seq coverage", "ATAC-seq analysis", "complete workflow" +- **Working with specific file types**: BAM files, bigWig files, BED region files in genomics context + +## Quick Start + +For users new to deepTools, start with file validation and common workflows: + +### 1. Validate Input Files + +Before running any analysis, validate BAM, bigWig, and BED files using the validation script: + +```bash +python scripts/validate_files.py --bam sample1.bam sample2.bam --bed regions.bed +``` + +This checks file existence, BAM indices, and format correctness. + +### 2. Generate Workflow Template + +For standard analyses, use the workflow generator to create customized scripts: + +```bash +# List available workflows +python scripts/workflow_generator.py --list + +# Generate ChIP-seq QC workflow +python scripts/workflow_generator.py chipseq_qc -o qc_workflow.sh \ + --input-bam Input.bam --chip-bams "ChIP1.bam ChIP2.bam" \ + --genome-size 2913022398 + +# Make executable and run +chmod +x qc_workflow.sh +./qc_workflow.sh +``` + +### 3. Most Common Operations + +See `assets/quick_reference.md` for frequently used commands and parameters. + +## Installation + +Guide users to install deepTools using conda (recommended): + +```bash +# Standard installation +conda install -c conda-forge -c bioconda deeptools + +# For M1 Macs +CONDA_SUBDIR=osx-64 conda create -c conda-forge -c bioconda -n deeptools deeptools +``` + +Or using pip: + +```bash +pip install deeptools +``` + +## Core Workflows + +deepTools workflows typically follow this pattern: **QC → Normalization → Comparison/Visualization** + +### ChIP-seq Quality Control Workflow + +When users request ChIP-seq QC or quality assessment: + +1. **Generate workflow script** using `scripts/workflow_generator.py chipseq_qc` +2. **Key QC steps**: + - Sample correlation (multiBamSummary + plotCorrelation) + - PCA analysis (plotPCA) + - Coverage assessment (plotCoverage) + - Fragment size validation (bamPEFragmentSize) + - ChIP enrichment strength (plotFingerprint) + +**Interpreting results:** +- **Correlation**: Replicates should cluster together with high correlation (>0.9) +- **Fingerprint**: Strong ChIP shows steep rise; flat diagonal indicates poor enrichment +- **Coverage**: Assess if sequencing depth is adequate for analysis + +Full workflow details in `references/workflows.md` → "ChIP-seq Quality Control Workflow" + +### ChIP-seq Complete Analysis Workflow + +For full ChIP-seq analysis from BAM to visualizations: + +1. **Generate coverage tracks** with normalization (bamCoverage) +2. **Create comparison tracks** (bamCompare for log2 ratio) +3. **Compute signal matrices** around features (computeMatrix) +4. **Generate visualizations** (plotHeatmap, plotProfile) +5. **Enrichment analysis** at peaks (plotEnrichment) + +Use `scripts/workflow_generator.py chipseq_analysis` to generate template. + +Complete command sequences in `references/workflows.md` → "ChIP-seq Analysis Workflow" + +### RNA-seq Coverage Workflow + +For strand-specific RNA-seq coverage tracks: + +Use bamCoverage with `--filterRNAstrand` to separate forward and reverse strands. + +**Important:** NEVER use `--extendReads` for RNA-seq (would extend over splice junctions). + +Use normalization: CPM for fixed bins, RPKM for gene-level analysis. + +Template available: `scripts/workflow_generator.py rnaseq_coverage` + +Details in `references/workflows.md` → "RNA-seq Coverage Workflow" + +### ATAC-seq Analysis Workflow + +ATAC-seq requires Tn5 offset correction: + +1. **Shift reads** using alignmentSieve with `--ATACshift` +2. **Generate coverage** with bamCoverage +3. **Analyze fragment sizes** (expect nucleosome ladder pattern) +4. **Visualize at peaks** if available + +Template: `scripts/workflow_generator.py atacseq` + +Full workflow in `references/workflows.md` → "ATAC-seq Workflow" + +## Tool Categories and Common Tasks + +### BAM/bigWig Processing + +**Convert BAM to normalized coverage:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPGC --effectiveGenomeSize 2913022398 \ + --binSize 10 --numberOfProcessors 8 +``` + +**Compare two samples (log2 ratio):** +```bash +bamCompare -b1 treatment.bam -b2 control.bam -o ratio.bw \ + --operation log2 --scaleFactorsMethod readCount +``` + +**Key tools:** bamCoverage, bamCompare, multiBamSummary, multiBigwigSummary, correctGCBias, alignmentSieve + +Complete reference: `references/tools_reference.md` → "BAM and bigWig File Processing Tools" + +### Quality Control + +**Check ChIP enrichment:** +```bash +plotFingerprint -b input.bam chip.bam -o fingerprint.png \ + --extendReads 200 --ignoreDuplicates +``` + +**Sample correlation:** +```bash +multiBamSummary bins --bamfiles *.bam -o counts.npz +plotCorrelation -in counts.npz --corMethod pearson \ + --whatToShow heatmap -o correlation.png +``` + +**Key tools:** plotFingerprint, plotCoverage, plotCorrelation, plotPCA, bamPEFragmentSize + +Complete reference: `references/tools_reference.md` → "Quality Control Tools" + +### Visualization + +**Create heatmap around TSS:** +```bash +# Compute matrix +computeMatrix reference-point -S signal.bw -R genes.bed \ + -b 3000 -a 3000 --referencePoint TSS -o matrix.gz + +# Generate heatmap +plotHeatmap -m matrix.gz -o heatmap.png \ + --colorMap RdBu --kmeans 3 +``` + +**Create profile plot:** +```bash +plotProfile -m matrix.gz -o profile.png \ + --plotType lines --colors blue red +``` + +**Key tools:** computeMatrix, plotHeatmap, plotProfile, plotEnrichment + +Complete reference: `references/tools_reference.md` → "Visualization Tools" + +## Normalization Methods + +Choosing the correct normalization is critical for valid comparisons. Consult `references/normalization_methods.md` for comprehensive guidance. + +**Quick selection guide:** + +- **ChIP-seq coverage**: Use RPGC or CPM +- **ChIP-seq comparison**: Use bamCompare with log2 and readCount +- **RNA-seq bins**: Use CPM +- **RNA-seq genes**: Use RPKM (accounts for gene length) +- **ATAC-seq**: Use RPGC or CPM + +**Normalization methods:** +- **RPGC**: 1× genome coverage (requires --effectiveGenomeSize) +- **CPM**: Counts per million mapped reads +- **RPKM**: Reads per kb per million (accounts for region length) +- **BPM**: Bins per million +- **None**: Raw counts (not recommended for comparisons) + +Full explanation: `references/normalization_methods.md` + +## Effective Genome Sizes + +RPGC normalization requires effective genome size. Common values: + +| Organism | Assembly | Size | Usage | +|----------|----------|------|-------| +| Human | GRCh38/hg38 | 2,913,022,398 | `--effectiveGenomeSize 2913022398` | +| Mouse | GRCm38/mm10 | 2,652,783,500 | `--effectiveGenomeSize 2652783500` | +| Zebrafish | GRCz11 | 1,368,780,147 | `--effectiveGenomeSize 1368780147` | +| *Drosophila* | dm6 | 142,573,017 | `--effectiveGenomeSize 142573017` | +| *C. elegans* | ce10/ce11 | 100,286,401 | `--effectiveGenomeSize 100286401` | + +Complete table with read-length-specific values: `references/effective_genome_sizes.md` + +## Common Parameters Across Tools + +Many deepTools commands share these options: + +**Performance:** +- `--numberOfProcessors, -p`: Enable parallel processing (always use available cores) +- `--region`: Process specific regions for testing (e.g., `chr1:1-1000000`) + +**Read Filtering:** +- `--ignoreDuplicates`: Remove PCR duplicates (recommended for most analyses) +- `--minMappingQuality`: Filter by alignment quality (e.g., `--minMappingQuality 10`) +- `--minFragmentLength` / `--maxFragmentLength`: Fragment length bounds +- `--samFlagInclude` / `--samFlagExclude`: SAM flag filtering + +**Read Processing:** +- `--extendReads`: Extend to fragment length (ChIP-seq: YES, RNA-seq: NO) +- `--centerReads`: Center at fragment midpoint for sharper signals + +## Best Practices + +### File Validation +**Always validate files first** using `scripts/validate_files.py` to check: +- File existence and readability +- BAM indices present (.bai files) +- BED format correctness +- File sizes reasonable + +### Analysis Strategy + +1. **Start with QC**: Run correlation, coverage, and fingerprint analysis before proceeding +2. **Test on small regions**: Use `--region chr1:1-10000000` for parameter testing +3. **Document commands**: Save full command lines for reproducibility +4. **Use consistent normalization**: Apply same method across samples in comparisons +5. **Verify genome assembly**: Ensure BAM and BED files use matching genome builds + +### ChIP-seq Specific + +- **Always extend reads** for ChIP-seq: `--extendReads 200` +- **Remove duplicates**: Use `--ignoreDuplicates` in most cases +- **Check enrichment first**: Run plotFingerprint before detailed analysis +- **GC correction**: Only apply if significant bias detected; never use `--ignoreDuplicates` after GC correction + +### RNA-seq Specific + +- **Never extend reads** for RNA-seq (would span splice junctions) +- **Strand-specific**: Use `--filterRNAstrand forward/reverse` for stranded libraries +- **Normalization**: CPM for bins, RPKM for genes + +### ATAC-seq Specific + +- **Apply Tn5 correction**: Use alignmentSieve with `--ATACshift` +- **Fragment filtering**: Set appropriate min/max fragment lengths +- **Check nucleosome pattern**: Fragment size plot should show ladder pattern + +### Performance Optimization + +1. **Use multiple processors**: `--numberOfProcessors 8` (or available cores) +2. **Increase bin size** for faster processing and smaller files +3. **Process chromosomes separately** for memory-limited systems +4. **Pre-filter BAM files** using alignmentSieve to create reusable filtered files +5. **Use bigWig over bedGraph**: Compressed and faster to process + +## Troubleshooting + +### Common Issues + +**BAM index missing:** +```bash +samtools index input.bam +``` + +**Out of memory:** +Process chromosomes individually using `--region`: +```bash +bamCoverage --bam input.bam -o chr1.bw --region chr1 +``` + +**Slow processing:** +Increase `--numberOfProcessors` and/or increase `--binSize` + +**bigWig files too large:** +Increase bin size: `--binSize 50` or larger + +### Validation Errors + +Run validation script to identify issues: +```bash +python scripts/validate_files.py --bam *.bam --bed regions.bed +``` + +Common errors and solutions explained in script output. + +## Reference Documentation + +This skill includes comprehensive reference documentation: + +### references/tools_reference.md +Complete documentation of all deepTools commands organized by category: +- BAM and bigWig processing tools (9 tools) +- Quality control tools (6 tools) +- Visualization tools (3 tools) +- Miscellaneous tools (2 tools) + +Each tool includes: +- Purpose and overview +- Key parameters with explanations +- Usage examples +- Important notes and best practices + +**Use this reference when:** Users ask about specific tools, parameters, or detailed usage. + +### references/workflows.md +Complete workflow examples for common analyses: +- ChIP-seq quality control workflow +- ChIP-seq complete analysis workflow +- RNA-seq coverage workflow +- ATAC-seq analysis workflow +- Multi-sample comparison workflow +- Peak region analysis workflow +- Troubleshooting and performance tips + +**Use this reference when:** Users need complete analysis pipelines or workflow examples. + +### references/normalization_methods.md +Comprehensive guide to normalization methods: +- Detailed explanation of each method (RPGC, CPM, RPKM, BPM, etc.) +- When to use each method +- Formulas and interpretation +- Selection guide by experiment type +- Common pitfalls and solutions +- Quick reference table + +**Use this reference when:** Users ask about normalization, comparing samples, or which method to use. + +### references/effective_genome_sizes.md +Effective genome size values and usage: +- Common organism values (human, mouse, fly, worm, zebrafish) +- Read-length-specific values +- Calculation methods +- When and how to use in commands +- Custom genome calculation instructions + +**Use this reference when:** Users need genome size for RPGC normalization or GC bias correction. + +## Helper Scripts + +### scripts/validate_files.py + +Validates BAM, bigWig, and BED files for deepTools analysis. Checks file existence, indices, and format. + +**Usage:** +```bash +python scripts/validate_files.py --bam sample1.bam sample2.bam \ + --bed peaks.bed --bigwig signal.bw +``` + +**When to use:** Before starting any analysis, or when troubleshooting errors. + +### scripts/workflow_generator.py + +Generates customizable bash script templates for common deepTools workflows. + +**Available workflows:** +- `chipseq_qc`: ChIP-seq quality control +- `chipseq_analysis`: Complete ChIP-seq analysis +- `rnaseq_coverage`: Strand-specific RNA-seq coverage +- `atacseq`: ATAC-seq with Tn5 correction + +**Usage:** +```bash +# List workflows +python scripts/workflow_generator.py --list + +# Generate workflow +python scripts/workflow_generator.py chipseq_qc -o qc.sh \ + --input-bam Input.bam --chip-bams "ChIP1.bam ChIP2.bam" \ + --genome-size 2913022398 --threads 8 + +# Run generated workflow +chmod +x qc.sh +./qc.sh +``` + +**When to use:** Users request standard workflows or need template scripts to customize. + +## Assets + +### assets/quick_reference.md + +Quick reference card with most common commands, effective genome sizes, and typical workflow pattern. + +**When to use:** Users need quick command examples without detailed documentation. + +## Handling User Requests + +### For New Users + +1. Start with installation verification +2. Validate input files using `scripts/validate_files.py` +3. Recommend appropriate workflow based on experiment type +4. Generate workflow template using `scripts/workflow_generator.py` +5. Guide through customization and execution + +### For Experienced Users + +1. Provide specific tool commands for requested operations +2. Reference appropriate sections in `references/tools_reference.md` +3. Suggest optimizations and best practices +4. Offer troubleshooting for issues + +### For Specific Tasks + +**"Convert BAM to bigWig":** +- Use bamCoverage with appropriate normalization +- Recommend RPGC or CPM based on use case +- Provide effective genome size for organism +- Suggest relevant parameters (extendReads, ignoreDuplicates, binSize) + +**"Check ChIP quality":** +- Run full QC workflow or use plotFingerprint specifically +- Explain interpretation of results +- Suggest follow-up actions based on results + +**"Create heatmap":** +- Guide through two-step process: computeMatrix → plotHeatmap +- Help choose appropriate matrix mode (reference-point vs scale-regions) +- Suggest visualization parameters and clustering options + +**"Compare samples":** +- Recommend bamCompare for two-sample comparison +- Suggest multiBamSummary + plotCorrelation for multiple samples +- Guide normalization method selection + +### Referencing Documentation + +When users need detailed information: +- **Tool details**: Direct to specific sections in `references/tools_reference.md` +- **Workflows**: Use `references/workflows.md` for complete analysis pipelines +- **Normalization**: Consult `references/normalization_methods.md` for method selection +- **Genome sizes**: Reference `references/effective_genome_sizes.md` + +Search references using grep patterns: +```bash +# Find tool documentation +grep -A 20 "^### toolname" references/tools_reference.md + +# Find workflow +grep -A 50 "^## Workflow Name" references/workflows.md + +# Find normalization method +grep -A 15 "^### Method Name" references/normalization_methods.md +``` + +## Example Interactions + +**User: "I need to analyze my ChIP-seq data"** + +Response approach: +1. Ask about files available (BAM files, peaks, genes) +2. Validate files using validation script +3. Generate chipseq_analysis workflow template +4. Customize for their specific files and organism +5. Explain each step as script runs + +**User: "Which normalization should I use?"** + +Response approach: +1. Ask about experiment type (ChIP-seq, RNA-seq, etc.) +2. Ask about comparison goal (within-sample or between-sample) +3. Consult `references/normalization_methods.md` selection guide +4. Recommend appropriate method with justification +5. Provide command example with parameters + +**User: "Create a heatmap around TSS"** + +Response approach: +1. Verify bigWig and gene BED files available +2. Use computeMatrix with reference-point mode at TSS +3. Generate plotHeatmap with appropriate visualization parameters +4. Suggest clustering if dataset is large +5. Offer profile plot as complement + +## Key Reminders + +- **File validation first**: Always validate input files before analysis +- **Normalization matters**: Choose appropriate method for comparison type +- **Extend reads carefully**: YES for ChIP-seq, NO for RNA-seq +- **Use all cores**: Set `--numberOfProcessors` to available cores +- **Test on regions**: Use `--region` for parameter testing +- **Check QC first**: Run quality control before detailed analysis +- **Document everything**: Save commands for reproducibility +- **Reference documentation**: Use comprehensive references for detailed guidance diff --git a/scientific-packages/deeptools/assets/quick_reference.md b/scientific-packages/deeptools/assets/quick_reference.md new file mode 100644 index 0000000..31ce31d --- /dev/null +++ b/scientific-packages/deeptools/assets/quick_reference.md @@ -0,0 +1,58 @@ +# deepTools Quick Reference + +## Most Common Commands + +### BAM to bigWig (normalized) +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPGC --effectiveGenomeSize 2913022398 \ + --binSize 10 --numberOfProcessors 8 +``` + +### Compare two BAM files +```bash +bamCompare -b1 treatment.bam -b2 control.bam -o ratio.bw \ + --operation log2 --scaleFactorsMethod readCount +``` + +### Correlation heatmap +```bash +multiBamSummary bins --bamfiles *.bam -o counts.npz +plotCorrelation -in counts.npz --corMethod pearson \ + --whatToShow heatmap -o correlation.png +``` + +### Heatmap around TSS +```bash +computeMatrix reference-point -S signal.bw -R genes.bed \ + -b 3000 -a 3000 --referencePoint TSS -o matrix.gz + +plotHeatmap -m matrix.gz -o heatmap.png +``` + +### ChIP enrichment check +```bash +plotFingerprint -b input.bam chip.bam -o fingerprint.png \ + --extendReads 200 --ignoreDuplicates +``` + +## Effective Genome Sizes + +| Organism | Assembly | Size | +|----------|----------|------| +| Human | hg38 | 2913022398 | +| Mouse | mm10 | 2652783500 | +| Fly | dm6 | 142573017 | + +## Common Normalization Methods + +- **RPGC**: 1× genome coverage (requires --effectiveGenomeSize) +- **CPM**: Counts per million (for fixed bins) +- **RPKM**: Reads per kb per million (for genes) + +## Typical Workflow + +1. **QC**: plotFingerprint, plotCorrelation +2. **Coverage**: bamCoverage with normalization +3. **Comparison**: bamCompare for treatment vs control +4. **Visualization**: computeMatrix → plotHeatmap/plotProfile diff --git a/scientific-packages/deeptools/references/effective_genome_sizes.md b/scientific-packages/deeptools/references/effective_genome_sizes.md new file mode 100644 index 0000000..b4c7031 --- /dev/null +++ b/scientific-packages/deeptools/references/effective_genome_sizes.md @@ -0,0 +1,116 @@ +# Effective Genome Sizes + +## Definition + +Effective genome size refers to the length of the "mappable" genome - regions that can be uniquely mapped by sequencing reads. This metric is crucial for proper normalization in many deepTools commands. + +## Why It Matters + +- Required for RPGC normalization (`--normalizeUsing RPGC`) +- Affects accuracy of coverage calculations +- Must match your data processing approach (filtered vs unfiltered reads) + +## Calculation Methods + +1. **Non-N bases**: Count of non-N nucleotides in genome sequence +2. **Unique mappability**: Regions of specific size that can be uniquely mapped (may consider edit distance) + +## Common Organism Values + +### Using Non-N Bases Method + +| Organism | Assembly | Effective Size | Full Command | +|----------|----------|----------------|--------------| +| Human | GRCh38/hg38 | 2,913,022,398 | `--effectiveGenomeSize 2913022398` | +| Human | GRCh37/hg19 | 2,864,785,220 | `--effectiveGenomeSize 2864785220` | +| Mouse | GRCm39/mm39 | 2,654,621,837 | `--effectiveGenomeSize 2654621837` | +| Mouse | GRCm38/mm10 | 2,652,783,500 | `--effectiveGenomeSize 2652783500` | +| Zebrafish | GRCz11 | 1,368,780,147 | `--effectiveGenomeSize 1368780147` | +| *Drosophila* | dm6 | 142,573,017 | `--effectiveGenomeSize 142573017` | +| *C. elegans* | WBcel235/ce11 | 100,286,401 | `--effectiveGenomeSize 100286401` | +| *C. elegans* | ce10 | 100,258,171 | `--effectiveGenomeSize 100258171` | + +### Human (GRCh38) by Read Length + +For quality-filtered reads, values vary by read length: + +| Read Length | Effective Size | +|-------------|----------------| +| 50bp | ~2.7 billion | +| 75bp | ~2.8 billion | +| 100bp | ~2.8 billion | +| 150bp | ~2.9 billion | +| 250bp | ~2.9 billion | + +### Mouse (GRCm38) by Read Length + +| Read Length | Effective Size | +|-------------|----------------| +| 50bp | ~2.3 billion | +| 75bp | ~2.5 billion | +| 100bp | ~2.6 billion | + +## Usage in deepTools + +The effective genome size is most commonly used with: + +### bamCoverage with RPGC normalization +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 +``` + +### bamCompare with RPGC normalization +```bash +bamCompare -b1 treatment.bam -b2 control.bam \ + --outFileName comparison.bw \ + --scaleFactorsMethod RPGC \ + --effectiveGenomeSize 2913022398 +``` + +### computeGCBias / correctGCBias +```bash +computeGCBias --bamfile input.bam \ + --effectiveGenomeSize 2913022398 \ + --genome genome.2bit \ + --fragmentLength 200 \ + --biasPlot bias.png +``` + +## Choosing the Right Value + +**For most analyses:** Use the non-N bases method value for your reference genome + +**For filtered data:** If you apply strict quality filters or remove multimapping reads, consider using the read-length-specific values + +**When unsure:** Use the conservative non-N bases value - it's more widely applicable + +## Common Shortcuts + +deepTools also accepts these shorthand values in some contexts: + +- `hs` or `GRCh38`: 2913022398 +- `mm` or `GRCm38`: 2652783500 +- `dm` or `dm6`: 142573017 +- `ce` or `ce10`: 100286401 + +Check your specific deepTools version documentation for supported shortcuts. + +## Calculating Custom Values + +For custom genomes or assemblies, calculate the non-N bases count: + +```bash +# Using faCount (UCSC tools) +faCount genome.fa | grep "total" | awk '{print $2-$7}' + +# Using seqtk +seqtk comp genome.fa | awk '{x+=$2}END{print x}' +``` + +## References + +For the most up-to-date effective genome sizes and detailed calculation methods, see: +- deepTools documentation: https://deeptools.readthedocs.io/en/latest/content/feature/effectiveGenomeSize.html +- ENCODE documentation for reference genome details diff --git a/scientific-packages/deeptools/references/normalization_methods.md b/scientific-packages/deeptools/references/normalization_methods.md new file mode 100644 index 0000000..dd84096 --- /dev/null +++ b/scientific-packages/deeptools/references/normalization_methods.md @@ -0,0 +1,410 @@ +# deepTools Normalization Methods + +This document explains the various normalization methods available in deepTools and when to use each one. + +## Why Normalize? + +Normalization is essential for: +1. **Comparing samples with different sequencing depths** +2. **Accounting for library size differences** +3. **Making coverage values interpretable across experiments** +4. **Enabling fair comparisons between conditions** + +Without normalization, a sample with 100 million reads will appear to have higher coverage than a sample with 50 million reads, even if the true biological signal is identical. + +--- + +## Available Normalization Methods + +### 1. RPKM (Reads Per Kilobase per Million mapped reads) + +**Formula:** `(Number of reads) / (Length of region in kb × Total mapped reads in millions)` + +**When to use:** +- Comparing different genomic regions within the same sample +- Adjusting for both sequencing depth AND region length +- RNA-seq gene expression analysis + +**Available in:** `bamCoverage` + +**Example:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPKM +``` + +**Interpretation:** RPKM of 10 means 10 reads per kilobase of feature per million mapped reads. + +**Pros:** +- Accounts for both region length and library size +- Widely used and understood in genomics + +**Cons:** +- Not ideal for comparing between samples if total RNA content differs +- Can be misleading when comparing samples with very different compositions + +--- + +### 2. CPM (Counts Per Million mapped reads) + +**Formula:** `(Number of reads) / (Total mapped reads in millions)` + +**Also known as:** RPM (Reads Per Million) + +**When to use:** +- Comparing the same genomic regions across different samples +- When region length is constant or not relevant +- ChIP-seq, ATAC-seq, DNase-seq analyses + +**Available in:** `bamCoverage`, `bamCompare` + +**Example:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing CPM +``` + +**Interpretation:** CPM of 5 means 5 reads per million mapped reads in that bin. + +**Pros:** +- Simple and intuitive +- Good for comparing samples with different sequencing depths +- Appropriate when comparing fixed-size bins + +**Cons:** +- Does not account for region length +- Affected by highly abundant regions (e.g., rRNA in RNA-seq) + +--- + +### 3. BPM (Bins Per Million mapped reads) + +**Formula:** `(Number of reads in bin) / (Sum of all reads in bins in millions)` + +**Key difference from CPM:** Only considers reads that fall within the analyzed bins, not all mapped reads. + +**When to use:** +- Similar to CPM, but when you want to exclude reads outside analyzed regions +- Comparing specific genomic regions while ignoring background + +**Available in:** `bamCoverage`, `bamCompare` + +**Example:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing BPM +``` + +**Interpretation:** BPM accounts only for reads in the binned regions. + +**Pros:** +- Focuses normalization on analyzed regions +- Less affected by reads in unanalyzed areas + +**Cons:** +- Less commonly used, may be harder to compare with published data + +--- + +### 4. RPGC (Reads Per Genomic Content) + +**Formula:** `(Number of reads × Scaling factor) / Effective genome size` + +**Scaling factor:** Calculated to achieve 1× genomic coverage (1 read per base) + +**When to use:** +- Want comparable coverage values across samples +- Need interpretable absolute coverage values +- Comparing samples with very different total read counts +- ChIP-seq with spike-in normalization context + +**Available in:** `bamCoverage`, `bamCompare` + +**Requires:** `--effectiveGenomeSize` parameter + +**Example:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 +``` + +**Interpretation:** Signal value approximates the coverage depth (e.g., value of 2 ≈ 2× coverage). + +**Pros:** +- Produces 1× normalized coverage +- Interpretable in terms of genomic coverage +- Good for comparing samples with different sequencing depths + +**Cons:** +- Requires knowing effective genome size +- Assumes uniform coverage (not true for ChIP-seq with peaks) + +--- + +### 5. None (No Normalization) + +**Formula:** Raw read counts + +**When to use:** +- Preliminary analysis +- When samples have identical library sizes (rare) +- When downstream tool will perform normalization +- Debugging or quality control + +**Available in:** All tools (usually default) + +**Example:** +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing None +``` + +**Interpretation:** Raw read counts per bin. + +**Pros:** +- No assumptions made +- Useful for seeing raw data +- Fastest computation + +**Cons:** +- Cannot fairly compare samples with different sequencing depths +- Not suitable for publication figures + +--- + +### 6. SES (Selective Enrichment Statistics) + +**Method:** Signal Extraction Scaling - more sophisticated method for comparing ChIP to control + +**When to use:** +- ChIP-seq analysis with bamCompare +- Want sophisticated background correction +- Alternative to simple readCount scaling + +**Available in:** `bamCompare` only + +**Example:** +```bash +bamCompare -b1 chip.bam -b2 input.bam -o output.bw \ + --scaleFactorsMethod SES +``` + +**Note:** SES is specifically designed for ChIP-seq data and may work better than simple read count scaling for noisy data. + +--- + +### 7. readCount (Read Count Scaling) + +**Method:** Scale by ratio of total read counts between samples + +**When to use:** +- Default for `bamCompare` +- Compensating for sequencing depth differences in comparisons +- When you trust that total read counts reflect library size + +**Available in:** `bamCompare` + +**Example:** +```bash +bamCompare -b1 treatment.bam -b2 control.bam -o output.bw \ + --scaleFactorsMethod readCount +``` + +**How it works:** If sample1 has 100M reads and sample2 has 50M reads, sample2 is scaled by 2× before comparison. + +--- + +## Normalization Method Selection Guide + +### For ChIP-seq Coverage Tracks + +**Recommended:** RPGC or CPM + +```bash +bamCoverage --bam chip.bam --outFileName chip.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --extendReads 200 \ + --ignoreDuplicates +``` + +**Reasoning:** Accounts for sequencing depth differences; RPGC provides interpretable coverage values. + +--- + +### For ChIP-seq Comparisons (Treatment vs Control) + +**Recommended:** log2 ratio with readCount or SES scaling + +```bash +bamCompare -b1 chip.bam -b2 input.bam -o ratio.bw \ + --operation log2 \ + --scaleFactorsMethod readCount \ + --extendReads 200 \ + --ignoreDuplicates +``` + +**Reasoning:** Log2 ratio shows enrichment (positive) and depletion (negative); readCount adjusts for depth. + +--- + +### For RNA-seq Coverage Tracks + +**Recommended:** CPM or RPKM + +```bash +# Strand-specific forward +bamCoverage --bam rnaseq.bam --outFileName forward.bw \ + --normalizeUsing CPM \ + --filterRNAstrand forward + +# For gene-level: RPKM accounts for gene length +bamCoverage --bam rnaseq.bam --outFileName output.bw \ + --normalizeUsing RPKM +``` + +**Reasoning:** CPM for comparing fixed-width bins; RPKM for genes (accounts for length). + +--- + +### For ATAC-seq + +**Recommended:** RPGC or CPM + +```bash +bamCoverage --bam atac_shifted.bam --outFileName atac.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 +``` + +**Reasoning:** Similar to ChIP-seq; want comparable coverage across samples. + +--- + +### For Sample Correlation Analysis + +**Recommended:** CPM or RPGC + +```bash +multiBamSummary bins \ + --bamfiles sample1.bam sample2.bam sample3.bam \ + -o readCounts.npz + +plotCorrelation -in readCounts.npz \ + --corMethod pearson \ + --whatToShow heatmap \ + -o correlation.png +``` + +**Note:** `multiBamSummary` doesn't explicitly normalize, but correlation analysis is robust to scaling. For very different library sizes, consider normalizing BAM files first or using CPM-normalized bigWig files with `multiBigwigSummary`. + +--- + +## Advanced Normalization Considerations + +### Spike-in Normalization + +For experiments with spike-in controls (e.g., *Drosophila* chromatin spike-in for ChIP-seq): + +1. Calculate scaling factors from spike-in reads +2. Apply custom scaling factors using `--scaleFactor` parameter + +```bash +# Calculate spike-in factor (example: 0.8) +SCALE_FACTOR=0.8 + +bamCoverage --bam chip.bam --outFileName chip_spikenorm.bw \ + --scaleFactor ${SCALE_FACTOR} \ + --extendReads 200 +``` + +--- + +### Manual Scaling Factors + +You can apply custom scaling factors: + +```bash +# Apply 2× scaling +bamCoverage --bam input.bam --outFileName output.bw \ + --scaleFactor 2.0 +``` + +--- + +### Chromosome Exclusion + +Exclude specific chromosomes from normalization calculations: + +```bash +bamCoverage --bam input.bam --outFileName output.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --ignoreForNormalization chrX chrY chrM +``` + +**When to use:** Sex chromosomes in mixed-sex samples, mitochondrial DNA, or chromosomes with unusual coverage. + +--- + +## Common Pitfalls + +### 1. Using RPKM for bin-based data +**Problem:** RPKM accounts for region length, but all bins are the same size +**Solution:** Use CPM or RPGC instead + +### 2. Comparing unnormalized samples +**Problem:** Sample with 2× sequencing depth appears to have 2× signal +**Solution:** Always normalize when comparing samples + +### 3. Wrong effective genome size +**Problem:** Using hg19 genome size for hg38 data +**Solution:** Double-check genome assembly and use correct size + +### 4. Ignoring duplicates after GC correction +**Problem:** Can introduce bias +**Solution:** Never use `--ignoreDuplicates` after `correctGCBias` + +### 5. Using RPGC without effective genome size +**Problem:** Command fails +**Solution:** Always specify `--effectiveGenomeSize` with RPGC + +--- + +## Normalization for Different Comparisons + +### Within-sample comparisons (different regions) +**Use:** RPKM (accounts for region length) + +### Between-sample comparisons (same regions) +**Use:** CPM, RPGC, or BPM (accounts for library size) + +### Treatment vs Control +**Use:** bamCompare with log2 ratio and readCount/SES scaling + +### Multiple samples correlation +**Use:** CPM or RPGC normalized bigWig files, then multiBigwigSummary + +--- + +## Quick Reference Table + +| Method | Accounts for Depth | Accounts for Length | Best For | Command | +|--------|-------------------|---------------------|----------|---------| +| RPKM | ✓ | ✓ | RNA-seq genes | `--normalizeUsing RPKM` | +| CPM | ✓ | ✗ | Fixed-size bins | `--normalizeUsing CPM` | +| BPM | ✓ | ✗ | Specific regions | `--normalizeUsing BPM` | +| RPGC | ✓ | ✗ | Interpretable coverage | `--normalizeUsing RPGC --effectiveGenomeSize X` | +| None | ✗ | ✗ | Raw data | `--normalizeUsing None` | +| SES | ✓ | ✗ | ChIP comparisons | `bamCompare --scaleFactorsMethod SES` | +| readCount | ✓ | ✗ | ChIP comparisons | `bamCompare --scaleFactorsMethod readCount` | + +--- + +## Further Reading + +For more details on normalization theory and best practices: +- deepTools documentation: https://deeptools.readthedocs.io/ +- ENCODE guidelines for ChIP-seq analysis +- RNA-seq normalization papers (DESeq2, TMM methods) diff --git a/scientific-packages/deeptools/references/tools_reference.md b/scientific-packages/deeptools/references/tools_reference.md new file mode 100644 index 0000000..6b39614 --- /dev/null +++ b/scientific-packages/deeptools/references/tools_reference.md @@ -0,0 +1,533 @@ +# deepTools Complete Tool Reference + +This document provides a comprehensive reference for all deepTools command-line utilities organized by category. + +## BAM and bigWig File Processing Tools + +### multiBamSummary + +Computes read coverages for genomic regions across multiple BAM files, outputting compressed numpy arrays for downstream correlation and PCA analysis. + +**Modes:** +- **bins**: Genome-wide analysis using consecutive equal-sized windows (default 10kb) +- **BED-file**: Restricts analysis to user-specified genomic regions + +**Key Parameters:** +- `--bamfiles, -b`: Indexed BAM files (space-separated, required) +- `--outFileName, -o`: Output coverage matrix file (required) +- `--BED`: Region specification file (BED-file mode only) +- `--binSize`: Window size in bases (default: 10,000) +- `--labels`: Custom sample identifiers +- `--minMappingQuality`: Quality threshold for read inclusion +- `--numberOfProcessors, -p`: Parallel processing cores +- `--extendReads`: Fragment size extension +- `--ignoreDuplicates`: Remove PCR duplicates +- `--outRawCounts`: Export tab-delimited file with coordinate columns and per-sample counts + +**Output:** Compressed numpy array (.npz) for plotCorrelation and plotPCA + +**Common Usage:** +```bash +# Genome-wide comparison +multiBamSummary bins --bamfiles sample1.bam sample2.bam -o results.npz + +# Peak region comparison +multiBamSummary BED-file --BED peaks.bed --bamfiles sample1.bam sample2.bam -o results.npz +``` + +--- + +### multiBigwigSummary + +Similar to multiBamSummary but operates on bigWig files instead of BAM files. Used for comparing coverage tracks across samples. + +**Modes:** +- **bins**: Genome-wide analysis +- **BED-file**: Region-specific analysis + +**Key Parameters:** Similar to multiBamSummary but accepts bigWig files + +--- + +### bamCoverage + +Converts BAM alignment files into normalized coverage tracks in bigWig or bedGraph formats. Calculates coverage as number of reads per bin. + +**Key Parameters:** +- `--bam, -b`: Input BAM file (required) +- `--outFileName, -o`: Output filename (required) +- `--outFileFormat, -of`: Output type (bigwig or bedgraph) +- `--normalizeUsing`: Normalization method + - **RPKM**: Reads Per Kilobase per Million mapped reads + - **CPM**: Counts Per Million mapped reads + - **BPM**: Bins Per Million mapped reads + - **RPGC**: Reads per genomic content (requires --effectiveGenomeSize) + - **None**: No normalization (default) +- `--effectiveGenomeSize`: Mappable genome size (required for RPGC) +- `--binSize`: Resolution in base pairs (default: 50) +- `--extendReads, -e`: Extend reads to fragment length (recommended for ChIP-seq, NOT for RNA-seq) +- `--centerReads`: Center reads at fragment length for sharper signals +- `--ignoreDuplicates`: Count identical reads only once +- `--minMappingQuality`: Filter reads below quality threshold +- `--minFragmentLength / --maxFragmentLength`: Fragment length filtering +- `--smoothLength`: Window averaging for noise reduction +- `--MNase`: Analyze MNase-seq data for nucleosome positioning +- `--Offset`: Position-specific offsets (useful for RiboSeq, GROseq) +- `--filterRNAstrand`: Separate forward/reverse strand reads +- `--ignoreForNormalization`: Exclude chromosomes from normalization (e.g., sex chromosomes) +- `--numberOfProcessors, -p`: Parallel processing + +**Important Notes:** +- For RNA-seq: Do NOT use --extendReads (would extend over splice junctions) +- For ChIP-seq: Use --extendReads with smaller bin sizes +- Never apply --ignoreDuplicates after GC bias correction + +**Common Usage:** +```bash +# Basic coverage with RPKM normalization +bamCoverage --bam input.bam --outFileName coverage.bw --normalizeUsing RPKM + +# ChIP-seq with extension +bamCoverage --bam chip.bam --outFileName chip_coverage.bw \ + --binSize 10 --extendReads 200 --ignoreDuplicates + +# Strand-specific RNA-seq +bamCoverage --bam rnaseq.bam --outFileName forward.bw \ + --filterRNAstrand forward +``` + +--- + +### bamCompare + +Compares two BAM files by generating bigWig or bedGraph files, normalizing for sequencing depth differences. Processes genome in equal-sized bins and performs per-bin calculations. + +**Comparison Methods:** +- **log2** (default): Log2 ratio of samples +- **ratio**: Direct ratio calculation +- **subtract**: Difference between files +- **add**: Sum of samples +- **mean**: Average across samples +- **reciprocal_ratio**: Negative inverse for ratios < 0 +- **first/second**: Output scaled signal from single file + +**Normalization Methods:** +- **readCount** (default): Compensates for sequencing depth +- **SES**: Selective enrichment statistics +- **RPKM**: Reads per kilobase per million +- **CPM**: Counts per million +- **BPM**: Bins per million +- **RPGC**: Reads per genomic content (requires --effectiveGenomeSize) + +**Key Parameters:** +- `--bamfile1, -b1`: First BAM file (required) +- `--bamfile2, -b2`: Second BAM file (required) +- `--outFileName, -o`: Output filename (required) +- `--outFileFormat`: bigwig or bedgraph +- `--operation`: Comparison method (see above) +- `--scaleFactorsMethod`: Normalization method (see above) +- `--binSize`: Bin width for output (default: 50bp) +- `--pseudocount`: Avoid division by zero (default: 1) +- `--extendReads`: Extend reads to fragment length +- `--ignoreDuplicates`: Count identical reads once +- `--minMappingQuality`: Quality threshold +- `--numberOfProcessors, -p`: Parallelization + +**Common Usage:** +```bash +# Log2 ratio of treatment vs control +bamCompare -b1 treatment.bam -b2 control.bam -o log2ratio.bw + +# Subtract control from treatment +bamCompare -b1 treatment.bam -b2 control.bam -o difference.bw \ + --operation subtract --scaleFactorsMethod readCount +``` + +--- + +### correctGCBias / computeGCBias + +**computeGCBias:** Identifies GC-content bias from sequencing and PCR amplification. + +**correctGCBias:** Corrects BAM files for GC bias detected by computeGCBias. + +**Key Parameters (computeGCBias):** +- `--bamfile, -b`: Input BAM file +- `--effectiveGenomeSize`: Mappable genome size +- `--genome, -g`: Reference genome in 2bit format +- `--fragmentLength, -l`: Fragment length (for single-end) +- `--biasPlot`: Output diagnostic plot + +**Key Parameters (correctGCBias):** +- `--bamfile, -b`: Input BAM file +- `--effectiveGenomeSize`: Mappable genome size +- `--genome, -g`: Reference genome in 2bit format +- `--GCbiasFrequenciesFile`: Frequencies from computeGCBias +- `--correctedFile, -o`: Output corrected BAM + +**Important:** Never use --ignoreDuplicates after GC bias correction + +--- + +### alignmentSieve + +Filters BAM files by various quality metrics on-the-fly. Useful for creating filtered BAM files for specific analyses. + +**Key Parameters:** +- `--bam, -b`: Input BAM file +- `--outFile, -o`: Output BAM file +- `--minMappingQuality`: Minimum mapping quality +- `--ignoreDuplicates`: Remove duplicates +- `--minFragmentLength / --maxFragmentLength`: Fragment length filters +- `--samFlagInclude / --samFlagExclude`: SAM flag filtering +- `--shift`: Shift reads (e.g., for ATACseq Tn5 correction) +- `--ATACshift`: Automatically shift for ATAC-seq data + +--- + +### computeMatrix + +Calculates scores per genomic region and prepares matrices for plotHeatmap and plotProfile. Processes bigWig score files and BED/GTF region files. + +**Modes:** +- **reference-point**: Signal distribution relative to specific position (TSS, TES, or center) +- **scale-regions**: Signal across regions standardized to uniform lengths + +**Key Parameters:** +- `-R`: Region file(s) in BED/GTF format (required) +- `-S`: BigWig score file(s) (required) +- `-o`: Output matrix file (required) +- `-b`: Upstream distance from reference point +- `-a`: Downstream distance from reference point +- `-m`: Region body length (scale-regions only) +- `-bs, --binSize`: Bin size for averaging scores +- `--skipZeros`: Skip regions with all zeros +- `--minThreshold / --maxThreshold`: Filter by signal intensity +- `--sortRegions`: ascending, descending, keep, no +- `--sortUsing`: mean, median, max, min, sum, region_length +- `-p, --numberOfProcessors`: Parallel processing +- `--averageTypeBins`: Statistical method (mean, median, min, max, sum, std) + +**Output Options:** +- `--outFileNameMatrix`: Export tab-delimited data +- `--outFileSortedRegions`: Save filtered/sorted BED file + +**Common Usage:** +```bash +# TSS analysis +computeMatrix reference-point -S signal.bw -R genes.bed \ + -o matrix.gz -b 2000 -a 2000 --referencePoint TSS + +# Scaled gene body +computeMatrix scale-regions -S signal.bw -R genes.bed \ + -o matrix.gz -b 1000 -a 1000 -m 3000 +``` + +--- + +## Quality Control Tools + +### plotFingerprint + +Quality control tool primarily for ChIP-seq experiments. Assesses whether antibody enrichment was successful. Generates cumulative read coverage profiles to distinguish signal from noise. + +**Key Parameters:** +- `--bamfiles, -b`: Indexed BAM files (required) +- `--plotFile, -plot, -o`: Output image filename (required) +- `--extendReads, -e`: Extend reads to fragment length +- `--ignoreDuplicates`: Count identical reads once +- `--minMappingQuality`: Mapping quality filter +- `--centerReads`: Center reads at fragment length +- `--minFragmentLength / --maxFragmentLength`: Fragment filters +- `--outRawCounts`: Save per-bin read counts +- `--outQualityMetrics`: Output QC metrics (Jensen-Shannon distance) +- `--labels`: Custom sample names +- `--numberOfProcessors, -p`: Parallel processing + +**Interpretation:** +- Ideal control: Straight diagonal line +- Strong ChIP: Steep rise towards highest rank (concentrated reads in few bins) +- Weak enrichment: Flatter curve approaching diagonal + +**Common Usage:** +```bash +plotFingerprint -b input.bam chip1.bam chip2.bam \ + --labels Input ChIP1 ChIP2 -o fingerprint.png \ + --extendReads 200 --ignoreDuplicates +``` + +--- + +### plotCoverage + +Visualizes average read distribution across the genome. Shows genome coverage and helps determine if sequencing depth is adequate. + +**Key Parameters:** +- `--bamfiles, -b`: BAM files to analyze (required) +- `--plotFile, -o`: Output plot filename (required) +- `--ignoreDuplicates`: Remove PCR duplicates +- `--minMappingQuality`: Quality threshold +- `--outRawCounts`: Save underlying data +- `--labels`: Sample names +- `--numberOfSamples`: Number of positions to sample (default: 1,000,000) + +--- + +### bamPEFragmentSize + +Determines fragment length distribution for paired-end sequencing data. Essential QC to verify expected fragment sizes from library preparation. + +**Key Parameters:** +- `--bamfiles, -b`: BAM files (required) +- `--histogram, -hist`: Output histogram filename (required) +- `--plotTitle, -T`: Plot title +- `--maxFragmentLength`: Maximum length to consider (default: 1000) +- `--logScale`: Use logarithmic Y-axis +- `--outRawFragmentLengths`: Save raw fragment lengths + +--- + +### plotCorrelation + +Analyzes sample correlations from multiBamSummary or multiBigwigSummary outputs. Shows how similar different samples are. + +**Correlation Methods:** +- **Pearson**: Measures metric differences; sensitive to outliers; appropriate for normally distributed data +- **Spearman**: Rank-based; less influenced by outliers; better for non-normal distributions + +**Visualization Options:** +- **heatmap**: Color intensity with hierarchical clustering (complete linkage) +- **scatterplot**: Pairwise scatter plots with correlation coefficients + +**Key Parameters:** +- `--corData, -in`: Input matrix from multiBamSummary/multiBigwigSummary (required) +- `--corMethod`: pearson or spearman (required) +- `--whatToShow`: heatmap or scatterplot (required) +- `--plotFile, -o`: Output filename (required) +- `--skipZeros`: Exclude zero-value regions +- `--removeOutliers`: Use median absolute deviation (MAD) filtering +- `--outFileCorMatrix`: Export correlation matrix +- `--labels`: Custom sample names +- `--plotTitle`: Plot title +- `--colorMap`: Color scheme (50+ options) +- `--plotNumbers`: Display correlation values on heatmap + +**Common Usage:** +```bash +# Heatmap with Pearson correlation +plotCorrelation -in readCounts.npz --corMethod pearson \ + --whatToShow heatmap -o correlation_heatmap.png --plotNumbers + +# Scatterplot with Spearman correlation +plotCorrelation -in readCounts.npz --corMethod spearman \ + --whatToShow scatterplot -o correlation_scatter.png +``` + +--- + +### plotPCA + +Generates principal component analysis plots from multiBamSummary or multiBigwigSummary output. Displays sample relationships in reduced dimensionality. + +**Key Parameters:** +- `--corData, -in`: Coverage file from multiBamSummary/multiBigwigSummary (required) +- `--plotFile, -o`: Output image (png, eps, pdf, svg) (required) +- `--outFileNameData`: Export PCA data (loadings/rotation and eigenvalues) +- `--labels, -l`: Custom sample labels +- `--plotTitle, -T`: Plot title +- `--plotHeight / --plotWidth`: Dimensions in centimeters +- `--colors`: Custom symbol colors +- `--markers`: Symbol shapes +- `--transpose`: Perform PCA on transposed matrix (rows=samples) +- `--ntop`: Use top N variable rows (default: 1000) +- `--PCs`: Components to plot (default: 1 2) +- `--log2`: Log2-transform data before analysis +- `--rowCenter`: Center each row at 0 + +**Common Usage:** +```bash +plotPCA -in readCounts.npz -o PCA_plot.png \ + -T "PCA of read counts" --transpose +``` + +--- + +## Visualization Tools + +### plotHeatmap + +Creates genomic region heatmaps from computeMatrix output. Generates publication-quality visualizations. + +**Key Parameters:** +- `--matrixFile, -m`: Matrix from computeMatrix (required) +- `--outFileName, -o`: Output image (png, eps, pdf, svg) (required) +- `--outFileSortedRegions`: Save regions after filtering +- `--outFileNameMatrix`: Export matrix values +- `--interpolationMethod`: auto, nearest, bilinear, bicubic, gaussian + - Default: nearest (≤1000 columns), bilinear (>1000 columns) +- `--dpi`: Figure resolution + +**Clustering:** +- `--kmeans`: k-means clustering +- `--hclust`: Hierarchical clustering (slower for >1000 regions) +- `--silhouette`: Calculate cluster quality metrics + +**Visual Customization:** +- `--heatmapHeight / --heatmapWidth`: Dimensions (3-100 cm) +- `--whatToShow`: plot, heatmap, colorbar (combinations) +- `--alpha`: Transparency (0-1) +- `--colorMap`: 50+ color schemes +- `--colorList`: Custom gradient colors +- `--zMin / --zMax`: Intensity scale limits +- `--boxAroundHeatmaps`: yes/no (default: yes) + +**Labels:** +- `--xAxisLabel / --yAxisLabel`: Axis labels +- `--regionsLabel`: Region set identifiers +- `--samplesLabel`: Sample names +- `--refPointLabel`: Reference point label +- `--startLabel / --endLabel`: Region boundary labels + +**Common Usage:** +```bash +# Basic heatmap +plotHeatmap -m matrix.gz -o heatmap.png + +# With clustering and custom colors +plotHeatmap -m matrix.gz -o heatmap.png \ + --kmeans 3 --colorMap RdBu --zMin -3 --zMax 3 +``` + +--- + +### plotProfile + +Generates profile plots showing scores across genomic regions using computeMatrix output. + +**Key Parameters:** +- `--matrixFile, -m`: Matrix from computeMatrix (required) +- `--outFileName, -o`: Output image (png, eps, pdf, svg) (required) +- `--plotType`: lines, fill, se, std, overlapped_lines, heatmap +- `--colors`: Color palette (names or hex codes) +- `--plotHeight / --plotWidth`: Dimensions in centimeters +- `--yMin / --yMax`: Y-axis range +- `--averageType`: mean, median, min, max, std, sum + +**Clustering:** +- `--kmeans`: k-means clustering +- `--hclust`: Hierarchical clustering +- `--silhouette`: Cluster quality metrics + +**Labels:** +- `--plotTitle`: Main heading +- `--regionsLabel`: Region set identifiers +- `--samplesLabel`: Sample names +- `--startLabel / --endLabel`: Region boundary labels (scale-regions mode) + +**Output Options:** +- `--outFileNameData`: Export data as tab-separated values +- `--outFileSortedRegions`: Save filtered/sorted regions as BED + +**Common Usage:** +```bash +# Line plot +plotProfile -m matrix.gz -o profile.png --plotType lines + +# With standard error shading +plotProfile -m matrix.gz -o profile.png --plotType se \ + --colors blue red green +``` + +--- + +### plotEnrichment + +Calculates and visualizes signal enrichment across genomic regions. Measures percentage of alignments overlapping region groups. Useful for FRiP (Fragment in Peaks) scores. + +**Key Parameters:** +- `--bamfiles, -b`: Indexed BAM files (required) +- `--BED`: Region files in BED/GTF format (required) +- `--plotFile, -o`: Output visualization (png, pdf, eps, svg) +- `--labels, -l`: Custom sample identifiers +- `--outRawCounts`: Export numerical data +- `--perSample`: Group by sample instead of feature (default) +- `--regionLabels`: Custom region names + +**Read Processing:** +- `--minFragmentLength / --maxFragmentLength`: Fragment filters +- `--minMappingQuality`: Quality threshold +- `--samFlagInclude / --samFlagExclude`: SAM flag filters +- `--ignoreDuplicates`: Remove duplicates +- `--centerReads`: Center reads for sharper signal + +**Common Usage:** +```bash +plotEnrichment -b Input.bam H3K4me3.bam \ + --BED peaks_up.bed peaks_down.bed \ + --regionLabels "Up regulated" "Down regulated" \ + -o enrichment.png +``` + +--- + +## Miscellaneous Tools + +### computeMatrixOperations + +Advanced matrix manipulation tool for combining or subsetting matrices from computeMatrix. Enables complex multi-sample, multi-region analyses. + +**Operations:** +- `cbind`: Combine matrices column-wise +- `rbind`: Combine matrices row-wise +- `subset`: Extract specific samples or regions +- `filterStrand`: Keep only regions on specific strand +- `filterValues`: Apply signal intensity filters +- `sort`: Order regions by various criteria +- `dataRange`: Report min/max values + +**Common Usage:** +```bash +# Combine matrices +computeMatrixOperations cbind -m matrix1.gz matrix2.gz -o combined.gz + +# Extract specific samples +computeMatrixOperations subset -m matrix.gz --samples 0 2 -o subset.gz +``` + +--- + +### estimateReadFiltering + +Predicts the impact of various filtering parameters without actually filtering. Helps optimize filtering strategies before running full analyses. + +**Key Parameters:** +- `--bamfiles, -b`: BAM files to analyze +- `--sampleSize`: Number of reads to sample (default: 100,000) +- `--binSize`: Bin size for analysis +- `--distanceBetweenBins`: Spacing between sampled bins + +**Filtration Options to Test:** +- `--minMappingQuality`: Test quality thresholds +- `--ignoreDuplicates`: Assess duplicate impact +- `--minFragmentLength / --maxFragmentLength`: Test fragment filters + +--- + +## Common Parameters Across Tools + +Many deepTools commands share these filtering and performance options: + +**Read Filtering:** +- `--ignoreDuplicates`: Remove PCR duplicates +- `--minMappingQuality`: Filter by alignment confidence +- `--samFlagInclude / --samFlagExclude`: SAM format filtering +- `--minFragmentLength / --maxFragmentLength`: Fragment length bounds + +**Performance:** +- `--numberOfProcessors, -p`: Enable parallel processing +- `--region`: Process specific genomic regions (chr:start-end) + +**Read Processing:** +- `--extendReads`: Extend to fragment length +- `--centerReads`: Center at fragment midpoint +- `--ignoreDuplicates`: Count unique reads only diff --git a/scientific-packages/deeptools/references/workflows.md b/scientific-packages/deeptools/references/workflows.md new file mode 100644 index 0000000..2bcc6ff --- /dev/null +++ b/scientific-packages/deeptools/references/workflows.md @@ -0,0 +1,474 @@ +# deepTools Common Workflows + +This document provides complete workflow examples for common deepTools analyses. + +## ChIP-seq Quality Control Workflow + +Complete quality control assessment for ChIP-seq experiments. + +### Step 1: Initial Correlation Assessment + +Compare replicates and samples to verify experimental quality: + +```bash +# Generate coverage matrix across genome +multiBamSummary bins \ + --bamfiles Input1.bam Input2.bam ChIP1.bam ChIP2.bam \ + --labels Input_rep1 Input_rep2 ChIP_rep1 ChIP_rep2 \ + -o readCounts.npz \ + --numberOfProcessors 8 + +# Create correlation heatmap +plotCorrelation \ + -in readCounts.npz \ + --corMethod pearson \ + --whatToShow heatmap \ + --plotFile correlation_heatmap.png \ + --plotNumbers + +# Generate PCA plot +plotPCA \ + -in readCounts.npz \ + -o PCA_plot.png \ + -T "PCA of ChIP-seq samples" +``` + +**Expected Results:** +- Replicates should cluster together +- Input samples should be distinct from ChIP samples + +--- + +### Step 2: Coverage and Depth Assessment + +```bash +# Check sequencing depth and coverage +plotCoverage \ + --bamfiles Input1.bam ChIP1.bam ChIP2.bam \ + --labels Input ChIP_rep1 ChIP_rep2 \ + --plotFile coverage.png \ + --ignoreDuplicates \ + --numberOfProcessors 8 +``` + +**Interpretation:** Assess whether sequencing depth is adequate for downstream analysis. + +--- + +### Step 3: Fragment Size Validation (Paired-end) + +```bash +# Verify expected fragment sizes +bamPEFragmentSize \ + --bamfiles Input1.bam ChIP1.bam ChIP2.bam \ + --histogram fragmentSizes.png \ + --plotTitle "Fragment Size Distribution" +``` + +**Expected Results:** Fragment sizes should match library preparation protocols (typically 200-600bp for ChIP-seq). + +--- + +### Step 4: GC Bias Detection and Correction + +```bash +# Compute GC bias +computeGCBias \ + --bamfile ChIP1.bam \ + --effectiveGenomeSize 2913022398 \ + --genome genome.2bit \ + --fragmentLength 200 \ + --biasPlot GCbias.png \ + --frequenciesFile freq.txt + +# If bias detected, correct it +correctGCBias \ + --bamfile ChIP1.bam \ + --effectiveGenomeSize 2913022398 \ + --genome genome.2bit \ + --GCbiasFrequenciesFile freq.txt \ + --correctedFile ChIP1_GCcorrected.bam +``` + +**Note:** Only correct if significant bias is observed. Do NOT use `--ignoreDuplicates` with GC-corrected files. + +--- + +### Step 5: ChIP Signal Strength Assessment + +```bash +# Evaluate ChIP enrichment quality +plotFingerprint \ + --bamfiles Input1.bam ChIP1.bam ChIP2.bam \ + --labels Input ChIP_rep1 ChIP_rep2 \ + --plotFile fingerprint.png \ + --extendReads 200 \ + --ignoreDuplicates \ + --numberOfProcessors 8 \ + --outQualityMetrics fingerprint_metrics.txt +``` + +**Interpretation:** +- Strong ChIP: Steep rise in cumulative curve +- Weak enrichment: Curve close to diagonal (input-like) + +--- + +## ChIP-seq Analysis Workflow + +Complete workflow from BAM files to publication-quality visualizations. + +### Step 1: Generate Normalized Coverage Tracks + +```bash +# Input control +bamCoverage \ + --bam Input.bam \ + --outFileName Input_coverage.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --binSize 10 \ + --extendReads 200 \ + --ignoreDuplicates \ + --numberOfProcessors 8 + +# ChIP sample +bamCoverage \ + --bam ChIP.bam \ + --outFileName ChIP_coverage.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --binSize 10 \ + --extendReads 200 \ + --ignoreDuplicates \ + --numberOfProcessors 8 +``` + +--- + +### Step 2: Create Log2 Ratio Track + +```bash +# Compare ChIP to Input +bamCompare \ + --bamfile1 ChIP.bam \ + --bamfile2 Input.bam \ + --outFileName ChIP_vs_Input_log2ratio.bw \ + --operation log2 \ + --scaleFactorsMethod readCount \ + --binSize 10 \ + --extendReads 200 \ + --ignoreDuplicates \ + --numberOfProcessors 8 +``` + +**Result:** Log2 ratio track showing enrichment (positive values) and depletion (negative values). + +--- + +### Step 3: Compute Matrix Around TSS + +```bash +# Prepare data for heatmap/profile around transcription start sites +computeMatrix reference-point \ + --referencePoint TSS \ + --scoreFileName ChIP_coverage.bw \ + --regionsFileName genes.bed \ + --beforeRegionStartLength 3000 \ + --afterRegionStartLength 3000 \ + --binSize 10 \ + --sortRegions descend \ + --sortUsing mean \ + --outFileName matrix_TSS.gz \ + --outFileNameMatrix matrix_TSS.tab \ + --numberOfProcessors 8 +``` + +--- + +### Step 4: Generate Heatmap + +```bash +# Create heatmap around TSS +plotHeatmap \ + --matrixFile matrix_TSS.gz \ + --outFileName heatmap_TSS.png \ + --colorMap RdBu \ + --whatToShow 'plot, heatmap and colorbar' \ + --zMin -3 --zMax 3 \ + --yAxisLabel "Genes" \ + --xAxisLabel "Distance from TSS (bp)" \ + --refPointLabel "TSS" \ + --heatmapHeight 15 \ + --kmeans 3 +``` + +--- + +### Step 5: Generate Profile Plot + +```bash +# Create meta-profile around TSS +plotProfile \ + --matrixFile matrix_TSS.gz \ + --outFileName profile_TSS.png \ + --plotType lines \ + --perGroup \ + --colors blue \ + --plotTitle "ChIP-seq signal around TSS" \ + --yAxisLabel "Average signal" \ + --xAxisLabel "Distance from TSS (bp)" \ + --refPointLabel "TSS" +``` + +--- + +### Step 6: Enrichment at Peaks + +```bash +# Calculate enrichment in peak regions +plotEnrichment \ + --bamfiles Input.bam ChIP.bam \ + --BED peaks.bed \ + --labels Input ChIP \ + --plotFile enrichment.png \ + --outRawCounts enrichment_counts.tab \ + --extendReads 200 \ + --ignoreDuplicates +``` + +--- + +## RNA-seq Coverage Workflow + +Generate strand-specific coverage tracks for RNA-seq data. + +### Forward Strand + +```bash +bamCoverage \ + --bam rnaseq.bam \ + --outFileName forward_coverage.bw \ + --filterRNAstrand forward \ + --normalizeUsing CPM \ + --binSize 1 \ + --numberOfProcessors 8 +``` + +### Reverse Strand + +```bash +bamCoverage \ + --bam rnaseq.bam \ + --outFileName reverse_coverage.bw \ + --filterRNAstrand reverse \ + --normalizeUsing CPM \ + --binSize 1 \ + --numberOfProcessors 8 +``` + +**Important:** Do NOT use `--extendReads` for RNA-seq (would extend over splice junctions). + +--- + +## Multi-Sample Comparison Workflow + +Compare multiple ChIP-seq samples (e.g., different conditions or time points). + +### Step 1: Generate Coverage Files + +```bash +# For each sample +for sample in Control_ChIP Treated_ChIP; do + bamCoverage \ + --bam ${sample}.bam \ + --outFileName ${sample}.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --binSize 10 \ + --extendReads 200 \ + --ignoreDuplicates \ + --numberOfProcessors 8 +done +``` + +--- + +### Step 2: Compute Multi-Sample Matrix + +```bash +computeMatrix scale-regions \ + --scoreFileName Control_ChIP.bw Treated_ChIP.bw \ + --regionsFileName genes.bed \ + --beforeRegionStartLength 1000 \ + --afterRegionStartLength 1000 \ + --regionBodyLength 3000 \ + --binSize 10 \ + --sortRegions descend \ + --sortUsing mean \ + --outFileName matrix_multi.gz \ + --numberOfProcessors 8 +``` + +--- + +### Step 3: Multi-Sample Heatmap + +```bash +plotHeatmap \ + --matrixFile matrix_multi.gz \ + --outFileName heatmap_comparison.png \ + --colorMap Blues \ + --whatToShow 'plot, heatmap and colorbar' \ + --samplesLabel Control Treated \ + --yAxisLabel "Genes" \ + --heatmapHeight 15 \ + --kmeans 4 +``` + +--- + +### Step 4: Multi-Sample Profile + +```bash +plotProfile \ + --matrixFile matrix_multi.gz \ + --outFileName profile_comparison.png \ + --plotType lines \ + --perGroup \ + --colors blue red \ + --samplesLabel Control Treated \ + --plotTitle "ChIP-seq signal comparison" \ + --startLabel "TSS" \ + --endLabel "TES" +``` + +--- + +## ATAC-seq Workflow + +Specialized workflow for ATAC-seq data with Tn5 offset correction. + +### Step 1: Shift Reads for Tn5 Correction + +```bash +alignmentSieve \ + --bam atacseq.bam \ + --outFile atacseq_shifted.bam \ + --ATACshift \ + --minFragmentLength 38 \ + --maxFragmentLength 2000 \ + --ignoreDuplicates +``` + +--- + +### Step 2: Generate Coverage Track + +```bash +bamCoverage \ + --bam atacseq_shifted.bam \ + --outFileName atacseq_coverage.bw \ + --normalizeUsing RPGC \ + --effectiveGenomeSize 2913022398 \ + --binSize 1 \ + --numberOfProcessors 8 +``` + +--- + +### Step 3: Fragment Size Analysis + +```bash +bamPEFragmentSize \ + --bamfiles atacseq.bam \ + --histogram fragmentSizes_atac.png \ + --maxFragmentLength 1000 +``` + +**Expected Pattern:** Nucleosome ladder with peaks at ~50bp (nucleosome-free), ~200bp (mono-nucleosome), ~400bp (di-nucleosome). + +--- + +## Peak Region Analysis Workflow + +Analyze ChIP-seq signal specifically at peak regions. + +### Step 1: Matrix at Peaks + +```bash +computeMatrix reference-point \ + --referencePoint center \ + --scoreFileName ChIP_coverage.bw \ + --regionsFileName peaks.bed \ + --beforeRegionStartLength 2000 \ + --afterRegionStartLength 2000 \ + --binSize 10 \ + --outFileName matrix_peaks.gz \ + --numberOfProcessors 8 +``` + +--- + +### Step 2: Heatmap at Peaks + +```bash +plotHeatmap \ + --matrixFile matrix_peaks.gz \ + --outFileName heatmap_peaks.png \ + --colorMap YlOrRd \ + --refPointLabel "Peak Center" \ + --heatmapHeight 15 \ + --sortUsing max +``` + +--- + +## Troubleshooting Common Issues + +### Issue: Out of Memory +**Solution:** Use `--region` parameter to process chromosomes individually: +```bash +bamCoverage --bam input.bam -o chr1.bw --region chr1 +``` + +### Issue: BAM Index Missing +**Solution:** Index BAM files before running deepTools: +```bash +samtools index input.bam +``` + +### Issue: Slow Processing +**Solution:** Increase `--numberOfProcessors`: +```bash +# Use 8 cores instead of default +--numberOfProcessors 8 +``` + +### Issue: bigWig Files Too Large +**Solution:** Increase bin size: +```bash +--binSize 50 # or larger (default is 10-50) +``` + +--- + +## Performance Tips + +1. **Use multiple processors:** Always set `--numberOfProcessors` to available cores +2. **Process regions:** Use `--region` for testing or memory-limited environments +3. **Adjust bin size:** Larger bins = faster processing and smaller files +4. **Pre-filter BAM files:** Use `alignmentSieve` to create filtered BAM files once, then reuse +5. **Use bigWig over bedGraph:** bigWig format is compressed and faster to process + +--- + +## Best Practices + +1. **Always check QC first:** Run correlation, coverage, and fingerprint analysis before proceeding +2. **Document parameters:** Save command lines for reproducibility +3. **Use consistent normalization:** Apply same normalization method across samples in a comparison +4. **Verify reference genome match:** Ensure BAM files and region files use same genome build +5. **Check strand orientation:** For RNA-seq, verify correct strand orientation +6. **Test on small regions first:** Use `--region chr1:1-1000000` for testing parameters +7. **Keep intermediate files:** Save matrices for regenerating plots with different settings diff --git a/scientific-packages/deeptools/scripts/validate_files.py b/scientific-packages/deeptools/scripts/validate_files.py new file mode 100644 index 0000000..d988514 --- /dev/null +++ b/scientific-packages/deeptools/scripts/validate_files.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +deepTools File Validation Script + +Validates BAM, bigWig, and BED files for deepTools analysis. +Checks for file existence, proper indexing, and basic format requirements. +""" + +import os +import sys +import argparse +from pathlib import Path + + +def check_file_exists(filepath): + """Check if file exists and is readable.""" + if not os.path.exists(filepath): + return False, f"File not found: {filepath}" + if not os.access(filepath, os.R_OK): + return False, f"File not readable: {filepath}" + return True, f"✓ File exists: {filepath}" + + +def check_bam_index(bam_file): + """Check if BAM file has an index (.bai or .bam.bai).""" + bai_file1 = bam_file + ".bai" + bai_file2 = bam_file.replace(".bam", ".bai") + + if os.path.exists(bai_file1): + return True, f"✓ BAM index found: {bai_file1}" + elif os.path.exists(bai_file2): + return True, f"✓ BAM index found: {bai_file2}" + else: + return False, f"✗ BAM index missing for: {bam_file}\n Run: samtools index {bam_file}" + + +def check_bigwig_file(bw_file): + """Basic check for bigWig file.""" + # Check file size (bigWig files should have reasonable size) + file_size = os.path.getsize(bw_file) + if file_size < 100: + return False, f"✗ bigWig file suspiciously small: {bw_file} ({file_size} bytes)" + return True, f"✓ bigWig file appears valid: {bw_file} ({file_size} bytes)" + + +def check_bed_file(bed_file): + """Basic validation of BED file format.""" + try: + with open(bed_file, 'r') as f: + lines = [line.strip() for line in f if line.strip() and not line.startswith('#')] + + if len(lines) == 0: + return False, f"✗ BED file is empty: {bed_file}" + + # Check first few lines for basic format + for i, line in enumerate(lines[:10], 1): + fields = line.split('\t') + if len(fields) < 3: + return False, f"✗ BED file format error at line {i}: expected at least 3 columns\n Line: {line}" + + # Check if start and end are integers + try: + start = int(fields[1]) + end = int(fields[2]) + if start >= end: + return False, f"✗ BED file error at line {i}: start >= end ({start} >= {end})" + except ValueError: + return False, f"✗ BED file format error at line {i}: start and end must be integers\n Line: {line}" + + return True, f"✓ BED file format appears valid: {bed_file} ({len(lines)} regions)" + + except Exception as e: + return False, f"✗ Error reading BED file: {bed_file}\n Error: {str(e)}" + + +def validate_files(bam_files=None, bigwig_files=None, bed_files=None): + """ + Validate all provided files. + + Args: + bam_files: List of BAM file paths + bigwig_files: List of bigWig file paths + bed_files: List of BED file paths + + Returns: + Tuple of (success: bool, messages: list) + """ + all_success = True + messages = [] + + # Validate BAM files + if bam_files: + messages.append("\n=== Validating BAM Files ===") + for bam_file in bam_files: + # Check existence + success, msg = check_file_exists(bam_file) + messages.append(msg) + if not success: + all_success = False + continue + + # Check index + success, msg = check_bam_index(bam_file) + messages.append(msg) + if not success: + all_success = False + + # Validate bigWig files + if bigwig_files: + messages.append("\n=== Validating bigWig Files ===") + for bw_file in bigwig_files: + # Check existence + success, msg = check_file_exists(bw_file) + messages.append(msg) + if not success: + all_success = False + continue + + # Basic bigWig check + success, msg = check_bigwig_file(bw_file) + messages.append(msg) + if not success: + all_success = False + + # Validate BED files + if bed_files: + messages.append("\n=== Validating BED Files ===") + for bed_file in bed_files: + # Check existence + success, msg = check_file_exists(bed_file) + messages.append(msg) + if not success: + all_success = False + continue + + # Check BED format + success, msg = check_bed_file(bed_file) + messages.append(msg) + if not success: + all_success = False + + return all_success, messages + + +def main(): + parser = argparse.ArgumentParser( + description="Validate files for deepTools analysis", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Validate BAM files + python validate_files.py --bam sample1.bam sample2.bam + + # Validate all file types + python validate_files.py --bam input.bam chip.bam --bed peaks.bed --bigwig signal.bw + + # Validate from a directory + python validate_files.py --bam *.bam --bed *.bed + """ + ) + + parser.add_argument('--bam', nargs='+', help='BAM files to validate') + parser.add_argument('--bigwig', '--bw', nargs='+', help='bigWig files to validate') + parser.add_argument('--bed', nargs='+', help='BED files to validate') + + args = parser.parse_args() + + # Check if any files were provided + if not any([args.bam, args.bigwig, args.bed]): + parser.print_help() + sys.exit(1) + + # Run validation + success, messages = validate_files( + bam_files=args.bam, + bigwig_files=args.bigwig, + bed_files=args.bed + ) + + # Print results + for msg in messages: + print(msg) + + # Summary + print("\n" + "="*50) + if success: + print("✓ All validations passed!") + sys.exit(0) + else: + print("✗ Some validations failed. Please fix the issues above.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/deeptools/scripts/workflow_generator.py b/scientific-packages/deeptools/scripts/workflow_generator.py new file mode 100644 index 0000000..03e1512 --- /dev/null +++ b/scientific-packages/deeptools/scripts/workflow_generator.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 +""" +deepTools Workflow Generator + +Generates bash script templates for common deepTools workflows. +""" + +import argparse +import sys + + +WORKFLOWS = { + 'chipseq_qc': { + 'name': 'ChIP-seq Quality Control', + 'description': 'Complete QC workflow for ChIP-seq experiments', + }, + 'chipseq_analysis': { + 'name': 'ChIP-seq Complete Analysis', + 'description': 'Full ChIP-seq analysis from BAM to heatmaps', + }, + 'rnaseq_coverage': { + 'name': 'RNA-seq Coverage Tracks', + 'description': 'Generate strand-specific RNA-seq coverage', + }, + 'atacseq': { + 'name': 'ATAC-seq Analysis', + 'description': 'ATAC-seq workflow with Tn5 correction', + }, +} + + +def generate_chipseq_qc_workflow(output_file, params): + """Generate ChIP-seq QC workflow script.""" + + script = f"""#!/bin/bash +# deepTools ChIP-seq Quality Control Workflow +# Generated by deepTools workflow generator + +# Configuration +INPUT_BAM="{params.get('input_bam', 'Input.bam')}" +CHIP_BAM=("{params.get('chip_bams', 'ChIP1.bam ChIP2.bam')}") +GENOME_SIZE={params.get('genome_size', '2913022398')} +THREADS={params.get('threads', '8')} +OUTPUT_DIR="{params.get('output_dir', 'deeptools_qc')}" + +# Create output directory +mkdir -p $OUTPUT_DIR + +echo "=== Starting ChIP-seq QC workflow ===" + +# Step 1: Correlation analysis +echo "Step 1: Computing correlation matrix..." +multiBamSummary bins \\ + --bamfiles $INPUT_BAM ${{CHIP_BAM[@]}} \\ + -o $OUTPUT_DIR/readCounts.npz \\ + --numberOfProcessors $THREADS + +echo "Step 2: Generating correlation heatmap..." +plotCorrelation \\ + -in $OUTPUT_DIR/readCounts.npz \\ + --corMethod pearson \\ + --whatToShow heatmap \\ + --plotFile $OUTPUT_DIR/correlation_heatmap.png \\ + --plotNumbers + +echo "Step 3: Generating PCA plot..." +plotPCA \\ + -in $OUTPUT_DIR/readCounts.npz \\ + -o $OUTPUT_DIR/PCA_plot.png \\ + -T "PCA of ChIP-seq samples" + +# Step 2: Coverage assessment +echo "Step 4: Assessing coverage..." +plotCoverage \\ + --bamfiles $INPUT_BAM ${{CHIP_BAM[@]}} \\ + --plotFile $OUTPUT_DIR/coverage.png \\ + --ignoreDuplicates \\ + --numberOfProcessors $THREADS + +# Step 3: Fragment size (for paired-end data) +echo "Step 5: Analyzing fragment sizes..." +bamPEFragmentSize \\ + --bamfiles $INPUT_BAM ${{CHIP_BAM[@]}} \\ + --histogram $OUTPUT_DIR/fragmentSizes.png \\ + --plotTitle "Fragment Size Distribution" + +# Step 4: ChIP signal strength +echo "Step 6: Evaluating ChIP enrichment..." +plotFingerprint \\ + --bamfiles $INPUT_BAM ${{CHIP_BAM[@]}} \\ + --plotFile $OUTPUT_DIR/fingerprint.png \\ + --extendReads 200 \\ + --ignoreDuplicates \\ + --numberOfProcessors $THREADS \\ + --outQualityMetrics $OUTPUT_DIR/fingerprint_metrics.txt + +echo "=== ChIP-seq QC workflow complete ===" +echo "Results are in: $OUTPUT_DIR" +""" + + with open(output_file, 'w') as f: + f.write(script) + + return f"✓ Generated ChIP-seq QC workflow: {output_file}" + + +def generate_chipseq_analysis_workflow(output_file, params): + """Generate complete ChIP-seq analysis workflow script.""" + + script = f"""#!/bin/bash +# deepTools ChIP-seq Complete Analysis Workflow +# Generated by deepTools workflow generator + +# Configuration +INPUT_BAM="{params.get('input_bam', 'Input.bam')}" +CHIP_BAM="{params.get('chip_bam', 'ChIP.bam')}" +GENES_BED="{params.get('genes_bed', 'genes.bed')}" +PEAKS_BED="{params.get('peaks_bed', 'peaks.bed')}" +GENOME_SIZE={params.get('genome_size', '2913022398')} +THREADS={params.get('threads', '8')} +OUTPUT_DIR="{params.get('output_dir', 'chipseq_analysis')}" + +# Create output directory +mkdir -p $OUTPUT_DIR + +echo "=== Starting ChIP-seq analysis workflow ===" + +# Step 1: Generate normalized coverage tracks +echo "Step 1: Generating coverage tracks..." + +bamCoverage \\ + --bam $INPUT_BAM \\ + --outFileName $OUTPUT_DIR/Input_coverage.bw \\ + --normalizeUsing RPGC \\ + --effectiveGenomeSize $GENOME_SIZE \\ + --binSize 10 \\ + --extendReads 200 \\ + --ignoreDuplicates \\ + --numberOfProcessors $THREADS + +bamCoverage \\ + --bam $CHIP_BAM \\ + --outFileName $OUTPUT_DIR/ChIP_coverage.bw \\ + --normalizeUsing RPGC \\ + --effectiveGenomeSize $GENOME_SIZE \\ + --binSize 10 \\ + --extendReads 200 \\ + --ignoreDuplicates \\ + --numberOfProcessors $THREADS + +# Step 2: Create log2 ratio track +echo "Step 2: Creating log2 ratio track..." +bamCompare \\ + --bamfile1 $CHIP_BAM \\ + --bamfile2 $INPUT_BAM \\ + --outFileName $OUTPUT_DIR/ChIP_vs_Input_log2ratio.bw \\ + --operation log2 \\ + --scaleFactorsMethod readCount \\ + --binSize 10 \\ + --extendReads 200 \\ + --ignoreDuplicates \\ + --numberOfProcessors $THREADS + +# Step 3: Compute matrix around TSS +echo "Step 3: Computing matrix around TSS..." +computeMatrix reference-point \\ + --referencePoint TSS \\ + --scoreFileName $OUTPUT_DIR/ChIP_coverage.bw \\ + --regionsFileName $GENES_BED \\ + --beforeRegionStartLength 3000 \\ + --afterRegionStartLength 3000 \\ + --binSize 10 \\ + --sortRegions descend \\ + --sortUsing mean \\ + --outFileName $OUTPUT_DIR/matrix_TSS.gz \\ + --numberOfProcessors $THREADS + +# Step 4: Generate heatmap +echo "Step 4: Generating heatmap..." +plotHeatmap \\ + --matrixFile $OUTPUT_DIR/matrix_TSS.gz \\ + --outFileName $OUTPUT_DIR/heatmap_TSS.png \\ + --colorMap RdBu \\ + --whatToShow 'plot, heatmap and colorbar' \\ + --yAxisLabel "Genes" \\ + --xAxisLabel "Distance from TSS (bp)" \\ + --refPointLabel "TSS" \\ + --heatmapHeight 15 \\ + --kmeans 3 + +# Step 5: Generate profile plot +echo "Step 5: Generating profile plot..." +plotProfile \\ + --matrixFile $OUTPUT_DIR/matrix_TSS.gz \\ + --outFileName $OUTPUT_DIR/profile_TSS.png \\ + --plotType lines \\ + --perGroup \\ + --colors blue \\ + --plotTitle "ChIP-seq signal around TSS" \\ + --yAxisLabel "Average signal" \\ + --refPointLabel "TSS" + +# Step 6: Enrichment at peaks (if peaks provided) +if [ -f "$PEAKS_BED" ]; then + echo "Step 6: Calculating enrichment at peaks..." + plotEnrichment \\ + --bamfiles $INPUT_BAM $CHIP_BAM \\ + --BED $PEAKS_BED \\ + --labels Input ChIP \\ + --plotFile $OUTPUT_DIR/enrichment.png \\ + --outRawCounts $OUTPUT_DIR/enrichment_counts.tab \\ + --extendReads 200 \\ + --ignoreDuplicates +fi + +echo "=== ChIP-seq analysis complete ===" +echo "Results are in: $OUTPUT_DIR" +""" + + with open(output_file, 'w') as f: + f.write(script) + + return f"✓ Generated ChIP-seq analysis workflow: {output_file}" + + +def generate_rnaseq_coverage_workflow(output_file, params): + """Generate RNA-seq coverage workflow script.""" + + script = f"""#!/bin/bash +# deepTools RNA-seq Coverage Workflow +# Generated by deepTools workflow generator + +# Configuration +RNASEQ_BAM="{params.get('rnaseq_bam', 'rnaseq.bam')}" +THREADS={params.get('threads', '8')} +OUTPUT_DIR="{params.get('output_dir', 'rnaseq_coverage')}" + +# Create output directory +mkdir -p $OUTPUT_DIR + +echo "=== Starting RNA-seq coverage workflow ===" + +# Generate strand-specific coverage tracks +echo "Step 1: Generating forward strand coverage..." +bamCoverage \\ + --bam $RNASEQ_BAM \\ + --outFileName $OUTPUT_DIR/forward_coverage.bw \\ + --filterRNAstrand forward \\ + --normalizeUsing CPM \\ + --binSize 1 \\ + --numberOfProcessors $THREADS + +echo "Step 2: Generating reverse strand coverage..." +bamCoverage \\ + --bam $RNASEQ_BAM \\ + --outFileName $OUTPUT_DIR/reverse_coverage.bw \\ + --filterRNAstrand reverse \\ + --normalizeUsing CPM \\ + --binSize 1 \\ + --numberOfProcessors $THREADS + +echo "=== RNA-seq coverage workflow complete ===" +echo "Results are in: $OUTPUT_DIR" +echo "" +echo "Note: These bigWig files can be loaded into genome browsers" +echo "for strand-specific visualization of RNA-seq data." +""" + + with open(output_file, 'w') as f: + f.write(script) + + return f"✓ Generated RNA-seq coverage workflow: {output_file}" + + +def generate_atacseq_workflow(output_file, params): + """Generate ATAC-seq workflow script.""" + + script = f"""#!/bin/bash +# deepTools ATAC-seq Analysis Workflow +# Generated by deepTools workflow generator + +# Configuration +ATAC_BAM="{params.get('atac_bam', 'atacseq.bam')}" +PEAKS_BED="{params.get('peaks_bed', 'peaks.bed')}" +GENOME_SIZE={params.get('genome_size', '2913022398')} +THREADS={params.get('threads', '8')} +OUTPUT_DIR="{params.get('output_dir', 'atacseq_analysis')}" + +# Create output directory +mkdir -p $OUTPUT_DIR + +echo "=== Starting ATAC-seq analysis workflow ===" + +# Step 1: Shift reads for Tn5 correction +echo "Step 1: Applying Tn5 offset correction..." +alignmentSieve \\ + --bam $ATAC_BAM \\ + --outFile $OUTPUT_DIR/atacseq_shifted.bam \\ + --ATACshift \\ + --minFragmentLength 38 \\ + --maxFragmentLength 2000 \\ + --ignoreDuplicates + +# Index the shifted BAM +samtools index $OUTPUT_DIR/atacseq_shifted.bam + +# Step 2: Generate coverage track +echo "Step 2: Generating coverage track..." +bamCoverage \\ + --bam $OUTPUT_DIR/atacseq_shifted.bam \\ + --outFileName $OUTPUT_DIR/atacseq_coverage.bw \\ + --normalizeUsing RPGC \\ + --effectiveGenomeSize $GENOME_SIZE \\ + --binSize 1 \\ + --numberOfProcessors $THREADS + +# Step 3: Fragment size analysis +echo "Step 3: Analyzing fragment sizes..." +bamPEFragmentSize \\ + --bamfiles $ATAC_BAM \\ + --histogram $OUTPUT_DIR/fragmentSizes.png \\ + --maxFragmentLength 1000 + +# Step 4: Compute matrix at peaks (if peaks provided) +if [ -f "$PEAKS_BED" ]; then + echo "Step 4: Computing matrix at peaks..." + computeMatrix reference-point \\ + --referencePoint center \\ + --scoreFileName $OUTPUT_DIR/atacseq_coverage.bw \\ + --regionsFileName $PEAKS_BED \\ + --beforeRegionStartLength 2000 \\ + --afterRegionStartLength 2000 \\ + --binSize 10 \\ + --outFileName $OUTPUT_DIR/matrix_peaks.gz \\ + --numberOfProcessors $THREADS + + echo "Step 5: Generating heatmap..." + plotHeatmap \\ + --matrixFile $OUTPUT_DIR/matrix_peaks.gz \\ + --outFileName $OUTPUT_DIR/heatmap_peaks.png \\ + --colorMap YlOrRd \\ + --refPointLabel "Peak Center" \\ + --heatmapHeight 15 +fi + +echo "=== ATAC-seq analysis complete ===" +echo "Results are in: $OUTPUT_DIR" +echo "" +echo "Expected fragment size pattern:" +echo " ~50bp: nucleosome-free regions" +echo " ~200bp: mono-nucleosome" +echo " ~400bp: di-nucleosome" +""" + + with open(output_file, 'w') as f: + f.write(script) + + return f"✓ Generated ATAC-seq workflow: {output_file}" + + +def main(): + parser = argparse.ArgumentParser( + description="Generate deepTools workflow scripts", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Available workflows: +{chr(10).join(f" {key}: {value['name']}" for key, value in WORKFLOWS.items())} + +Examples: + # Generate ChIP-seq QC workflow + python workflow_generator.py chipseq_qc -o chipseq_qc.sh + + # Generate ChIP-seq analysis with custom parameters + python workflow_generator.py chipseq_analysis -o analysis.sh \\ + --chip-bam H3K4me3.bam --input-bam Input.bam + + # List all available workflows + python workflow_generator.py --list + """ + ) + + parser.add_argument('workflow', nargs='?', choices=list(WORKFLOWS.keys()), + help='Workflow type to generate') + parser.add_argument('-o', '--output', default='deeptools_workflow.sh', + help='Output script filename (default: deeptools_workflow.sh)') + parser.add_argument('--list', action='store_true', + help='List all available workflows') + + # Common parameters + parser.add_argument('--threads', type=int, default=8, + help='Number of threads (default: 8)') + parser.add_argument('--genome-size', type=int, default=2913022398, + help='Effective genome size (default: 2913022398 for hg38)') + parser.add_argument('--output-dir', default=None, + help='Output directory for results') + + # Workflow-specific parameters + parser.add_argument('--input-bam', help='Input/control BAM file') + parser.add_argument('--chip-bam', help='ChIP BAM file') + parser.add_argument('--chip-bams', help='Multiple ChIP BAM files (space-separated)') + parser.add_argument('--rnaseq-bam', help='RNA-seq BAM file') + parser.add_argument('--atac-bam', help='ATAC-seq BAM file') + parser.add_argument('--genes-bed', help='Genes BED file') + parser.add_argument('--peaks-bed', help='Peaks BED file') + + args = parser.parse_args() + + # List workflows + if args.list: + print("\nAvailable deepTools workflows:\n") + for key, value in WORKFLOWS.items(): + print(f" {key}") + print(f" {value['name']}") + print(f" {value['description']}\n") + sys.exit(0) + + # Check if workflow was specified + if not args.workflow: + parser.print_help() + sys.exit(1) + + # Prepare parameters + params = { + 'threads': args.threads, + 'genome_size': args.genome_size, + 'output_dir': args.output_dir or f"{args.workflow}_output", + 'input_bam': args.input_bam, + 'chip_bam': args.chip_bam, + 'chip_bams': args.chip_bams, + 'rnaseq_bam': args.rnaseq_bam, + 'atac_bam': args.atac_bam, + 'genes_bed': args.genes_bed, + 'peaks_bed': args.peaks_bed, + } + + # Generate workflow + if args.workflow == 'chipseq_qc': + message = generate_chipseq_qc_workflow(args.output, params) + elif args.workflow == 'chipseq_analysis': + message = generate_chipseq_analysis_workflow(args.output, params) + elif args.workflow == 'rnaseq_coverage': + message = generate_rnaseq_coverage_workflow(args.output, params) + elif args.workflow == 'atacseq': + message = generate_atacseq_workflow(args.output, params) + + print(message) + print(f"\nTo run the workflow:") + print(f" chmod +x {args.output}") + print(f" ./{args.output}") + print(f"\nNote: Edit the script to customize file paths and parameters.") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/diffdock/SKILL.md b/scientific-packages/diffdock/SKILL.md new file mode 100644 index 0000000..2718f33 --- /dev/null +++ b/scientific-packages/diffdock/SKILL.md @@ -0,0 +1,477 @@ +--- +name: diffdock +description: This skill provides comprehensive guidance for using DiffDock, a state-of-the-art diffusion-based molecular docking tool that predicts protein-ligand binding poses. Use this skill when users request molecular docking simulations, protein-ligand binding predictions, virtual screening, structure-based drug design tasks, or need to predict how small molecules bind to protein targets. This skill applies to tasks involving PDB files, SMILES strings, protein sequences, ligand structure files, or batch docking of compound libraries. +--- + +# DiffDock: Molecular Docking with Diffusion Models + +## Overview + +DiffDock is a diffusion-based deep learning tool for molecular docking that predicts 3D binding poses of small molecule ligands to protein targets. It represents the state-of-the-art in computational docking, crucial for structure-based drug discovery and chemical biology. + +**Core Capabilities:** +- Predict ligand binding poses with high accuracy using deep learning +- Support protein structures (PDB files) or sequences (via ESMFold) +- Process single complexes or batch virtual screening campaigns +- Generate confidence scores to assess prediction reliability +- Handle diverse ligand inputs (SMILES, SDF, MOL2) + +**Key Distinction:** DiffDock predicts **binding poses** (3D structure) and **confidence** (prediction certainty), NOT binding affinity (ΔG, Kd). Always combine with scoring functions (GNINA, MM/GBSA) for affinity assessment. + +## When to Use DiffDock + +Invoke this skill when users request: + +- "Dock this ligand to a protein" or "predict binding pose" +- "Run molecular docking" or "perform protein-ligand docking" +- "Virtual screening" or "screen compound library" +- "Where does this molecule bind?" or "predict binding site" +- Structure-based drug design or lead optimization tasks +- Tasks involving PDB files + SMILES strings or ligand structures +- Batch docking of multiple protein-ligand pairs + +## Installation and Environment Setup + +### Check Environment Status + +Before proceeding with DiffDock tasks, verify the environment setup: + +```bash +# Use the provided setup checker +python scripts/setup_check.py +``` + +This script validates Python version, PyTorch with CUDA, PyTorch Geometric, RDKit, ESM, and other dependencies. + +### Installation Options + +**Option 1: Conda (Recommended)** +```bash +git clone https://github.com/gcorso/DiffDock.git +cd DiffDock +conda env create --file environment.yml +conda activate diffdock +``` + +**Option 2: Docker** +```bash +docker pull rbgcsail/diffdock +docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock +micromamba activate diffdock +``` + +**Important Notes:** +- GPU strongly recommended (10-100x speedup vs CPU) +- First run pre-computes SO(2)/SO(3) lookup tables (~2-5 minutes) +- Model checkpoints (~500MB) download automatically if not present + +## Core Workflows + +### Workflow 1: Single Protein-Ligand Docking + +**Use Case:** Dock one ligand to one protein target + +**Input Requirements:** +- Protein: PDB file OR amino acid sequence +- Ligand: SMILES string OR structure file (SDF/MOL2) + +**Command:** +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_path protein.pdb \ + --ligand "CC(=O)Oc1ccccc1C(=O)O" \ + --out_dir results/single_docking/ +``` + +**Alternative (protein sequence):** +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_sequence "MSKGEELFTGVVPILVELDGDVNGHKF..." \ + --ligand ligand.sdf \ + --out_dir results/sequence_docking/ +``` + +**Output Structure:** +``` +results/single_docking/ +├── rank_1.sdf # Top-ranked pose +├── rank_2.sdf # Second-ranked pose +├── ... +├── rank_10.sdf # 10th pose (default: 10 samples) +└── confidence_scores.txt +``` + +### Workflow 2: Batch Processing Multiple Complexes + +**Use Case:** Dock multiple ligands to proteins, virtual screening campaigns + +**Step 1: Prepare Batch CSV** + +Use the provided script to create or validate batch input: + +```bash +# Create template +python scripts/prepare_batch_csv.py --create --output batch_input.csv + +# Validate existing CSV +python scripts/prepare_batch_csv.py my_input.csv --validate +``` + +**CSV Format:** +```csv +complex_name,protein_path,ligand_description,protein_sequence +complex1,protein1.pdb,CC(=O)Oc1ccccc1C(=O)O, +complex2,,COc1ccc(C#N)cc1,MSKGEELFT... +complex3,protein3.pdb,ligand3.sdf, +``` + +**Required Columns:** +- `complex_name`: Unique identifier +- `protein_path`: PDB file path (leave empty if using sequence) +- `ligand_description`: SMILES string or ligand file path +- `protein_sequence`: Amino acid sequence (leave empty if using PDB) + +**Step 2: Run Batch Docking** + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv batch_input.csv \ + --out_dir results/batch/ \ + --batch_size 10 +``` + +**For Large Virtual Screening (>100 compounds):** + +Pre-compute protein embeddings for faster processing: +```bash +# Pre-compute embeddings +python datasets/esm_embedding_preparation.py \ + --protein_ligand_csv screening_input.csv \ + --out_file protein_embeddings.pt + +# Run with pre-computed embeddings +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv screening_input.csv \ + --esm_embeddings_path protein_embeddings.pt \ + --out_dir results/screening/ +``` + +### Workflow 3: Analyzing Results + +After docking completes, analyze confidence scores and rank predictions: + +```bash +# Analyze all results +python scripts/analyze_results.py results/batch/ + +# Show top 5 per complex +python scripts/analyze_results.py results/batch/ --top 5 + +# Filter by confidence threshold +python scripts/analyze_results.py results/batch/ --threshold 0.0 + +# Export to CSV +python scripts/analyze_results.py results/batch/ --export summary.csv + +# Show top 20 predictions across all complexes +python scripts/analyze_results.py results/batch/ --best 20 +``` + +The analysis script: +- Parses confidence scores from all predictions +- Classifies as High (>0), Moderate (-1.5 to 0), or Low (<-1.5) +- Ranks predictions within and across complexes +- Generates statistical summaries +- Exports results to CSV for downstream analysis + +## Confidence Score Interpretation + +**Understanding Scores:** + +| Score Range | Confidence Level | Interpretation | +|------------|------------------|----------------| +| **> 0** | High | Strong prediction, likely accurate | +| **-1.5 to 0** | Moderate | Reasonable prediction, validate carefully | +| **< -1.5** | Low | Uncertain prediction, requires validation | + +**Critical Notes:** +1. **Confidence ≠ Affinity**: High confidence means model certainty about structure, NOT strong binding +2. **Context Matters**: Adjust expectations for: + - Large ligands (>500 Da): Lower confidence expected + - Multiple protein chains: May decrease confidence + - Novel protein families: May underperform +3. **Multiple Samples**: Review top 3-5 predictions, look for consensus + +**For detailed guidance:** Read `references/confidence_and_limitations.md` using the Read tool + +## Parameter Customization + +### Using Custom Configuration + +Create custom configuration for specific use cases: + +```bash +# Copy template +cp assets/custom_inference_config.yaml my_config.yaml + +# Edit parameters (see template for presets) +# Then run with custom config +python -m inference \ + --config my_config.yaml \ + --protein_ligand_csv input.csv \ + --out_dir results/ +``` + +### Key Parameters to Adjust + +**Sampling Density:** +- `samples_per_complex: 10` → Increase to 20-40 for difficult cases +- More samples = better coverage but longer runtime + +**Inference Steps:** +- `inference_steps: 20` → Increase to 25-30 for higher accuracy +- More steps = potentially better quality but slower + +**Temperature Parameters (control diversity):** +- `temp_sampling_tor: 7.04` → Increase for flexible ligands (8-10) +- `temp_sampling_tor: 7.04` → Decrease for rigid ligands (5-6) +- Higher temperature = more diverse poses + +**Presets Available in Template:** +1. High Accuracy: More samples + steps, lower temperature +2. Fast Screening: Fewer samples, faster +3. Flexible Ligands: Increased torsion temperature +4. Rigid Ligands: Decreased torsion temperature + +**For complete parameter reference:** Read `references/parameters_reference.md` using the Read tool + +## Advanced Techniques + +### Ensemble Docking (Protein Flexibility) + +For proteins with known flexibility, dock to multiple conformations: + +```python +# Create ensemble CSV +import pandas as pd + +conformations = ["conf1.pdb", "conf2.pdb", "conf3.pdb"] +ligand = "CC(=O)Oc1ccccc1C(=O)O" + +data = { + "complex_name": [f"ensemble_{i}" for i in range(len(conformations))], + "protein_path": conformations, + "ligand_description": [ligand] * len(conformations), + "protein_sequence": [""] * len(conformations) +} + +pd.DataFrame(data).to_csv("ensemble_input.csv", index=False) +``` + +Run docking with increased sampling: +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv ensemble_input.csv \ + --samples_per_complex 20 \ + --out_dir results/ensemble/ +``` + +### Integration with Scoring Functions + +DiffDock generates poses; combine with other tools for affinity: + +**GNINA (Fast neural network scoring):** +```bash +for pose in results/*.sdf; do + gnina -r protein.pdb -l "$pose" --score_only +done +``` + +**MM/GBSA (More accurate, slower):** +Use AmberTools MMPBSA.py or gmx_MMPBSA after energy minimization + +**Free Energy Calculations (Most accurate):** +Use OpenMM + OpenFE or GROMACS for FEP/TI calculations + +**Recommended Workflow:** +1. DiffDock → Generate poses with confidence scores +2. Visual inspection → Check structural plausibility +3. GNINA or MM/GBSA → Rescore and rank by affinity +4. Experimental validation → Biochemical assays + +## Limitations and Scope + +**DiffDock IS Designed For:** +- Small molecule ligands (typically 100-1000 Da) +- Drug-like organic compounds +- Small peptides (<20 residues) +- Single or multi-chain proteins + +**DiffDock IS NOT Designed For:** +- Large biomolecules (protein-protein docking) → Use DiffDock-PP or AlphaFold-Multimer +- Large peptides (>20 residues) → Use alternative methods +- Covalent docking → Use specialized covalent docking tools +- Binding affinity prediction → Combine with scoring functions +- Membrane proteins → Not specifically trained, use with caution + +**For complete limitations:** Read `references/confidence_and_limitations.md` using the Read tool + +## Troubleshooting + +### Common Issues + +**Issue: Low confidence scores across all predictions** +- Cause: Large/unusual ligands, unclear binding site, protein flexibility +- Solution: Increase `samples_per_complex` (20-40), try ensemble docking, validate protein structure + +**Issue: Out of memory errors** +- Cause: GPU memory insufficient for batch size +- Solution: Reduce `--batch_size 2` or process fewer complexes at once + +**Issue: Slow performance** +- Cause: Running on CPU instead of GPU +- Solution: Verify CUDA with `python -c "import torch; print(torch.cuda.is_available())"`, use GPU + +**Issue: Unrealistic binding poses** +- Cause: Poor protein preparation, ligand too large, wrong binding site +- Solution: Check protein for missing residues, remove far waters, consider specifying binding site + +**Issue: "Module not found" errors** +- Cause: Missing dependencies or wrong environment +- Solution: Run `python scripts/setup_check.py` to diagnose + +### Performance Optimization + +**For Best Results:** +1. Use GPU (essential for practical use) +2. Pre-compute ESM embeddings for repeated protein use +3. Batch process multiple complexes together +4. Start with default parameters, then tune if needed +5. Validate protein structures (resolve missing residues) +6. Use canonical SMILES for ligands + +## Graphical User Interface + +For interactive use, launch the web interface: + +```bash +python app/main.py +# Navigate to http://localhost:7860 +``` + +Or use the online demo without installation: +- https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web + +## Resources + +### Helper Scripts (`scripts/`) + +**`prepare_batch_csv.py`**: Create and validate batch input CSV files +- Create templates with example entries +- Validate file paths and SMILES strings +- Check for required columns and format issues + +**`analyze_results.py`**: Analyze confidence scores and rank predictions +- Parse results from single or batch runs +- Generate statistical summaries +- Export to CSV for downstream analysis +- Identify top predictions across complexes + +**`setup_check.py`**: Verify DiffDock environment setup +- Check Python version and dependencies +- Verify PyTorch and CUDA availability +- Test RDKit and PyTorch Geometric installation +- Provide installation instructions if needed + +### Reference Documentation (`references/`) + +**`parameters_reference.md`**: Complete parameter documentation +- All command-line options and configuration parameters +- Default values and acceptable ranges +- Temperature parameters for controlling diversity +- Model checkpoint locations and version flags + +Read this file when users need: +- Detailed parameter explanations +- Fine-tuning guidance for specific systems +- Alternative sampling strategies + +**`confidence_and_limitations.md`**: Confidence score interpretation and tool limitations +- Detailed confidence score interpretation +- When to trust predictions +- Scope and limitations of DiffDock +- Integration with complementary tools +- Troubleshooting prediction quality + +Read this file when users need: +- Help interpreting confidence scores +- Understanding when NOT to use DiffDock +- Guidance on combining with other tools +- Validation strategies + +**`workflows_examples.md`**: Comprehensive workflow examples +- Detailed installation instructions +- Step-by-step examples for all workflows +- Advanced integration patterns +- Troubleshooting common issues +- Best practices and optimization tips + +Read this file when users need: +- Complete workflow examples with code +- Integration with GNINA, OpenMM, or other tools +- Virtual screening workflows +- Ensemble docking procedures + +### Assets (`assets/`) + +**`batch_template.csv`**: Template for batch processing +- Pre-formatted CSV with required columns +- Example entries showing different input types +- Ready to customize with actual data + +**`custom_inference_config.yaml`**: Configuration template +- Annotated YAML with all parameters +- Four preset configurations for common use cases +- Detailed comments explaining each parameter +- Ready to customize and use + +## Best Practices + +1. **Always verify environment** with `setup_check.py` before starting large jobs +2. **Validate batch CSVs** with `prepare_batch_csv.py` to catch errors early +3. **Start with defaults** then tune parameters based on system-specific needs +4. **Generate multiple samples** (10-40) for robust predictions +5. **Visual inspection** of top poses before downstream analysis +6. **Combine with scoring** functions for affinity assessment +7. **Use confidence scores** for initial ranking, not final decisions +8. **Pre-compute embeddings** for virtual screening campaigns +9. **Document parameters** used for reproducibility +10. **Validate results** experimentally when possible + +## Citations + +When using DiffDock, cite the appropriate papers: + +**DiffDock-L (current default model):** +``` +Stärk et al. (2024) "DiffDock-L: Improving Molecular Docking with Diffusion Models" +arXiv:2402.18396 +``` + +**Original DiffDock:** +``` +Corso et al. (2023) "DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking" +ICLR 2023, arXiv:2210.01776 +``` + +## Additional Resources + +- **GitHub Repository**: https://github.com/gcorso/DiffDock +- **Online Demo**: https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web +- **DiffDock-L Paper**: https://arxiv.org/abs/2402.18396 +- **Original Paper**: https://arxiv.org/abs/2210.01776 diff --git a/scientific-packages/diffdock/assets/batch_template.csv b/scientific-packages/diffdock/assets/batch_template.csv new file mode 100644 index 0000000..fc1990e --- /dev/null +++ b/scientific-packages/diffdock/assets/batch_template.csv @@ -0,0 +1,4 @@ +complex_name,protein_path,ligand_description,protein_sequence +example_1,protein1.pdb,CC(=O)Oc1ccccc1C(=O)O, +example_2,,COc1ccc(C#N)cc1,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK +example_3,protein3.pdb,ligand3.sdf, diff --git a/scientific-packages/diffdock/assets/custom_inference_config.yaml b/scientific-packages/diffdock/assets/custom_inference_config.yaml new file mode 100644 index 0000000..de5f0a9 --- /dev/null +++ b/scientific-packages/diffdock/assets/custom_inference_config.yaml @@ -0,0 +1,90 @@ +# DiffDock Custom Inference Configuration Template +# Copy and modify this file to customize inference parameters + +# Model paths (usually don't need to change these) +model_dir: ./workdir/v1.1/score_model +confidence_model_dir: ./workdir/v1.1/confidence_model +ckpt: best_ema_inference_epoch_model.pt +confidence_ckpt: best_model_epoch75.pt + +# Model version flags +old_score_model: false # Set to true to use original DiffDock instead of DiffDock-L +old_filtering_model: true + +# Inference steps +inference_steps: 20 # Increase for potentially better accuracy (e.g., 25-30) +actual_steps: 19 +no_final_step_noise: true + +# Sampling parameters +samples_per_complex: 10 # Increase for difficult cases (e.g., 20-40) +sigma_schedule: expbeta +initial_noise_std_proportion: 1.46 + +# Temperature controls - Adjust these to balance exploration vs accuracy +# Higher values = more diverse predictions, lower values = more focused predictions + +# Sampling temperatures +temp_sampling_tr: 1.17 # Translation sampling temperature +temp_sampling_rot: 2.06 # Rotation sampling temperature +temp_sampling_tor: 7.04 # Torsion sampling temperature (increase for flexible ligands) + +# Psi angle temperatures +temp_psi_tr: 0.73 +temp_psi_rot: 0.90 +temp_psi_tor: 0.59 + +# Sigma data temperatures +temp_sigma_data_tr: 0.93 +temp_sigma_data_rot: 0.75 +temp_sigma_data_tor: 0.69 + +# Feature flags +no_model: false +no_random: false +ode: false # Set to true to use ODE solver instead of SDE +different_schedules: false +limit_failures: 5 + +# Output settings +# save_visualisation: true # Uncomment to save SDF files + +# ============================================================================ +# Configuration Presets for Common Use Cases +# ============================================================================ + +# PRESET 1: High Accuracy (slower, more thorough) +# samples_per_complex: 30 +# inference_steps: 25 +# temp_sampling_tr: 1.0 +# temp_sampling_rot: 1.8 +# temp_sampling_tor: 6.5 + +# PRESET 2: Fast Screening (faster, less thorough) +# samples_per_complex: 5 +# inference_steps: 15 +# temp_sampling_tr: 1.3 +# temp_sampling_rot: 2.2 +# temp_sampling_tor: 7.5 + +# PRESET 3: Flexible Ligands (more conformational diversity) +# samples_per_complex: 20 +# inference_steps: 20 +# temp_sampling_tr: 1.2 +# temp_sampling_rot: 2.1 +# temp_sampling_tor: 8.5 # Increased torsion temperature + +# PRESET 4: Rigid Ligands (more focused predictions) +# samples_per_complex: 10 +# inference_steps: 20 +# temp_sampling_tr: 1.1 +# temp_sampling_rot: 2.0 +# temp_sampling_tor: 6.0 # Decreased torsion temperature + +# ============================================================================ +# Usage Example +# ============================================================================ +# python -m inference \ +# --config custom_inference_config.yaml \ +# --protein_ligand_csv input.csv \ +# --out_dir results/ diff --git a/scientific-packages/diffdock/references/confidence_and_limitations.md b/scientific-packages/diffdock/references/confidence_and_limitations.md new file mode 100644 index 0000000..5610c58 --- /dev/null +++ b/scientific-packages/diffdock/references/confidence_and_limitations.md @@ -0,0 +1,182 @@ +# DiffDock Confidence Scores and Limitations + +This document provides detailed guidance on interpreting DiffDock confidence scores and understanding the tool's limitations. + +## Confidence Score Interpretation + +DiffDock generates a confidence score for each predicted binding pose. This score indicates the model's certainty about the prediction. + +### Score Ranges + +| Score Range | Confidence Level | Interpretation | +|------------|------------------|----------------| +| **> 0** | High confidence | Strong prediction, likely accurate binding pose | +| **-1.5 to 0** | Moderate confidence | Reasonable prediction, may need validation | +| **< -1.5** | Low confidence | Uncertain prediction, requires careful validation | + +### Important Notes on Confidence Scores + +1. **Not Binding Affinity**: Confidence scores reflect prediction certainty, NOT binding affinity strength + - High confidence = model is confident about the structure + - Does NOT indicate strong/weak binding affinity + +2. **Context-Dependent**: Confidence scores should be adjusted based on system complexity: + - **Lower expectations** for: + - Large ligands (>500 Da) + - Protein complexes with many chains + - Unbound protein conformations (may require conformational changes) + - Novel protein families not well-represented in training data + + - **Higher expectations** for: + - Drug-like small molecules (150-500 Da) + - Single-chain proteins or well-defined binding sites + - Proteins similar to those in training data (PDBBind, BindingMOAD) + +3. **Multiple Predictions**: DiffDock generates multiple samples per complex (default: 10) + - Review top-ranked predictions (by confidence) + - Consider clustering similar poses + - High-confidence consensus across multiple samples strengthens prediction + +## What DiffDock Predicts + +### ✅ DiffDock DOES Predict +- **Binding poses**: 3D spatial orientation of ligand in protein binding site +- **Confidence scores**: Model's certainty about predictions +- **Multiple conformations**: Various possible binding modes + +### ❌ DiffDock DOES NOT Predict +- **Binding affinity**: Strength of protein-ligand interaction (ΔG, Kd, Ki) +- **Binding kinetics**: On/off rates, residence time +- **ADMET properties**: Absorption, distribution, metabolism, excretion, toxicity +- **Selectivity**: Relative binding to different targets + +## Scope and Limitations + +### Designed For +- **Small molecule docking**: Organic compounds typically 100-1000 Da +- **Protein targets**: Single or multi-chain proteins +- **Small peptides**: Short peptide ligands (< ~20 residues) +- **Small nucleic acids**: Short oligonucleotides + +### NOT Designed For +- **Large biomolecules**: Full protein-protein interactions + - Use DiffDock-PP, AlphaFold-Multimer, or RoseTTAFold2NA instead +- **Large peptides/proteins**: >20 residues as ligands +- **Covalent docking**: Irreversible covalent bond formation +- **Metalloprotein specifics**: May not accurately handle metal coordination +- **Membrane proteins**: Not specifically trained on membrane-embedded proteins + +### Training Data Considerations + +DiffDock was trained on: +- **PDBBind**: Diverse protein-ligand complexes +- **BindingMOAD**: Multi-domain protein structures + +**Implications**: +- Best performance on proteins/ligands similar to training data +- May underperform on: + - Novel protein families + - Unusual ligand chemotypes + - Allosteric sites not well-represented in training data + +## Validation and Complementary Tools + +### Recommended Workflow + +1. **Generate poses with DiffDock** + - Use confidence scores for initial ranking + - Consider multiple high-confidence predictions + +2. **Visual Inspection** + - Examine protein-ligand interactions in molecular viewer + - Check for reasonable: + - Hydrogen bonds + - Hydrophobic interactions + - Steric complementarity + - Electrostatic interactions + +3. **Scoring and Refinement** (choose one or more): + - **GNINA**: Deep learning-based scoring function + - **Molecular mechanics**: Energy minimization and refinement + - **MM/GBSA or MM/PBSA**: Binding free energy estimation + - **Free energy calculations**: FEP or TI for accurate affinity prediction + +4. **Experimental Validation** + - Biochemical assays (IC50, Kd measurements) + - Structural validation (X-ray crystallography, cryo-EM) + +### Tools for Binding Affinity Assessment + +DiffDock should be combined with these tools for affinity prediction: + +- **GNINA**: Fast, accurate scoring function + - Github: github.com/gnina/gnina + +- **AutoDock Vina**: Classical docking and scoring + - Website: vina.scripps.edu + +- **Free Energy Calculations**: + - OpenMM + OpenFE + - GROMACS + ABFE/RBFE protocols + +- **MM/GBSA Tools**: + - MMPBSA.py (AmberTools) + - gmx_MMPBSA + +## Performance Optimization + +### For Best Results + +1. **Protein Preparation**: + - Remove water molecules far from binding site + - Resolve missing residues if possible + - Consider protonation states at physiological pH + +2. **Ligand Input**: + - Provide reasonable 3D conformers when using structure files + - Use canonical SMILES for consistent results + - Pre-process with RDKit if needed + +3. **Computational Resources**: + - GPU strongly recommended (10-100x speedup) + - First run pre-computes lookup tables (takes a few minutes) + - Batch processing more efficient than single predictions + +4. **Parameter Tuning**: + - Increase `samples_per_complex` for difficult cases (20-40) + - Adjust temperature parameters for diversity/accuracy trade-off + - Use pre-computed ESM embeddings for repeated predictions + +## Common Issues and Troubleshooting + +### Low Confidence Scores +- **Large/flexible ligands**: Consider splitting into fragments or use alternative methods +- **Multiple binding sites**: May predict multiple locations with distributed confidence +- **Protein flexibility**: Consider using ensemble of protein conformations + +### Unrealistic Predictions +- **Clashes**: May indicate need for protein preparation or refinement +- **Surface binding**: Check if true binding site is blocked or unclear +- **Unusual poses**: Consider increasing samples to explore more conformations + +### Slow Performance +- **Use GPU**: Essential for reasonable runtime +- **Pre-compute embeddings**: Reuse ESM embeddings for same protein +- **Batch processing**: More efficient than sequential individual predictions +- **Reduce samples**: Lower `samples_per_complex` for quick screening + +## Citation and Further Reading + +For methodology details and benchmarking results, see: + +1. **Original DiffDock Paper** (ICLR 2023): + - "DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking" + - Corso et al., arXiv:2210.01776 + +2. **DiffDock-L Paper** (2024): + - Enhanced model with improved generalization + - Stärk et al., arXiv:2402.18396 + +3. **PoseBusters Benchmark**: + - Rigorous docking evaluation framework + - Used for DiffDock validation diff --git a/scientific-packages/diffdock/references/parameters_reference.md b/scientific-packages/diffdock/references/parameters_reference.md new file mode 100644 index 0000000..86f822c --- /dev/null +++ b/scientific-packages/diffdock/references/parameters_reference.md @@ -0,0 +1,163 @@ +# DiffDock Configuration Parameters Reference + +This document provides comprehensive details on all DiffDock configuration parameters and command-line options. + +## Model & Checkpoint Settings + +### Model Paths +- **`--model_dir`**: Directory containing the score model checkpoint + - Default: `./workdir/v1.1/score_model` + - DiffDock-L model (current default) + +- **`--confidence_model_dir`**: Directory containing the confidence model checkpoint + - Default: `./workdir/v1.1/confidence_model` + +- **`--ckpt`**: Name of the score model checkpoint file + - Default: `best_ema_inference_epoch_model.pt` + +- **`--confidence_ckpt`**: Name of the confidence model checkpoint file + - Default: `best_model_epoch75.pt` + +### Model Version Flags +- **`--old_score_model`**: Use original DiffDock model instead of DiffDock-L + - Default: `false` (uses DiffDock-L) + +- **`--old_filtering_model`**: Use legacy confidence filtering approach + - Default: `true` + +## Input/Output Options + +### Input Specification +- **`--protein_path`**: Path to protein PDB file + - Example: `--protein_path protein.pdb` + - Alternative to `--protein_sequence` + +- **`--protein_sequence`**: Amino acid sequence for ESMFold folding + - Automatically generates protein structure from sequence + - Alternative to `--protein_path` + +- **`--ligand`**: Ligand specification (SMILES string or file path) + - SMILES string: `--ligand "COc(cc1)ccc1C#N"` + - File path: `--ligand ligand.sdf` or `.mol2` + +- **`--protein_ligand_csv`**: CSV file for batch processing + - Required columns: `complex_name`, `protein_path`, `ligand_description`, `protein_sequence` + - Example: `--protein_ligand_csv data/protein_ligand_example.csv` + +### Output Control +- **`--out_dir`**: Output directory for predictions + - Example: `--out_dir results/user_predictions/` + +- **`--save_visualisation`**: Export predicted molecules as SDF files + - Enables visualization of results + +## Inference Parameters + +### Diffusion Steps +- **`--inference_steps`**: Number of planned inference iterations + - Default: `20` + - Higher values may improve accuracy but increase runtime + +- **`--actual_steps`**: Actual diffusion steps executed + - Default: `19` + +- **`--no_final_step_noise`**: Omit noise at the final diffusion step + - Default: `true` + +### Sampling Settings +- **`--samples_per_complex`**: Number of samples to generate per complex + - Default: `10` + - More samples provide better coverage but increase computation + +- **`--sigma_schedule`**: Noise schedule type + - Default: `expbeta` (exponential-beta) + +- **`--initial_noise_std_proportion`**: Initial noise standard deviation scaling + - Default: `1.46` + +### Temperature Parameters + +#### Sampling Temperatures (Controls diversity of predictions) +- **`--temp_sampling_tr`**: Translation sampling temperature + - Default: `1.17` + +- **`--temp_sampling_rot`**: Rotation sampling temperature + - Default: `2.06` + +- **`--temp_sampling_tor`**: Torsion sampling temperature + - Default: `7.04` + +#### Psi Angle Temperatures +- **`--temp_psi_tr`**: Translation psi temperature + - Default: `0.73` + +- **`--temp_psi_rot`**: Rotation psi temperature + - Default: `0.90` + +- **`--temp_psi_tor`**: Torsion psi temperature + - Default: `0.59` + +#### Sigma Data Temperatures +- **`--temp_sigma_data_tr`**: Translation data distribution scaling + - Default: `0.93` + +- **`--temp_sigma_data_rot`**: Rotation data distribution scaling + - Default: `0.75` + +- **`--temp_sigma_data_tor`**: Torsion data distribution scaling + - Default: `0.69` + +## Processing Options + +### Performance +- **`--batch_size`**: Processing batch size + - Default: `10` + - Larger values increase throughput but require more memory + +- **`--tqdm`**: Enable progress bar visualization + - Useful for monitoring long-running jobs + +### Protein Structure +- **`--chain_cutoff`**: Maximum number of protein chains to process + - Example: `--chain_cutoff 10` + - Useful for large multi-chain complexes + +- **`--esm_embeddings_path`**: Path to pre-computed ESM2 protein embeddings + - Speeds up inference by reusing embeddings + - Optional optimization + +### Dataset Options +- **`--split`**: Dataset split to use (train/test/val) + - Used for evaluation on standard benchmarks + +## Advanced Flags + +### Debugging & Testing +- **`--no_model`**: Disable model inference (debugging) + - Default: `false` + +- **`--no_random`**: Disable randomization + - Default: `false` + - Useful for reproducibility testing + +### Alternative Sampling +- **`--ode`**: Use ODE solver instead of SDE + - Default: `false` + - Alternative sampling approach + +- **`--different_schedules`**: Use different noise schedules per component + - Default: `false` + +### Error Handling +- **`--limit_failures`**: Maximum allowed failures before stopping + - Default: `5` + +## Configuration File + +All parameters can be specified in a YAML configuration file (typically `default_inference_args.yaml`) or overridden via command line: + +```bash +python -m inference --config default_inference_args.yaml --samples_per_complex 20 +``` + +Command-line arguments take precedence over configuration file values. diff --git a/scientific-packages/diffdock/references/workflows_examples.md b/scientific-packages/diffdock/references/workflows_examples.md new file mode 100644 index 0000000..abedd68 --- /dev/null +++ b/scientific-packages/diffdock/references/workflows_examples.md @@ -0,0 +1,392 @@ +# DiffDock Workflows and Examples + +This document provides practical workflows and usage examples for common DiffDock tasks. + +## Installation and Setup + +### Conda Installation (Recommended) + +```bash +# Clone repository +git clone https://github.com/gcorso/DiffDock.git +cd DiffDock + +# Create conda environment +conda env create --file environment.yml +conda activate diffdock +``` + +### Docker Installation + +```bash +# Pull Docker image +docker pull rbgcsail/diffdock + +# Run container with GPU support +docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock + +# Inside container, activate environment +micromamba activate diffdock +``` + +### First Run +The first execution pre-computes SO(2) and SO(3) lookup tables, taking a few minutes. Subsequent runs start immediately. + +## Workflow 1: Single Protein-Ligand Docking + +### Using PDB File and SMILES String + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_path examples/protein.pdb \ + --ligand "COc1ccc(C(=O)Nc2ccccc2)cc1" \ + --out_dir results/single_docking/ +``` + +**Output Structure**: +``` +results/single_docking/ +├── index_0_rank_1.sdf # Top-ranked prediction +├── index_0_rank_2.sdf # Second-ranked prediction +├── ... +├── index_0_rank_10.sdf # 10th prediction (if samples_per_complex=10) +└── confidence_scores.txt # Scores for all predictions +``` + +### Using Ligand Structure File + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_path protein.pdb \ + --ligand ligand.sdf \ + --out_dir results/ligand_file/ +``` + +**Supported ligand formats**: SDF, MOL2, or any format readable by RDKit + +## Workflow 2: Protein Sequence to Structure Docking + +### Using ESMFold for Protein Folding + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_sequence "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK" \ + --ligand "CC(C)Cc1ccc(cc1)C(C)C(=O)O" \ + --out_dir results/sequence_docking/ +``` + +**Use Cases**: +- Protein structure not available in PDB +- Modeling mutations or variants +- De novo protein design validation + +**Note**: ESMFold folding adds computation time (30s-5min depending on sequence length) + +## Workflow 3: Batch Processing Multiple Complexes + +### Prepare CSV File + +Create `complexes.csv` with required columns: + +```csv +complex_name,protein_path,ligand_description,protein_sequence +complex1,proteins/protein1.pdb,CC(=O)Oc1ccccc1C(=O)O, +complex2,,COc1ccc(C#N)cc1,MSKGEELFTGVVPILVELDGDVNGHKF... +complex3,proteins/protein3.pdb,ligands/ligand3.sdf, +``` + +**Column Descriptions**: +- `complex_name`: Unique identifier for the complex +- `protein_path`: Path to PDB file (leave empty if using sequence) +- `ligand_description`: SMILES string or path to ligand file +- `protein_sequence`: Amino acid sequence (leave empty if using PDB) + +### Run Batch Docking + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv complexes.csv \ + --out_dir results/batch_predictions/ \ + --batch_size 10 +``` + +**Output Structure**: +``` +results/batch_predictions/ +├── complex1/ +│ ├── rank_1.sdf +│ ├── rank_2.sdf +│ └── ... +├── complex2/ +│ ├── rank_1.sdf +│ └── ... +└── complex3/ + └── ... +``` + +## Workflow 4: High-Throughput Virtual Screening + +### Setup for Screening Large Ligand Libraries + +```python +# generate_screening_csv.py +import pandas as pd + +# Load ligand library +ligands = pd.read_csv("ligand_library.csv") # Contains SMILES + +# Create DiffDock input +screening_data = { + "complex_name": [f"screen_{i}" for i in range(len(ligands))], + "protein_path": ["target_protein.pdb"] * len(ligands), + "ligand_description": ligands["smiles"].tolist(), + "protein_sequence": [""] * len(ligands) +} + +df = pd.DataFrame(screening_data) +df.to_csv("screening_input.csv", index=False) +``` + +### Run Screening + +```bash +# Pre-compute ESM embeddings for faster screening +python datasets/esm_embedding_preparation.py \ + --protein_ligand_csv screening_input.csv \ + --out_file protein_embeddings.pt + +# Run docking with pre-computed embeddings +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv screening_input.csv \ + --esm_embeddings_path protein_embeddings.pt \ + --out_dir results/virtual_screening/ \ + --batch_size 32 +``` + +### Post-Processing: Extract Top Hits + +```python +# analyze_screening_results.py +import os +import pandas as pd + +results = [] +results_dir = "results/virtual_screening/" + +for complex_dir in os.listdir(results_dir): + confidence_file = os.path.join(results_dir, complex_dir, "confidence_scores.txt") + if os.path.exists(confidence_file): + with open(confidence_file) as f: + scores = [float(line.strip()) for line in f] + top_score = max(scores) + results.append({"complex": complex_dir, "top_confidence": top_score}) + +# Sort by confidence +df = pd.DataFrame(results) +df_sorted = df.sort_values("top_confidence", ascending=False) + +# Get top 100 hits +top_hits = df_sorted.head(100) +top_hits.to_csv("top_hits.csv", index=False) +``` + +## Workflow 5: Ensemble Docking with Protein Flexibility + +### Prepare Protein Ensemble + +```python +# For proteins with known flexibility, use multiple conformations +# Example: Using MD snapshots or crystal structures + +# create_ensemble_csv.py +import pandas as pd + +conformations = [ + "protein_conf1.pdb", + "protein_conf2.pdb", + "protein_conf3.pdb", + "protein_conf4.pdb" +] + +ligand = "CC(C)Cc1ccc(cc1)C(C)C(=O)O" + +data = { + "complex_name": [f"ensemble_{i}" for i in range(len(conformations))], + "protein_path": conformations, + "ligand_description": [ligand] * len(conformations), + "protein_sequence": [""] * len(conformations) +} + +pd.DataFrame(data).to_csv("ensemble_input.csv", index=False) +``` + +### Run Ensemble Docking + +```bash +python -m inference \ + --config default_inference_args.yaml \ + --protein_ligand_csv ensemble_input.csv \ + --out_dir results/ensemble_docking/ \ + --samples_per_complex 20 # More samples per conformation +``` + +## Workflow 6: Integration with Downstream Analysis + +### Example: DiffDock + GNINA Rescoring + +```bash +# 1. Run DiffDock +python -m inference \ + --config default_inference_args.yaml \ + --protein_path protein.pdb \ + --ligand "CC(=O)OC1=CC=CC=C1C(=O)O" \ + --out_dir results/diffdock_poses/ \ + --save_visualisation + +# 2. Rescore with GNINA +for pose in results/diffdock_poses/*.sdf; do + gnina -r protein.pdb -l "$pose" --score_only -o "${pose%.sdf}_gnina.sdf" +done +``` + +### Example: DiffDock + OpenMM Energy Minimization + +```python +# minimize_poses.py +from openmm import app, LangevinIntegrator, Platform +from openmm.app import ForceField, Modeller, PDBFile +from rdkit import Chem +import os + +# Load protein +protein = PDBFile('protein.pdb') +forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml') + +# Process each DiffDock pose +pose_dir = 'results/diffdock_poses/' +for pose_file in os.listdir(pose_dir): + if pose_file.endswith('.sdf'): + # Load ligand + mol = Chem.SDMolSupplier(os.path.join(pose_dir, pose_file))[0] + + # Combine protein + ligand + modeller = Modeller(protein.topology, protein.positions) + # ... add ligand to modeller ... + + # Create system and minimize + system = forcefield.createSystem(modeller.topology) + integrator = LangevinIntegrator(300, 1.0, 0.002) + simulation = app.Simulation(modeller.topology, system, integrator) + simulation.minimizeEnergy(maxIterations=1000) + + # Save minimized structure + positions = simulation.context.getState(getPositions=True).getPositions() + PDBFile.writeFile(simulation.topology, positions, + open(f"minimized_{pose_file}.pdb", 'w')) +``` + +## Workflow 7: Using the Graphical Interface + +### Launch Web Interface + +```bash +python app/main.py +``` + +### Access Interface +Navigate to `http://localhost:7860` in web browser + +### Features +- Upload protein PDB or enter sequence +- Input ligand SMILES or upload structure +- Adjust inference parameters via GUI +- Visualize results interactively +- Download predictions directly + +### Online Alternative +Use the Hugging Face Spaces demo without local installation: +- URL: https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web + +## Advanced Configuration + +### Custom Inference Settings + +Create custom YAML configuration: + +```yaml +# custom_inference.yaml +# Model settings +model_dir: ./workdir/v1.1/score_model +confidence_model_dir: ./workdir/v1.1/confidence_model + +# Sampling parameters +samples_per_complex: 20 # More samples for better coverage +inference_steps: 25 # More steps for accuracy + +# Temperature adjustments (increase for more diversity) +temp_sampling_tr: 1.3 +temp_sampling_rot: 2.2 +temp_sampling_tor: 7.5 + +# Output +save_visualisation: true +``` + +Use custom configuration: + +```bash +python -m inference \ + --config custom_inference.yaml \ + --protein_path protein.pdb \ + --ligand "CC(=O)OC1=CC=CC=C1C(=O)O" \ + --out_dir results/custom_config/ +``` + +## Troubleshooting Common Issues + +### Issue: Out of Memory Errors + +**Solution**: Reduce batch size +```bash +python -m inference ... --batch_size 2 +``` + +### Issue: Slow Performance + +**Solution**: Ensure GPU usage +```python +import torch +print(torch.cuda.is_available()) # Should return True +``` + +### Issue: Poor Predictions for Large Ligands + +**Solution**: Increase sampling diversity +```bash +python -m inference ... --samples_per_complex 40 --temp_sampling_tor 9.0 +``` + +### Issue: Protein with Many Chains + +**Solution**: Limit chains or isolate binding site +```bash +python -m inference ... --chain_cutoff 4 +``` + +Or pre-process PDB to include only relevant chains. + +## Best Practices Summary + +1. **Start Simple**: Test with single complex before batch processing +2. **GPU Essential**: Use GPU for reasonable performance +3. **Multiple Samples**: Generate 10-40 samples for robust predictions +4. **Validate Results**: Use molecular visualization and complementary scoring +5. **Consider Confidence**: Use confidence scores for initial ranking, not final decisions +6. **Iterate Parameters**: Adjust temperature/steps for specific systems +7. **Pre-compute Embeddings**: For repeated use of same protein +8. **Combine Tools**: Integrate with scoring functions and energy minimization diff --git a/scientific-packages/diffdock/scripts/analyze_results.py b/scientific-packages/diffdock/scripts/analyze_results.py new file mode 100755 index 0000000..4eec2dd --- /dev/null +++ b/scientific-packages/diffdock/scripts/analyze_results.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +""" +DiffDock Results Analysis Script + +This script analyzes DiffDock prediction results, extracting confidence scores, +ranking predictions, and generating summary reports. + +Usage: + python analyze_results.py results/output_dir/ + python analyze_results.py results/ --top 50 --threshold 0.0 + python analyze_results.py results/ --export summary.csv +""" + +import argparse +import os +import sys +import json +from pathlib import Path +from collections import defaultdict +import re + + +def parse_confidence_scores(results_dir): + """ + Parse confidence scores from DiffDock output directory. + + Args: + results_dir: Path to DiffDock results directory + + Returns: + dict: Dictionary mapping complex names to their predictions and scores + """ + results = {} + results_path = Path(results_dir) + + # Check if this is a single complex or batch results + sdf_files = list(results_path.glob("*.sdf")) + + if sdf_files: + # Single complex output + results['single_complex'] = parse_single_complex(results_path) + else: + # Batch output - multiple subdirectories + for subdir in results_path.iterdir(): + if subdir.is_dir(): + complex_results = parse_single_complex(subdir) + if complex_results: + results[subdir.name] = complex_results + + return results + + +def parse_single_complex(complex_dir): + """Parse results for a single complex.""" + predictions = [] + + # Look for SDF files with rank information + for sdf_file in complex_dir.glob("*.sdf"): + filename = sdf_file.name + + # Extract rank from filename (e.g., "rank_1.sdf" or "index_0_rank_1.sdf") + rank_match = re.search(r'rank_(\d+)', filename) + if rank_match: + rank = int(rank_match.group(1)) + + # Try to extract confidence score from filename or separate file + confidence = extract_confidence_score(sdf_file, complex_dir) + + predictions.append({ + 'rank': rank, + 'file': sdf_file.name, + 'path': str(sdf_file), + 'confidence': confidence + }) + + # Sort by rank + predictions.sort(key=lambda x: x['rank']) + + return {'predictions': predictions} if predictions else None + + +def extract_confidence_score(sdf_file, complex_dir): + """ + Extract confidence score for a prediction. + + Tries multiple methods: + 1. Read from confidence_scores.txt file + 2. Parse from SDF file properties + 3. Extract from filename if present + """ + # Method 1: confidence_scores.txt + confidence_file = complex_dir / "confidence_scores.txt" + if confidence_file.exists(): + try: + with open(confidence_file) as f: + lines = f.readlines() + # Extract rank from filename + rank_match = re.search(r'rank_(\d+)', sdf_file.name) + if rank_match: + rank = int(rank_match.group(1)) + if rank <= len(lines): + return float(lines[rank - 1].strip()) + except Exception: + pass + + # Method 2: Parse from SDF file + try: + with open(sdf_file) as f: + content = f.read() + # Look for confidence score in SDF properties + conf_match = re.search(r'confidence[:\s]+(-?\d+\.?\d*)', content, re.IGNORECASE) + if conf_match: + return float(conf_match.group(1)) + except Exception: + pass + + # Method 3: Filename (e.g., "rank_1_conf_0.95.sdf") + conf_match = re.search(r'conf_(-?\d+\.?\d*)', sdf_file.name) + if conf_match: + return float(conf_match.group(1)) + + return None + + +def classify_confidence(score): + """Classify confidence score into categories.""" + if score is None: + return "Unknown" + elif score > 0: + return "High" + elif score > -1.5: + return "Moderate" + else: + return "Low" + + +def print_summary(results, top_n=None, min_confidence=None): + """Print a formatted summary of results.""" + + print("\n" + "="*80) + print("DiffDock Results Summary") + print("="*80) + + all_predictions = [] + + for complex_name, data in results.items(): + predictions = data.get('predictions', []) + + print(f"\n{complex_name}") + print("-" * 80) + + if not predictions: + print(" No predictions found") + continue + + # Filter by confidence if specified + filtered_predictions = predictions + if min_confidence is not None: + filtered_predictions = [p for p in predictions if p['confidence'] is not None and p['confidence'] >= min_confidence] + + # Limit to top N if specified + if top_n is not None: + filtered_predictions = filtered_predictions[:top_n] + + for pred in filtered_predictions: + confidence = pred['confidence'] + confidence_class = classify_confidence(confidence) + + conf_str = f"{confidence:>7.3f}" if confidence is not None else " N/A" + print(f" Rank {pred['rank']:2d}: Confidence = {conf_str} ({confidence_class:8s}) | {pred['file']}") + + # Add to all predictions for overall statistics + if confidence is not None: + all_predictions.append((complex_name, pred['rank'], confidence)) + + # Show statistics for this complex + if filtered_predictions and any(p['confidence'] is not None for p in filtered_predictions): + confidences = [p['confidence'] for p in filtered_predictions if p['confidence'] is not None] + print(f"\n Statistics: {len(filtered_predictions)} predictions") + print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}") + print(f" Max confidence: {max(confidences):.3f}") + print(f" Min confidence: {min(confidences):.3f}") + + # Overall statistics + if all_predictions: + print("\n" + "="*80) + print("Overall Statistics") + print("="*80) + + confidences = [conf for _, _, conf in all_predictions] + print(f" Total predictions: {len(all_predictions)}") + print(f" Total complexes: {len(results)}") + print(f" Mean confidence: {sum(confidences)/len(confidences):.3f}") + print(f" Max confidence: {max(confidences):.3f}") + print(f" Min confidence: {min(confidences):.3f}") + + # Confidence distribution + high = sum(1 for c in confidences if c > 0) + moderate = sum(1 for c in confidences if -1.5 < c <= 0) + low = sum(1 for c in confidences if c <= -1.5) + + print(f"\n Confidence distribution:") + print(f" High (> 0): {high:4d} ({100*high/len(confidences):5.1f}%)") + print(f" Moderate (-1.5 to 0): {moderate:4d} ({100*moderate/len(confidences):5.1f}%)") + print(f" Low (< -1.5): {low:4d} ({100*low/len(confidences):5.1f}%)") + + print("\n" + "="*80) + + +def export_to_csv(results, output_path): + """Export results to CSV file.""" + import csv + + with open(output_path, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['complex_name', 'rank', 'confidence', 'confidence_class', 'file_path']) + + for complex_name, data in results.items(): + predictions = data.get('predictions', []) + for pred in predictions: + confidence = pred['confidence'] + confidence_class = classify_confidence(confidence) + conf_value = confidence if confidence is not None else '' + + writer.writerow([ + complex_name, + pred['rank'], + conf_value, + confidence_class, + pred['path'] + ]) + + print(f"✓ Exported results to: {output_path}") + + +def get_top_predictions(results, n=10, sort_by='confidence'): + """Get top N predictions across all complexes.""" + all_predictions = [] + + for complex_name, data in results.items(): + predictions = data.get('predictions', []) + for pred in predictions: + if pred['confidence'] is not None: + all_predictions.append({ + 'complex': complex_name, + **pred + }) + + # Sort by confidence (descending) + all_predictions.sort(key=lambda x: x['confidence'], reverse=True) + + return all_predictions[:n] + + +def print_top_predictions(results, n=10): + """Print top N predictions across all complexes.""" + top_preds = get_top_predictions(results, n) + + print("\n" + "="*80) + print(f"Top {n} Predictions Across All Complexes") + print("="*80) + + for i, pred in enumerate(top_preds, 1): + confidence_class = classify_confidence(pred['confidence']) + print(f"{i:2d}. {pred['complex']:30s} | Rank {pred['rank']:2d} | " + f"Confidence: {pred['confidence']:7.3f} ({confidence_class})") + + print("="*80) + + +def main(): + parser = argparse.ArgumentParser( + description='Analyze DiffDock prediction results', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Analyze all results in directory + python analyze_results.py results/output_dir/ + + # Show only top 5 predictions per complex + python analyze_results.py results/ --top 5 + + # Filter by confidence threshold + python analyze_results.py results/ --threshold 0.0 + + # Export to CSV + python analyze_results.py results/ --export summary.csv + + # Show top 20 predictions across all complexes + python analyze_results.py results/ --best 20 + """ + ) + + parser.add_argument('results_dir', help='Path to DiffDock results directory') + parser.add_argument('--top', '-t', type=int, + help='Show only top N predictions per complex') + parser.add_argument('--threshold', type=float, + help='Minimum confidence threshold') + parser.add_argument('--export', '-e', metavar='FILE', + help='Export results to CSV file') + parser.add_argument('--best', '-b', type=int, metavar='N', + help='Show top N predictions across all complexes') + + args = parser.parse_args() + + # Validate results directory + if not os.path.exists(args.results_dir): + print(f"Error: Results directory not found: {args.results_dir}") + return 1 + + # Parse results + print(f"Analyzing results in: {args.results_dir}") + results = parse_confidence_scores(args.results_dir) + + if not results: + print("No DiffDock results found in directory") + return 1 + + # Print summary + print_summary(results, top_n=args.top, min_confidence=args.threshold) + + # Print top predictions across all complexes + if args.best: + print_top_predictions(results, args.best) + + # Export to CSV if requested + if args.export: + export_to_csv(results, args.export) + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/diffdock/scripts/prepare_batch_csv.py b/scientific-packages/diffdock/scripts/prepare_batch_csv.py new file mode 100755 index 0000000..24f4ef4 --- /dev/null +++ b/scientific-packages/diffdock/scripts/prepare_batch_csv.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +DiffDock Batch CSV Preparation and Validation Script + +This script helps prepare and validate CSV files for DiffDock batch processing. +It checks for required columns, validates file paths, and ensures SMILES strings +are properly formatted. + +Usage: + python prepare_batch_csv.py input.csv --validate + python prepare_batch_csv.py --create --output batch_input.csv +""" + +import argparse +import os +import sys +import pandas as pd +from pathlib import Path + +try: + from rdkit import Chem + from rdkit import RDLogger + RDLogger.DisableLog('rdApp.*') + RDKIT_AVAILABLE = True +except ImportError: + RDKIT_AVAILABLE = False + print("Warning: RDKit not available. SMILES validation will be skipped.") + + +def validate_smiles(smiles_string): + """Validate a SMILES string using RDKit.""" + if not RDKIT_AVAILABLE: + return True, "RDKit not available for validation" + + try: + mol = Chem.MolFromSmiles(smiles_string) + if mol is None: + return False, "Invalid SMILES structure" + return True, "Valid SMILES" + except Exception as e: + return False, str(e) + + +def validate_file_path(file_path, base_dir=None): + """Validate that a file path exists.""" + if pd.isna(file_path) or file_path == "": + return True, "Empty (will use protein_sequence)" + + # Handle relative paths + if base_dir: + full_path = Path(base_dir) / file_path + else: + full_path = Path(file_path) + + if full_path.exists(): + return True, f"File exists: {full_path}" + else: + return False, f"File not found: {full_path}" + + +def validate_csv(csv_path, base_dir=None): + """ + Validate a DiffDock batch input CSV file. + + Args: + csv_path: Path to CSV file + base_dir: Base directory for relative paths (default: CSV directory) + + Returns: + bool: True if validation passes + list: List of validation messages + """ + messages = [] + valid = True + + # Read CSV + try: + df = pd.read_csv(csv_path) + messages.append(f"✓ Successfully read CSV with {len(df)} rows") + except Exception as e: + messages.append(f"✗ Error reading CSV: {e}") + return False, messages + + # Check required columns + required_cols = ['complex_name', 'protein_path', 'ligand_description', 'protein_sequence'] + missing_cols = [col for col in required_cols if col not in df.columns] + + if missing_cols: + messages.append(f"✗ Missing required columns: {', '.join(missing_cols)}") + valid = False + else: + messages.append("✓ All required columns present") + + # Set base directory + if base_dir is None: + base_dir = Path(csv_path).parent + + # Validate each row + for idx, row in df.iterrows(): + row_msgs = [] + + # Check complex name + if pd.isna(row['complex_name']) or row['complex_name'] == "": + row_msgs.append("Missing complex_name") + valid = False + + # Check that either protein_path or protein_sequence is provided + has_protein_path = not pd.isna(row['protein_path']) and row['protein_path'] != "" + has_protein_seq = not pd.isna(row['protein_sequence']) and row['protein_sequence'] != "" + + if not has_protein_path and not has_protein_seq: + row_msgs.append("Must provide either protein_path or protein_sequence") + valid = False + elif has_protein_path and has_protein_seq: + row_msgs.append("Warning: Both protein_path and protein_sequence provided, will use protein_path") + + # Validate protein path if provided + if has_protein_path: + file_valid, msg = validate_file_path(row['protein_path'], base_dir) + if not file_valid: + row_msgs.append(f"Protein file issue: {msg}") + valid = False + + # Validate ligand description + if pd.isna(row['ligand_description']) or row['ligand_description'] == "": + row_msgs.append("Missing ligand_description") + valid = False + else: + ligand_desc = row['ligand_description'] + # Check if it's a file path or SMILES + if os.path.exists(ligand_desc) or "/" in ligand_desc or "\\" in ligand_desc: + # Likely a file path + file_valid, msg = validate_file_path(ligand_desc, base_dir) + if not file_valid: + row_msgs.append(f"Ligand file issue: {msg}") + valid = False + else: + # Likely a SMILES string + smiles_valid, msg = validate_smiles(ligand_desc) + if not smiles_valid: + row_msgs.append(f"SMILES issue: {msg}") + valid = False + + if row_msgs: + messages.append(f"\nRow {idx + 1} ({row.get('complex_name', 'unnamed')}):") + for msg in row_msgs: + messages.append(f" - {msg}") + + # Summary + messages.append(f"\n{'='*60}") + if valid: + messages.append("✓ CSV validation PASSED - ready for DiffDock") + else: + messages.append("✗ CSV validation FAILED - please fix issues above") + + return valid, messages + + +def create_template_csv(output_path, num_examples=3): + """Create a template CSV file with example entries.""" + + examples = { + 'complex_name': ['example1', 'example2', 'example3'][:num_examples], + 'protein_path': ['protein1.pdb', '', 'protein3.pdb'][:num_examples], + 'ligand_description': [ + 'CC(=O)Oc1ccccc1C(=O)O', # Aspirin SMILES + 'COc1ccc(C#N)cc1', # Example SMILES + 'ligand.sdf' # Example file path + ][:num_examples], + 'protein_sequence': [ + '', # Empty - using PDB file + 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', # GFP sequence + '' # Empty - using PDB file + ][:num_examples] + } + + df = pd.DataFrame(examples) + df.to_csv(output_path, index=False) + + return df + + +def main(): + parser = argparse.ArgumentParser( + description='Prepare and validate DiffDock batch CSV files', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Validate existing CSV + python prepare_batch_csv.py input.csv --validate + + # Create template CSV + python prepare_batch_csv.py --create --output batch_template.csv + + # Create template with 5 example rows + python prepare_batch_csv.py --create --output template.csv --num-examples 5 + + # Validate with custom base directory for relative paths + python prepare_batch_csv.py input.csv --validate --base-dir /path/to/data/ + """ + ) + + parser.add_argument('csv_file', nargs='?', help='CSV file to validate') + parser.add_argument('--validate', action='store_true', + help='Validate the CSV file') + parser.add_argument('--create', action='store_true', + help='Create a template CSV file') + parser.add_argument('--output', '-o', help='Output path for template CSV') + parser.add_argument('--num-examples', type=int, default=3, + help='Number of example rows in template (default: 3)') + parser.add_argument('--base-dir', help='Base directory for relative file paths') + + args = parser.parse_args() + + # Create template + if args.create: + output_path = args.output or 'diffdock_batch_template.csv' + df = create_template_csv(output_path, args.num_examples) + print(f"✓ Created template CSV: {output_path}") + print(f"\nTemplate contents:") + print(df.to_string(index=False)) + print(f"\nEdit this file with your protein-ligand pairs and run with:") + print(f" python -m inference --config default_inference_args.yaml \\") + print(f" --protein_ligand_csv {output_path} --out_dir results/") + return 0 + + # Validate CSV + if args.validate or args.csv_file: + if not args.csv_file: + print("Error: CSV file required for validation") + parser.print_help() + return 1 + + if not os.path.exists(args.csv_file): + print(f"Error: CSV file not found: {args.csv_file}") + return 1 + + print(f"Validating: {args.csv_file}") + print("="*60) + + valid, messages = validate_csv(args.csv_file, args.base_dir) + + for msg in messages: + print(msg) + + return 0 if valid else 1 + + # No action specified + parser.print_help() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/diffdock/scripts/setup_check.py b/scientific-packages/diffdock/scripts/setup_check.py new file mode 100755 index 0000000..950c36a --- /dev/null +++ b/scientific-packages/diffdock/scripts/setup_check.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +""" +DiffDock Environment Setup Checker + +This script verifies that the DiffDock environment is properly configured +and all dependencies are available. + +Usage: + python setup_check.py + python setup_check.py --verbose +""" + +import argparse +import sys +import os +from pathlib import Path + + +def check_python_version(): + """Check Python version.""" + import sys + version = sys.version_info + + print("Checking Python version...") + if version.major == 3 and version.minor >= 8: + print(f" ✓ Python {version.major}.{version.minor}.{version.micro}") + return True + else: + print(f" ✗ Python {version.major}.{version.minor}.{version.micro} " + f"(requires Python 3.8 or higher)") + return False + + +def check_package(package_name, import_name=None, version_attr='__version__'): + """Check if a Python package is installed.""" + if import_name is None: + import_name = package_name + + try: + module = __import__(import_name) + version = getattr(module, version_attr, 'unknown') + print(f" ✓ {package_name:20s} (version: {version})") + return True + except ImportError: + print(f" ✗ {package_name:20s} (not installed)") + return False + + +def check_pytorch(): + """Check PyTorch installation and CUDA availability.""" + print("\nChecking PyTorch...") + try: + import torch + print(f" ✓ PyTorch version: {torch.__version__}") + + # Check CUDA + if torch.cuda.is_available(): + print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}") + print(f" - CUDA version: {torch.version.cuda}") + print(f" - Number of GPUs: {torch.cuda.device_count()}") + return True, True + else: + print(f" ⚠ CUDA not available (will run on CPU)") + return True, False + except ImportError: + print(f" ✗ PyTorch not installed") + return False, False + + +def check_pytorch_geometric(): + """Check PyTorch Geometric installation.""" + print("\nChecking PyTorch Geometric...") + packages = [ + ('torch-geometric', 'torch_geometric'), + ('torch-scatter', 'torch_scatter'), + ('torch-sparse', 'torch_sparse'), + ('torch-cluster', 'torch_cluster'), + ] + + all_ok = True + for pkg_name, import_name in packages: + if not check_package(pkg_name, import_name): + all_ok = False + + return all_ok + + +def check_core_dependencies(): + """Check core DiffDock dependencies.""" + print("\nChecking core dependencies...") + + dependencies = [ + ('numpy', 'numpy'), + ('scipy', 'scipy'), + ('pandas', 'pandas'), + ('rdkit', 'rdkit', 'rdBase.__version__'), + ('biopython', 'Bio', '__version__'), + ('pytorch-lightning', 'pytorch_lightning'), + ('PyYAML', 'yaml'), + ] + + all_ok = True + for dep in dependencies: + pkg_name = dep[0] + import_name = dep[1] + version_attr = dep[2] if len(dep) > 2 else '__version__' + + if not check_package(pkg_name, import_name, version_attr): + all_ok = False + + return all_ok + + +def check_esm(): + """Check ESM (protein language model) installation.""" + print("\nChecking ESM (for protein sequence folding)...") + try: + import esm + print(f" ✓ ESM installed (version: {esm.__version__ if hasattr(esm, '__version__') else 'unknown'})") + return True + except ImportError: + print(f" ⚠ ESM not installed (needed for protein sequence folding)") + print(f" Install with: pip install fair-esm") + return False + + +def check_diffdock_installation(): + """Check if DiffDock is properly installed/cloned.""" + print("\nChecking DiffDock installation...") + + # Look for key files + key_files = [ + 'inference.py', + 'default_inference_args.yaml', + 'environment.yml', + ] + + found_files = [] + missing_files = [] + + for filename in key_files: + if os.path.exists(filename): + found_files.append(filename) + else: + missing_files.append(filename) + + if found_files: + print(f" ✓ Found DiffDock files in current directory:") + for f in found_files: + print(f" - {f}") + else: + print(f" ⚠ DiffDock files not found in current directory") + print(f" Current directory: {os.getcwd()}") + print(f" Make sure you're in the DiffDock repository root") + + # Check for model checkpoints + model_dir = Path('./workdir/v1.1/score_model') + confidence_dir = Path('./workdir/v1.1/confidence_model') + + if model_dir.exists() and confidence_dir.exists(): + print(f" ✓ Model checkpoints found") + else: + print(f" ⚠ Model checkpoints not found in ./workdir/v1.1/") + print(f" Models will be downloaded on first run") + + return len(found_files) > 0 + + +def print_installation_instructions(): + """Print installation instructions if setup is incomplete.""" + print("\n" + "="*80) + print("Installation Instructions") + print("="*80) + + print(""" +If DiffDock is not installed, follow these steps: + +1. Clone the repository: + git clone https://github.com/gcorso/DiffDock.git + cd DiffDock + +2. Create conda environment: + conda env create --file environment.yml + conda activate diffdock + +3. Verify installation: + python setup_check.py + +For Docker installation: + docker pull rbgcsail/diffdock + docker run -it --gpus all --entrypoint /bin/bash rbgcsail/diffdock + micromamba activate diffdock + +For more information, visit: https://github.com/gcorso/DiffDock + """) + + +def print_performance_notes(has_cuda): + """Print performance notes based on available hardware.""" + print("\n" + "="*80) + print("Performance Notes") + print("="*80) + + if has_cuda: + print(""" +✓ GPU detected - DiffDock will run efficiently + +Expected performance: + - First run: ~2-5 minutes (pre-computing SO(2)/SO(3) tables) + - Subsequent runs: ~10-60 seconds per complex (depending on settings) + - Batch processing: Highly efficient with GPU + """) + else: + print(""" +⚠ No GPU detected - DiffDock will run on CPU + +Expected performance: + - CPU inference is SIGNIFICANTLY slower than GPU + - Single complex: Several minutes to hours + - Batch processing: Not recommended on CPU + +Recommendation: Use GPU for practical applications + - Cloud options: Google Colab, AWS, or other cloud GPU services + - Local: Install CUDA-capable GPU + """) + + +def main(): + parser = argparse.ArgumentParser( + description='Check DiffDock environment setup', + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument('--verbose', '-v', action='store_true', + help='Show detailed version information') + + args = parser.parse_args() + + print("="*80) + print("DiffDock Environment Setup Checker") + print("="*80) + + checks = [] + + # Run all checks + checks.append(("Python version", check_python_version())) + + pytorch_ok, has_cuda = check_pytorch() + checks.append(("PyTorch", pytorch_ok)) + + checks.append(("PyTorch Geometric", check_pytorch_geometric())) + checks.append(("Core dependencies", check_core_dependencies())) + checks.append(("ESM", check_esm())) + checks.append(("DiffDock files", check_diffdock_installation())) + + # Summary + print("\n" + "="*80) + print("Summary") + print("="*80) + + all_passed = all(result for _, result in checks) + + for check_name, result in checks: + status = "✓ PASS" if result else "✗ FAIL" + print(f" {status:8s} - {check_name}") + + if all_passed: + print("\n✓ All checks passed! DiffDock is ready to use.") + print_performance_notes(has_cuda) + return 0 + else: + print("\n✗ Some checks failed. Please install missing dependencies.") + print_installation_instructions() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scientific-packages/etetoolkit/SKILL.md b/scientific-packages/etetoolkit/SKILL.md new file mode 100644 index 0000000..a696171 --- /dev/null +++ b/scientific-packages/etetoolkit/SKILL.md @@ -0,0 +1,617 @@ +--- +name: etetoolkit +description: Comprehensive toolkit for phylogenetic and hierarchical tree analysis using the ETE (Environment for Tree Exploration) Python library. This skill should be used when working with phylogenetic trees, gene trees, species trees, clustering dendrograms, or any hierarchical tree structures. Applies to tasks involving tree manipulation (pruning, rerooting, format conversion), evolutionary analysis (orthology detection, duplication/speciation events), tree comparison (Robinson-Foulds distance), NCBI taxonomy integration, tree visualization (PDF, SVG, PNG output), and clustering analysis with heatmaps. +--- + +# ETE Toolkit Skill + +## Overview + +Provide comprehensive support for phylogenetic and hierarchical tree analysis using the ETE (Environment for Tree Exploration) toolkit. Enable tree manipulation, evolutionary analysis, visualization, and integration with biological databases for phylogenomic research and clustering analysis. + +## Core Capabilities + +### 1. Tree Manipulation and Analysis + +Load, manipulate, and analyze hierarchical tree structures with support for: + +- **Tree I/O**: Read and write Newick, NHX, PhyloXML, and NeXML formats +- **Tree traversal**: Navigate trees using preorder, postorder, or levelorder strategies +- **Topology modification**: Prune, root, collapse nodes, resolve polytomies +- **Distance calculations**: Compute branch lengths and topological distances between nodes +- **Tree comparison**: Calculate Robinson-Foulds distances and identify topological differences + +**Common patterns:** + +```python +from ete3 import Tree + +# Load tree from file +tree = Tree("tree.nw", format=1) + +# Basic statistics +print(f"Leaves: {len(tree)}") +print(f"Total nodes: {len(list(tree.traverse()))}") + +# Prune to taxa of interest +taxa_to_keep = ["species1", "species2", "species3"] +tree.prune(taxa_to_keep, preserve_branch_length=True) + +# Midpoint root +midpoint = tree.get_midpoint_outgroup() +tree.set_outgroup(midpoint) + +# Save modified tree +tree.write(outfile="rooted_tree.nw") +``` + +Use `scripts/tree_operations.py` for command-line tree manipulation: + +```bash +# Display tree statistics +python scripts/tree_operations.py stats tree.nw + +# Convert format +python scripts/tree_operations.py convert tree.nw output.nw --in-format 0 --out-format 1 + +# Reroot tree +python scripts/tree_operations.py reroot tree.nw rooted.nw --midpoint + +# Prune to specific taxa +python scripts/tree_operations.py prune tree.nw pruned.nw --keep-taxa "sp1,sp2,sp3" + +# Show ASCII visualization +python scripts/tree_operations.py ascii tree.nw +``` + +### 2. Phylogenetic Analysis + +Analyze gene trees with evolutionary event detection: + +- **Sequence alignment integration**: Link trees to multiple sequence alignments (FASTA, Phylip) +- **Species naming**: Automatic or custom species extraction from gene names +- **Evolutionary events**: Detect duplication and speciation events using Species Overlap or tree reconciliation +- **Orthology detection**: Identify orthologs and paralogs based on evolutionary events +- **Gene family analysis**: Split trees by duplications, collapse lineage-specific expansions + +**Workflow for gene tree analysis:** + +```python +from ete3 import PhyloTree + +# Load gene tree with alignment +tree = PhyloTree("gene_tree.nw", alignment="alignment.fasta") + +# Set species naming function +def get_species(gene_name): + return gene_name.split("_")[0] + +tree.set_species_naming_function(get_species) + +# Detect evolutionary events +events = tree.get_descendant_evol_events() + +# Analyze events +for node in tree.traverse(): + if hasattr(node, "evoltype"): + if node.evoltype == "D": + print(f"Duplication at {node.name}") + elif node.evoltype == "S": + print(f"Speciation at {node.name}") + +# Extract ortholog groups +ortho_groups = tree.get_speciation_trees() +for i, ortho_tree in enumerate(ortho_groups): + ortho_tree.write(outfile=f"ortholog_group_{i}.nw") +``` + +**Finding orthologs and paralogs:** + +```python +# Find orthologs to query gene +query = tree & "species1_gene1" + +orthologs = [] +paralogs = [] + +for event in events: + if query in event.in_seqs: + if event.etype == "S": + orthologs.extend([s for s in event.out_seqs if s != query]) + elif event.etype == "D": + paralogs.extend([s for s in event.out_seqs if s != query]) +``` + +### 3. NCBI Taxonomy Integration + +Integrate taxonomic information from NCBI Taxonomy database: + +- **Database access**: Automatic download and local caching of NCBI taxonomy (~300MB) +- **Taxid/name translation**: Convert between taxonomic IDs and scientific names +- **Lineage retrieval**: Get complete evolutionary lineages +- **Taxonomy trees**: Build species trees connecting specified taxa +- **Tree annotation**: Automatically annotate trees with taxonomic information + +**Building taxonomy-based trees:** + +```python +from ete3 import NCBITaxa + +ncbi = NCBITaxa() + +# Build tree from species names +species = ["Homo sapiens", "Pan troglodytes", "Mus musculus"] +name2taxid = ncbi.get_name_translator(species) +taxids = [name2taxid[sp][0] for sp in species] + +# Get minimal tree connecting taxa +tree = ncbi.get_topology(taxids) + +# Annotate nodes with taxonomy info +for node in tree.traverse(): + if hasattr(node, "sci_name"): + print(f"{node.sci_name} - Rank: {node.rank} - TaxID: {node.taxid}") +``` + +**Annotating existing trees:** + +```python +# Get taxonomy info for tree leaves +for leaf in tree: + species = extract_species_from_name(leaf.name) + taxid = ncbi.get_name_translator([species])[species][0] + + # Get lineage + lineage = ncbi.get_lineage(taxid) + ranks = ncbi.get_rank(lineage) + names = ncbi.get_taxid_translator(lineage) + + # Add to node + leaf.add_feature("taxid", taxid) + leaf.add_feature("lineage", [names[t] for t in lineage]) +``` + +### 4. Tree Visualization + +Create publication-quality tree visualizations: + +- **Output formats**: PNG (raster), PDF, and SVG (vector) for publications +- **Layout modes**: Rectangular and circular tree layouts +- **Interactive GUI**: Explore trees interactively with zoom, pan, and search +- **Custom styling**: NodeStyle for node appearance (colors, shapes, sizes) +- **Faces**: Add graphical elements (text, images, charts, heatmaps) to nodes +- **Layout functions**: Dynamic styling based on node properties + +**Basic visualization workflow:** + +```python +from ete3 import Tree, TreeStyle, NodeStyle + +tree = Tree("tree.nw") + +# Configure tree style +ts = TreeStyle() +ts.show_leaf_name = True +ts.show_branch_support = True +ts.scale = 50 # pixels per branch length unit + +# Style nodes +for node in tree.traverse(): + nstyle = NodeStyle() + + if node.is_leaf(): + nstyle["fgcolor"] = "blue" + nstyle["size"] = 8 + else: + # Color by support + if node.support > 0.9: + nstyle["fgcolor"] = "darkgreen" + else: + nstyle["fgcolor"] = "red" + nstyle["size"] = 5 + + node.set_style(nstyle) + +# Render to file +tree.render("tree.pdf", tree_style=ts) +tree.render("tree.png", w=800, h=600, units="px", dpi=300) +``` + +Use `scripts/quick_visualize.py` for rapid visualization: + +```bash +# Basic visualization +python scripts/quick_visualize.py tree.nw output.pdf + +# Circular layout with custom styling +python scripts/quick_visualize.py tree.nw output.pdf --mode c --color-by-support + +# High-resolution PNG +python scripts/quick_visualize.py tree.nw output.png --width 1200 --height 800 --units px --dpi 300 + +# Custom title and styling +python scripts/quick_visualize.py tree.nw output.pdf --title "Species Phylogeny" --show-support +``` + +**Advanced visualization with faces:** + +```python +from ete3 import Tree, TreeStyle, TextFace, CircleFace + +tree = Tree("tree.nw") + +# Add features to nodes +for leaf in tree: + leaf.add_feature("habitat", "marine" if "fish" in leaf.name else "land") + +# Layout function +def layout(node): + if node.is_leaf(): + # Add colored circle + color = "blue" if node.habitat == "marine" else "green" + circle = CircleFace(radius=5, color=color) + node.add_face(circle, column=0, position="aligned") + + # Add label + label = TextFace(node.name, fsize=10) + node.add_face(label, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("annotated_tree.pdf", tree_style=ts) +``` + +### 5. Clustering Analysis + +Analyze hierarchical clustering results with data integration: + +- **ClusterTree**: Specialized class for clustering dendrograms +- **Data matrix linking**: Connect tree leaves to numerical profiles +- **Cluster metrics**: Silhouette coefficient, Dunn index, inter/intra-cluster distances +- **Validation**: Test cluster quality with different distance metrics +- **Heatmap visualization**: Display data matrices alongside trees + +**Clustering workflow:** + +```python +from ete3 import ClusterTree + +# Load tree with data matrix +matrix = """#Names\tSample1\tSample2\tSample3 +Gene1\t1.5\t2.3\t0.8 +Gene2\t0.9\t1.1\t1.8 +Gene3\t2.1\t2.5\t0.5""" + +tree = ClusterTree("((Gene1,Gene2),Gene3);", text_array=matrix) + +# Evaluate cluster quality +for node in tree.traverse(): + if not node.is_leaf(): + silhouette = node.get_silhouette() + dunn = node.get_dunn() + + print(f"Cluster: {node.name}") + print(f" Silhouette: {silhouette:.3f}") + print(f" Dunn index: {dunn:.3f}") + +# Visualize with heatmap +tree.show("heatmap") +``` + +### 6. Tree Comparison + +Quantify topological differences between trees: + +- **Robinson-Foulds distance**: Standard metric for tree comparison +- **Normalized RF**: Scale-invariant distance (0.0 to 1.0) +- **Partition analysis**: Identify unique and shared bipartitions +- **Consensus trees**: Analyze support across multiple trees +- **Batch comparison**: Compare multiple trees pairwise + +**Compare two trees:** + +```python +from ete3 import Tree + +tree1 = Tree("tree1.nw") +tree2 = Tree("tree2.nw") + +# Calculate RF distance +rf, max_rf, common_leaves, parts_t1, parts_t2 = tree1.robinson_foulds(tree2) + +print(f"RF distance: {rf}/{max_rf}") +print(f"Normalized RF: {rf/max_rf:.3f}") +print(f"Common leaves: {len(common_leaves)}") + +# Find unique partitions +unique_t1 = parts_t1 - parts_t2 +unique_t2 = parts_t2 - parts_t1 + +print(f"Unique to tree1: {len(unique_t1)}") +print(f"Unique to tree2: {len(unique_t2)}") +``` + +**Compare multiple trees:** + +```python +import numpy as np + +trees = [Tree(f"tree{i}.nw") for i in range(4)] + +# Create distance matrix +n = len(trees) +dist_matrix = np.zeros((n, n)) + +for i in range(n): + for j in range(i+1, n): + rf, max_rf, _, _, _ = trees[i].robinson_foulds(trees[j]) + norm_rf = rf / max_rf if max_rf > 0 else 0 + dist_matrix[i, j] = norm_rf + dist_matrix[j, i] = norm_rf +``` + +## Installation and Setup + +Install ETE toolkit: + +```bash +# Basic installation +pip install ete3 + +# With external dependencies for rendering (optional but recommended) +# On macOS: +brew install qt@5 + +# On Ubuntu/Debian: +sudo apt-get install python3-pyqt5 python3-pyqt5.qtsvg + +# For full features including GUI +pip install ete3[gui] +``` + +**First-time NCBI Taxonomy setup:** + +The first time NCBITaxa is instantiated, it automatically downloads the NCBI taxonomy database (~300MB) to `~/.etetoolkit/taxa.sqlite`. This happens only once: + +```python +from ete3 import NCBITaxa +ncbi = NCBITaxa() # Downloads database on first run +``` + +Update taxonomy database: + +```python +ncbi.update_taxonomy_database() # Download latest NCBI data +``` + +## Common Use Cases + +### Use Case 1: Phylogenomic Pipeline + +Complete workflow from gene tree to ortholog identification: + +```python +from ete3 import PhyloTree, NCBITaxa + +# 1. Load gene tree with alignment +tree = PhyloTree("gene_tree.nw", alignment="alignment.fasta") + +# 2. Configure species naming +tree.set_species_naming_function(lambda x: x.split("_")[0]) + +# 3. Detect evolutionary events +tree.get_descendant_evol_events() + +# 4. Annotate with taxonomy +ncbi = NCBITaxa() +for leaf in tree: + if leaf.species in species_to_taxid: + taxid = species_to_taxid[leaf.species] + lineage = ncbi.get_lineage(taxid) + leaf.add_feature("lineage", lineage) + +# 5. Extract ortholog groups +ortho_groups = tree.get_speciation_trees() + +# 6. Save and visualize +for i, ortho in enumerate(ortho_groups): + ortho.write(outfile=f"ortho_{i}.nw") +``` + +### Use Case 2: Tree Preprocessing and Formatting + +Batch process trees for analysis: + +```bash +# Convert format +python scripts/tree_operations.py convert input.nw output.nw --in-format 0 --out-format 1 + +# Root at midpoint +python scripts/tree_operations.py reroot input.nw rooted.nw --midpoint + +# Prune to focal taxa +python scripts/tree_operations.py prune rooted.nw pruned.nw --keep-taxa taxa_list.txt + +# Get statistics +python scripts/tree_operations.py stats pruned.nw +``` + +### Use Case 3: Publication-Quality Figures + +Create styled visualizations: + +```python +from ete3 import Tree, TreeStyle, NodeStyle, TextFace + +tree = Tree("tree.nw") + +# Define clade colors +clade_colors = { + "Mammals": "red", + "Birds": "blue", + "Fish": "green" +} + +def layout(node): + # Highlight clades + if node.is_leaf(): + for clade, color in clade_colors.items(): + if clade in node.name: + nstyle = NodeStyle() + nstyle["fgcolor"] = color + nstyle["size"] = 8 + node.set_style(nstyle) + else: + # Add support values + if node.support > 0.95: + support = TextFace(f"{node.support:.2f}", fsize=8) + node.add_face(support, column=0, position="branch-top") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_scale = True + +# Render for publication +tree.render("figure.pdf", w=200, units="mm", tree_style=ts) +tree.render("figure.svg", tree_style=ts) # Editable vector +``` + +### Use Case 4: Automated Tree Analysis + +Process multiple trees systematically: + +```python +from ete3 import Tree +import os + +input_dir = "trees" +output_dir = "processed" + +for filename in os.listdir(input_dir): + if filename.endswith(".nw"): + tree = Tree(os.path.join(input_dir, filename)) + + # Standardize: midpoint root, resolve polytomies + midpoint = tree.get_midpoint_outgroup() + tree.set_outgroup(midpoint) + tree.resolve_polytomy(recursive=True) + + # Filter low support branches + for node in tree.traverse(): + if hasattr(node, 'support') and node.support < 0.5: + if not node.is_leaf() and not node.is_root(): + node.delete() + + # Save processed tree + output_file = os.path.join(output_dir, f"processed_{filename}") + tree.write(outfile=output_file) +``` + +## Reference Documentation + +For comprehensive API documentation, code examples, and detailed guides, refer to the following resources in the `references/` directory: + +- **`api_reference.md`**: Complete API documentation for all ETE classes and methods (Tree, PhyloTree, ClusterTree, NCBITaxa), including parameters, return types, and code examples +- **`workflows.md`**: Common workflow patterns organized by task (tree operations, phylogenetic analysis, tree comparison, taxonomy integration, clustering analysis) +- **`visualization.md`**: Comprehensive visualization guide covering TreeStyle, NodeStyle, Faces, layout functions, and advanced visualization techniques + +Load these references when detailed information is needed: + +```python +# To use API reference +# Read references/api_reference.md for complete method signatures and parameters + +# To implement workflows +# Read references/workflows.md for step-by-step workflow examples + +# To create visualizations +# Read references/visualization.md for styling and rendering options +``` + +## Troubleshooting + +**Import errors:** + +```bash +# If "ModuleNotFoundError: No module named 'ete3'" +pip install ete3 + +# For GUI and rendering issues +pip install ete3[gui] +``` + +**Rendering issues:** + +If `tree.render()` or `tree.show()` fails with Qt-related errors, install system dependencies: + +```bash +# macOS +brew install qt@5 + +# Ubuntu/Debian +sudo apt-get install python3-pyqt5 python3-pyqt5.qtsvg +``` + +**NCBI Taxonomy database:** + +If database download fails or becomes corrupted: + +```python +from ete3 import NCBITaxa +ncbi = NCBITaxa() +ncbi.update_taxonomy_database() # Redownload database +``` + +**Memory issues with large trees:** + +For very large trees (>10,000 leaves), use iterators instead of list comprehensions: + +```python +# Memory-efficient iteration +for leaf in tree.iter_leaves(): + process(leaf) + +# Instead of +for leaf in tree.get_leaves(): # Loads all into memory + process(leaf) +``` + +## Newick Format Reference + +ETE supports multiple Newick format specifications (0-100): + +- **Format 0**: Flexible with branch lengths (default) +- **Format 1**: With internal node names +- **Format 2**: With bootstrap/support values +- **Format 5**: Internal node names + branch lengths +- **Format 8**: All features (names, distances, support) +- **Format 9**: Leaf names only +- **Format 100**: Topology only + +Specify format when reading/writing: + +```python +tree = Tree("tree.nw", format=1) +tree.write(outfile="output.nw", format=5) +``` + +NHX (New Hampshire eXtended) format preserves custom features: + +```python +tree.write(outfile="tree.nhx", features=["habitat", "temperature", "depth"]) +``` + +## Best Practices + +1. **Preserve branch lengths**: Use `preserve_branch_length=True` when pruning for phylogenetic analysis +2. **Cache content**: Use `get_cached_content()` for repeated access to node contents on large trees +3. **Use iterators**: Employ `iter_*` methods for memory-efficient processing of large trees +4. **Choose appropriate traversal**: Postorder for bottom-up analysis, preorder for top-down +5. **Validate monophyly**: Always check returned clade type (monophyletic/paraphyletic/polyphyletic) +6. **Vector formats for publication**: Use PDF or SVG for publication figures (scalable, editable) +7. **Interactive testing**: Use `tree.show()` to test visualizations before rendering to file +8. **PhyloTree for phylogenetics**: Use PhyloTree class for gene trees and evolutionary analysis +9. **Copy method selection**: "newick" for speed, "cpickle" for full fidelity, "deepcopy" for complex objects +10. **NCBI query caching**: Store NCBI taxonomy query results to avoid repeated database access diff --git a/scientific-packages/etetoolkit/references/api_reference.md b/scientific-packages/etetoolkit/references/api_reference.md new file mode 100644 index 0000000..73c8fdf --- /dev/null +++ b/scientific-packages/etetoolkit/references/api_reference.md @@ -0,0 +1,583 @@ +# ETE Toolkit API Reference + +## Overview + +ETE (Environment for Tree Exploration) is a Python toolkit for phylogenetic tree manipulation, analysis, and visualization. This reference covers the main classes and methods. + +## Core Classes + +### TreeNode (alias: Tree) + +The fundamental class representing tree structures with hierarchical node organization. + +**Constructor:** +```python +from ete3 import Tree +t = Tree(newick=None, format=0, dist=None, support=None, name=None) +``` + +**Parameters:** +- `newick`: Newick string or file path +- `format`: Newick format (0-100). Common formats: + - `0`: Flexible format with branch lengths and names + - `1`: With internal node names + - `2`: With bootstrap/support values + - `5`: Internal node names and branch lengths + - `8`: All features (names, distances, support) + - `9`: Leaf names only + - `100`: Topology only +- `dist`: Branch length to parent (default: 1.0) +- `support`: Bootstrap/confidence value (default: 1.0) +- `name`: Node identifier + +### PhyloTree + +Specialized class for phylogenetic analysis, extending TreeNode. + +**Constructor:** +```python +from ete3 import PhyloTree +t = PhyloTree(newick=None, alignment=None, alg_format='fasta', + sp_naming_function=None, format=0) +``` + +**Additional Parameters:** +- `alignment`: Path to alignment file or alignment string +- `alg_format`: 'fasta' or 'phylip' +- `sp_naming_function`: Custom function to extract species from node names + +### ClusterTree + +Class for hierarchical clustering analysis. + +**Constructor:** +```python +from ete3 import ClusterTree +t = ClusterTree(newick, text_array=None) +``` + +**Parameters:** +- `text_array`: Tab-delimited matrix with column headers and row names + +### NCBITaxa + +Class for NCBI taxonomy database operations. + +**Constructor:** +```python +from ete3 import NCBITaxa +ncbi = NCBITaxa(dbfile=None) +``` + +First instantiation downloads ~300MB NCBI taxonomy database to `~/.etetoolkit/taxa.sqlite`. + +## Node Properties + +### Basic Attributes + +| Property | Type | Description | Default | +|----------|------|-------------|---------| +| `name` | str | Node identifier | "NoName" | +| `dist` | float | Branch length to parent | 1.0 | +| `support` | float | Bootstrap/confidence value | 1.0 | +| `up` | TreeNode | Parent node reference | None | +| `children` | list | Child nodes | [] | + +### Custom Features + +Add any custom data to nodes: +```python +node.add_feature("custom_name", value) +node.add_features(feature1=value1, feature2=value2) +``` + +Access features: +```python +value = node.custom_name +# or +value = getattr(node, "custom_name", default_value) +``` + +## Navigation & Traversal + +### Basic Navigation + +```python +# Check node type +node.is_leaf() # Returns True if terminal node +node.is_root() # Returns True if root node +len(node) # Number of leaves under node + +# Get relatives +parent = node.up +children = node.children +root = node.get_tree_root() +``` + +### Traversal Strategies + +```python +# Three traversal strategies +for node in tree.traverse("preorder"): # Root → Left → Right + print(node.name) + +for node in tree.traverse("postorder"): # Left → Right → Root + print(node.name) + +for node in tree.traverse("levelorder"): # Level by level + print(node.name) + +# Exclude root +for node in tree.iter_descendants("postorder"): + print(node.name) +``` + +### Getting Nodes + +```python +# Get all leaves +leaves = tree.get_leaves() +for leaf in tree: # Shortcut iteration + print(leaf.name) + +# Get all descendants +descendants = tree.get_descendants() + +# Get ancestors +ancestors = node.get_ancestors() + +# Get specific nodes by attribute +nodes = tree.search_nodes(name="NodeA") +node = tree & "NodeA" # Shortcut syntax + +# Get leaves by name +leaves = tree.get_leaves_by_name("LeafA") + +# Get common ancestor +ancestor = tree.get_common_ancestor("LeafA", "LeafB", "LeafC") + +# Custom filtering +filtered = [n for n in tree.traverse() if n.dist > 0.5 and n.is_leaf()] +``` + +### Iterator Methods (Memory Efficient) + +```python +# For large trees, use iterators +for match in tree.iter_search_nodes(name="X"): + if some_condition: + break # Stop early + +for leaf in tree.iter_leaves(): + process(leaf) + +for descendant in node.iter_descendants(): + process(descendant) +``` + +## Tree Construction & Modification + +### Creating Trees from Scratch + +```python +# Empty tree +t = Tree() + +# Add children +child1 = t.add_child(name="A", dist=1.0) +child2 = t.add_child(name="B", dist=2.0) + +# Add siblings +sister = child1.add_sister(name="C", dist=1.5) + +# Populate with random topology +t.populate(10) # Creates 10 random leaves +t.populate(5, names_library=["A", "B", "C", "D", "E"]) +``` + +### Removing & Deleting Nodes + +```python +# Detach: removes entire subtree +node.detach() +# or +parent.remove_child(node) + +# Delete: removes node, reconnects children to parent +node.delete() +# or +parent.remove_child(node) +``` + +### Pruning + +Keep only specified leaves: +```python +# Keep only these leaves, remove all others +tree.prune(["A", "B", "C"]) + +# Preserve original branch lengths +tree.prune(["A", "B", "C"], preserve_branch_length=True) +``` + +### Tree Concatenation + +```python +# Attach one tree as child of another +t1 = Tree("(A,(B,C));") +t2 = Tree("((D,E),(F,G));") +A = t1 & "A" +A.add_child(t2) +``` + +### Tree Copying + +```python +# Four copy methods +copy1 = tree.copy() # Default: cpickle (preserves types) +copy2 = tree.copy("newick") # Fastest: basic topology +copy3 = tree.copy("newick-extended") # Includes custom features as text +copy4 = tree.copy("deepcopy") # Slowest: handles complex objects +``` + +## Tree Operations + +### Rooting + +```python +# Set outgroup (reroot tree) +outgroup_node = tree & "OutgroupLeaf" +tree.set_outgroup(outgroup_node) + +# Midpoint rooting +midpoint = tree.get_midpoint_outgroup() +tree.set_outgroup(midpoint) + +# Unroot tree +tree.unroot() +``` + +### Resolving Polytomies + +```python +# Resolve multifurcations to bifurcations +tree.resolve_polytomy(recursive=False) # Single node only +tree.resolve_polytomy(recursive=True) # Entire tree +``` + +### Ladderize + +```python +# Sort branches by size +tree.ladderize() +tree.ladderize(direction=1) # Ascending order +``` + +### Convert to Ultrametric + +```python +# Make all leaves equidistant from root +tree.convert_to_ultrametric() +tree.convert_to_ultrametric(tree_length=100) # Specific total length +``` + +## Distance & Comparison + +### Distance Calculations + +```python +# Branch length distance between nodes +dist = tree.get_distance("A", "B") +dist = nodeA.get_distance(nodeB) + +# Topology-only distance (count nodes) +dist = tree.get_distance("A", "B", topology_only=True) + +# Farthest node +farthest, distance = node.get_farthest_node() +farthest_leaf, distance = node.get_farthest_leaf() +``` + +### Monophyly Testing + +```python +# Check if values form monophyletic group +is_mono, clade_type, base_node = tree.check_monophyly( + values=["A", "B", "C"], + target_attr="name" +) +# Returns: (bool, "monophyletic"|"paraphyletic"|"polyphyletic", node) + +# Get all monophyletic clades +monophyletic_nodes = tree.get_monophyletic( + values=["A", "B", "C"], + target_attr="name" +) +``` + +### Tree Comparison + +```python +# Robinson-Foulds distance +rf, max_rf, common_leaves, parts_t1, parts_t2 = t1.robinson_foulds(t2) +print(f"RF distance: {rf}/{max_rf}") + +# Normalized RF distance +result = t1.compare(t2) +norm_rf = result["norm_rf"] # 0.0 to 1.0 +ref_edges = result["ref_edges_in_source"] +``` + +## Input/Output + +### Reading Trees + +```python +# From string +t = Tree("(A:1,(B:1,(C:1,D:1):0.5):0.5);") + +# From file +t = Tree("tree.nw") + +# With format +t = Tree("tree.nw", format=1) +``` + +### Writing Trees + +```python +# To string +newick = tree.write() +newick = tree.write(format=1) +newick = tree.write(format=1, features=["support", "custom_feature"]) + +# To file +tree.write(outfile="output.nw") +tree.write(format=5, outfile="output.nw", features=["name", "dist"]) + +# Custom leaf function (for collapsing) +def is_leaf(node): + return len(node) <= 3 # Treat small clades as leaves + +newick = tree.write(is_leaf_fn=is_leaf) +``` + +### Tree Rendering + +```python +# Show interactive GUI +tree.show() + +# Render to file (PNG, PDF, SVG) +tree.render("tree.png") +tree.render("tree.pdf", w=200, units="mm") +tree.render("tree.svg", dpi=300) + +# ASCII representation +print(tree) +print(tree.get_ascii(show_internal=True, compact=False)) +``` + +## Performance Optimization + +### Caching Content + +For frequent access to node contents: +```python +# Cache all node contents +node2content = tree.get_cached_content() + +# Fast lookup +for node in tree.traverse(): + leaves = node2content[node] + print(f"Node has {len(leaves)} leaves") +``` + +### Precomputing Distances + +```python +# For multiple distance queries +node2dist = {} +for node in tree.traverse(): + node2dist[node] = node.get_distance(tree) +``` + +## PhyloTree-Specific Methods + +### Sequence Alignment + +```python +# Link alignment +tree.link_to_alignment("alignment.fasta", alg_format="fasta") + +# Access sequences +for leaf in tree: + print(f"{leaf.name}: {leaf.sequence}") +``` + +### Species Naming + +```python +# Default: first 3 letters +# Custom function +def get_species(node_name): + return node_name.split("_")[0] + +tree.set_species_naming_function(get_species) + +# Manual setting +for leaf in tree: + leaf.species = extract_species(leaf.name) +``` + +### Evolutionary Events + +```python +# Detect duplication/speciation events +events = tree.get_descendant_evol_events() + +for node in tree.traverse(): + if hasattr(node, "evoltype"): + print(f"{node.name}: {node.evoltype}") # "D" or "S" + +# With species tree +species_tree = Tree("(human, (chimp, gorilla));") +events = tree.get_descendant_evol_events(species_tree=species_tree) +``` + +### Gene Tree Operations + +```python +# Get species trees from duplicated gene families +species_trees = tree.get_speciation_trees() + +# Split by duplication events +subtrees = tree.split_by_dups() + +# Collapse lineage-specific expansions +tree.collapse_lineage_specific_expansions() +``` + +## NCBITaxa Methods + +### Database Operations + +```python +from ete3 import NCBITaxa +ncbi = NCBITaxa() + +# Update database +ncbi.update_taxonomy_database() +``` + +### Querying Taxonomy + +```python +# Get taxid from name +taxid = ncbi.get_name_translator(["Homo sapiens"]) +# Returns: {'Homo sapiens': [9606]} + +# Get name from taxid +names = ncbi.get_taxid_translator([9606, 9598]) +# Returns: {9606: 'Homo sapiens', 9598: 'Pan troglodytes'} + +# Get rank +rank = ncbi.get_rank([9606]) +# Returns: {9606: 'species'} + +# Get lineage +lineage = ncbi.get_lineage(9606) +# Returns: [1, 131567, 2759, ..., 9606] + +# Get descendants +descendants = ncbi.get_descendant_taxa("Primates") +descendants = ncbi.get_descendant_taxa("Primates", collapse_subspecies=True) +``` + +### Building Taxonomy Trees + +```python +# Get minimal tree connecting taxa +tree = ncbi.get_topology([9606, 9598, 9593]) # Human, chimp, gorilla + +# Annotate tree with taxonomy +tree.annotate_ncbi_taxa() + +# Access taxonomy info +for node in tree.traverse(): + print(f"{node.sci_name} ({node.taxid}) - Rank: {node.rank}") +``` + +## ClusterTree Methods + +### Linking to Data + +```python +# Link matrix to tree +tree.link_to_arraytable(matrix_string) + +# Access profiles +for leaf in tree: + print(leaf.profile) # Numerical array +``` + +### Cluster Metrics + +```python +# Get silhouette coefficient +silhouette = tree.get_silhouette() + +# Get Dunn index +dunn = tree.get_dunn() + +# Inter/intra cluster distances +inter = node.intercluster_dist +intra = node.intracluster_dist + +# Standard deviation +dev = node.deviation +``` + +### Distance Metrics + +Supported metrics: +- `"euclidean"`: Euclidean distance +- `"pearson"`: Pearson correlation +- `"spearman"`: Spearman rank correlation + +```python +tree.dist_to(node2, metric="pearson") +``` + +## Common Error Handling + +```python +# Check if tree is empty +if tree.children: + print("Tree has children") + +# Check if node exists +nodes = tree.search_nodes(name="X") +if nodes: + node = nodes[0] + +# Safe feature access +value = getattr(node, "feature_name", default_value) + +# Check format compatibility +try: + tree.write(format=1) +except: + print("Tree lacks internal node names") +``` + +## Best Practices + +1. **Use appropriate traversal**: Postorder for bottom-up, preorder for top-down +2. **Cache for repeated access**: Use `get_cached_content()` for frequent queries +3. **Use iterators for large trees**: Memory-efficient processing +4. **Preserve branch lengths**: Use `preserve_branch_length=True` when pruning +5. **Choose copy method wisely**: "newick" for speed, "cpickle" for full fidelity +6. **Validate monophyly**: Check returned clade type (monophyletic/paraphyletic/polyphyletic) +7. **Use PhyloTree for phylogenetics**: Specialized methods for evolutionary analysis +8. **Cache NCBI queries**: Store results to avoid repeated database access diff --git a/scientific-packages/etetoolkit/references/visualization.md b/scientific-packages/etetoolkit/references/visualization.md new file mode 100644 index 0000000..84825e9 --- /dev/null +++ b/scientific-packages/etetoolkit/references/visualization.md @@ -0,0 +1,783 @@ +# ETE Toolkit Visualization Guide + +Complete guide to tree visualization with ETE Toolkit. + +## Table of Contents +1. [Rendering Basics](#rendering-basics) +2. [TreeStyle Configuration](#treestyle-configuration) +3. [Node Styling](#node-styling) +4. [Faces](#faces) +5. [Layout Functions](#layout-functions) +6. [Advanced Visualization](#advanced-visualization) + +--- + +## Rendering Basics + +### Output Formats + +ETE supports three main output formats: + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# PNG (raster, good for presentations) +tree.render("output.png", w=800, h=600, units="px", dpi=300) + +# PDF (vector, good for publications) +tree.render("output.pdf", w=200, units="mm") + +# SVG (vector, editable) +tree.render("output.svg") +``` + +### Units and Dimensions + +```python +# Pixels +tree.render("tree.png", w=1200, h=800, units="px") + +# Millimeters +tree.render("tree.pdf", w=210, h=297, units="mm") # A4 size + +# Inches +tree.render("tree.pdf", w=8.5, h=11, units="in") # US Letter + +# Auto-size (aspect ratio preserved) +tree.render("tree.pdf", w=200, units="mm") # Height auto-calculated +``` + +### Interactive Visualization + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# Launch GUI +# - Zoom with mouse wheel +# - Pan by dragging +# - Search with Ctrl+F +# - Export from menu +# - Edit node properties +tree.show() +``` + +--- + +## TreeStyle Configuration + +### Basic TreeStyle Options + +```python +from ete3 import Tree, TreeStyle + +tree = Tree("tree.nw") +ts = TreeStyle() + +# Display options +ts.show_leaf_name = True # Show leaf names +ts.show_branch_length = True # Show branch lengths +ts.show_branch_support = True # Show support values +ts.show_scale = True # Show scale bar + +# Branch length scaling +ts.scale = 50 # Pixels per branch length unit +ts.min_leaf_separation = 10 # Minimum space between leaves (pixels) + +# Layout orientation +ts.rotation = 0 # 0=left-to-right, 90=top-to-bottom +ts.branch_vertical_margin = 10 # Vertical spacing between branches + +# Tree shape +ts.mode = "r" # "r"=rectangular (default), "c"=circular + +tree.render("tree.pdf", tree_style=ts) +``` + +### Circular Trees + +```python +from ete3 import Tree, TreeStyle + +tree = Tree("tree.nw") +ts = TreeStyle() + +# Circular mode +ts.mode = "c" +ts.arc_start = 0 # Starting angle (degrees) +ts.arc_span = 360 # Angular span (degrees, 360=full circle) + +# For semicircle +ts.arc_start = -180 +ts.arc_span = 180 + +tree.render("circular_tree.pdf", tree_style=ts) +``` + +### Title and Legend + +```python +from ete3 import Tree, TreeStyle, TextFace + +tree = Tree("tree.nw") +ts = TreeStyle() + +# Add title +title = TextFace("Phylogenetic Tree of Species", fsize=20, bold=True) +ts.title.add_face(title, column=0) + +# Add legend +ts.legend.add_face(TextFace("Red nodes: High support", fsize=10), column=0) +ts.legend.add_face(TextFace("Blue nodes: Low support", fsize=10), column=0) + +# Legend position +ts.legend_position = 1 # 1=top-right, 2=top-left, 3=bottom-left, 4=bottom-right + +tree.render("tree_with_legend.pdf", tree_style=ts) +``` + +### Custom Background + +```python +from ete3 import Tree, TreeStyle + +tree = Tree("tree.nw") +ts = TreeStyle() + +# Background color +ts.bgcolor = "#f0f0f0" # Light gray background + +# Tree border +ts.show_border = True + +tree.render("tree_background.pdf", tree_style=ts) +``` + +--- + +## Node Styling + +### NodeStyle Properties + +```python +from ete3 import Tree, NodeStyle + +tree = Tree("tree.nw") + +for node in tree.traverse(): + nstyle = NodeStyle() + + # Node size and shape + nstyle["size"] = 10 # Node size in pixels + nstyle["shape"] = "circle" # "circle", "square", "sphere" + + # Colors + nstyle["fgcolor"] = "blue" # Foreground color (node itself) + nstyle["bgcolor"] = "lightblue" # Background color (only for sphere) + + # Line style for branches + nstyle["hz_line_type"] = 0 # 0=solid, 1=dashed, 2=dotted + nstyle["vt_line_type"] = 0 # Vertical line type + nstyle["hz_line_color"] = "black" # Horizontal line color + nstyle["vt_line_color"] = "black" # Vertical line color + nstyle["hz_line_width"] = 2 # Line width in pixels + nstyle["vt_line_width"] = 2 + + node.set_style(nstyle) + +tree.render("styled_tree.pdf") +``` + +### Conditional Styling + +```python +from ete3 import Tree, NodeStyle + +tree = Tree("tree.nw") + +# Style based on node properties +for node in tree.traverse(): + nstyle = NodeStyle() + + if node.is_leaf(): + # Leaf node style + nstyle["size"] = 8 + nstyle["fgcolor"] = "darkgreen" + nstyle["shape"] = "circle" + else: + # Internal node style based on support + if node.support > 0.9: + nstyle["size"] = 6 + nstyle["fgcolor"] = "red" + nstyle["shape"] = "sphere" + else: + nstyle["size"] = 4 + nstyle["fgcolor"] = "gray" + nstyle["shape"] = "circle" + + # Style branches by length + if node.dist > 1.0: + nstyle["hz_line_width"] = 3 + nstyle["hz_line_color"] = "blue" + else: + nstyle["hz_line_width"] = 1 + nstyle["hz_line_color"] = "black" + + node.set_style(nstyle) + +tree.render("conditional_styled_tree.pdf") +``` + +### Hiding Nodes + +```python +from ete3 import Tree, NodeStyle + +tree = Tree("tree.nw") + +# Hide specific nodes +for node in tree.traverse(): + if node.support < 0.5: # Hide low support nodes + nstyle = NodeStyle() + nstyle["draw_descendants"] = False # Don't draw this node's subtree + nstyle["size"] = 0 # Make node invisible + node.set_style(nstyle) + +tree.render("filtered_tree.pdf") +``` + +--- + +## Faces + +Faces are graphical elements attached to nodes. They appear at specific positions around nodes. + +### Face Positions + +- `"branch-right"`: Right side of branch (after node) +- `"branch-top"`: Above branch +- `"branch-bottom"`: Below branch +- `"aligned"`: Aligned column at tree edge (for leaves) + +### TextFace + +```python +from ete3 import Tree, TreeStyle, TextFace + +tree = Tree("tree.nw") + +def layout(node): + if node.is_leaf(): + # Add species name + name_face = TextFace(node.name, fsize=12, fgcolor="black") + node.add_face(name_face, column=0, position="branch-right") + + # Add additional text + info_face = TextFace(f"Length: {node.dist:.3f}", fsize=8, fgcolor="gray") + node.add_face(info_face, column=1, position="branch-right") + else: + # Add support value + if node.support: + support_face = TextFace(f"{node.support:.2f}", fsize=8, fgcolor="red") + node.add_face(support_face, column=0, position="branch-top") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False # We're adding custom names + +tree.render("tree_textfaces.pdf", tree_style=ts) +``` + +### AttrFace + +Display node attributes directly: + +```python +from ete3 import Tree, TreeStyle, AttrFace + +tree = Tree("tree.nw") + +# Add custom attributes +for leaf in tree: + leaf.add_feature("habitat", "aquatic" if "fish" in leaf.name else "terrestrial") + leaf.add_feature("temperature", 20) + +def layout(node): + if node.is_leaf(): + # Display attribute directly + habitat_face = AttrFace("habitat", fsize=10) + node.add_face(habitat_face, column=0, position="aligned") + + temp_face = AttrFace("temperature", fsize=10) + node.add_face(temp_face, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout + +tree.render("tree_attrfaces.pdf", tree_style=ts) +``` + +### CircleFace + +```python +from ete3 import Tree, TreeStyle, CircleFace, TextFace + +tree = Tree("tree.nw") + +# Annotate with habitat +for leaf in tree: + leaf.add_feature("habitat", "marine" if "fish" in leaf.name else "land") + +def layout(node): + if node.is_leaf(): + # Colored circle based on habitat + color = "blue" if node.habitat == "marine" else "green" + circle = CircleFace(radius=5, color=color, style="circle") + node.add_face(circle, column=0, position="aligned") + + # Label + name = TextFace(node.name, fsize=10) + node.add_face(name, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("tree_circles.pdf", tree_style=ts) +``` + +### ImgFace + +Add images to nodes: + +```python +from ete3 import Tree, TreeStyle, ImgFace, TextFace + +tree = Tree("tree.nw") + +def layout(node): + if node.is_leaf(): + # Add species image + img_path = f"images/{node.name}.png" # Path to image + try: + img_face = ImgFace(img_path, width=50, height=50) + node.add_face(img_face, column=0, position="aligned") + except: + pass # Skip if image doesn't exist + + # Add name + name_face = TextFace(node.name, fsize=10) + node.add_face(name_face, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("tree_images.pdf", tree_style=ts) +``` + +### BarChartFace + +```python +from ete3 import Tree, TreeStyle, BarChartFace, TextFace + +tree = Tree("tree.nw") + +# Add data for bar charts +for leaf in tree: + leaf.add_feature("values", [1.2, 2.3, 0.5, 1.8]) # Multiple values + +def layout(node): + if node.is_leaf(): + # Add bar chart + chart = BarChartFace( + node.values, + width=100, + height=40, + colors=["red", "blue", "green", "orange"], + labels=["A", "B", "C", "D"] + ) + node.add_face(chart, column=0, position="aligned") + + # Add name + name = TextFace(node.name, fsize=10) + node.add_face(name, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("tree_barcharts.pdf", tree_style=ts) +``` + +### PieChartFace + +```python +from ete3 import Tree, TreeStyle, PieChartFace, TextFace + +tree = Tree("tree.nw") + +# Add data +for leaf in tree: + leaf.add_feature("proportions", [25, 35, 40]) # Percentages + +def layout(node): + if node.is_leaf(): + # Add pie chart + pie = PieChartFace( + node.proportions, + width=30, + height=30, + colors=["red", "blue", "green"] + ) + node.add_face(pie, column=0, position="aligned") + + name = TextFace(node.name, fsize=10) + node.add_face(name, column=1, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("tree_piecharts.pdf", tree_style=ts) +``` + +### SequenceFace (for alignments) + +```python +from ete3 import PhyloTree, TreeStyle, SeqMotifFace + +tree = PhyloTree("tree.nw") +tree.link_to_alignment("alignment.fasta") + +def layout(node): + if node.is_leaf(): + # Display sequence + seq_face = SeqMotifFace(node.sequence, seq_format="seq") + node.add_face(seq_face, column=0, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = True + +tree.render("tree_alignment.pdf", tree_style=ts) +``` + +--- + +## Layout Functions + +Layout functions are Python functions that modify node appearance during rendering. + +### Basic Layout Function + +```python +from ete3 import Tree, TreeStyle, TextFace + +tree = Tree("tree.nw") + +def my_layout(node): + """Called for every node before rendering""" + + if node.is_leaf(): + # Add text to leaves + name_face = TextFace(node.name.upper(), fsize=12, fgcolor="blue") + node.add_face(name_face, column=0, position="branch-right") + else: + # Add support to internal nodes + if node.support: + support_face = TextFace(f"BS: {node.support:.0f}", fsize=8) + node.add_face(support_face, column=0, position="branch-top") + +# Apply layout function +ts = TreeStyle() +ts.layout_fn = my_layout +ts.show_leaf_name = False + +tree.render("tree_custom_layout.pdf", tree_style=ts) +``` + +### Dynamic Styling in Layout + +```python +from ete3 import Tree, TreeStyle, NodeStyle, TextFace + +tree = Tree("tree.nw") + +def layout(node): + # Modify node style dynamically + nstyle = NodeStyle() + + # Color by clade + if "clade_A" in [l.name for l in node.get_leaves()]: + nstyle["bgcolor"] = "lightblue" + elif "clade_B" in [l.name for l in node.get_leaves()]: + nstyle["bgcolor"] = "lightgreen" + + node.set_style(nstyle) + + # Add faces based on features + if hasattr(node, "annotation"): + text = TextFace(node.annotation, fsize=8) + node.add_face(text, column=0, position="branch-top") + +ts = TreeStyle() +ts.layout_fn = layout + +tree.render("tree_dynamic.pdf", tree_style=ts) +``` + +### Multiple Column Layout + +```python +from ete3 import Tree, TreeStyle, TextFace, CircleFace + +tree = Tree("tree.nw") + +# Add features +for leaf in tree: + leaf.add_feature("habitat", "aquatic") + leaf.add_feature("temp", 20) + leaf.add_feature("depth", 100) + +def layout(node): + if node.is_leaf(): + # Column 0: Name + name = TextFace(node.name, fsize=10) + node.add_face(name, column=0, position="aligned") + + # Column 1: Habitat indicator + color = "blue" if node.habitat == "aquatic" else "brown" + circle = CircleFace(radius=5, color=color) + node.add_face(circle, column=1, position="aligned") + + # Column 2: Temperature + temp = TextFace(f"{node.temp}°C", fsize=8) + node.add_face(temp, column=2, position="aligned") + + # Column 3: Depth + depth = TextFace(f"{node.depth}m", fsize=8) + node.add_face(depth, column=3, position="aligned") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +tree.render("tree_columns.pdf", tree_style=ts) +``` + +--- + +## Advanced Visualization + +### Highlighting Clades + +```python +from ete3 import Tree, TreeStyle, NodeStyle, TextFace + +tree = Tree("tree.nw") + +# Define clades to highlight +clade_members = { + "Clade_A": ["species1", "species2", "species3"], + "Clade_B": ["species4", "species5"] +} + +def layout(node): + # Check if node is ancestor of specific clade + node_leaves = set([l.name for l in node.get_leaves()]) + + for clade_name, members in clade_members.items(): + if set(members).issubset(node_leaves): + # This node is ancestor of the clade + nstyle = NodeStyle() + nstyle["bgcolor"] = "yellow" + nstyle["size"] = 0 + + # Add label + if set(members) == node_leaves: # Exact match + label = TextFace(clade_name, fsize=14, bold=True, fgcolor="red") + node.add_face(label, column=0, position="branch-top") + + node.set_style(nstyle) + break + +ts = TreeStyle() +ts.layout_fn = layout + +tree.render("tree_highlighted_clades.pdf", tree_style=ts) +``` + +### Collapsing Clades + +```python +from ete3 import Tree, TreeStyle, TextFace, NodeStyle + +tree = Tree("tree.nw") + +# Define which clades to collapse +clades_to_collapse = ["clade1_species1", "clade1_species2"] + +def layout(node): + if not node.is_leaf(): + node_leaves = [l.name for l in node.get_leaves()] + + # Check if this is a clade we want to collapse + if all(l in clades_to_collapse for l in node_leaves): + # Collapse by hiding descendants + nstyle = NodeStyle() + nstyle["draw_descendants"] = False + nstyle["size"] = 20 + nstyle["fgcolor"] = "steelblue" + nstyle["shape"] = "sphere" + node.set_style(nstyle) + + # Add label showing what's collapsed + label = TextFace(f"[{len(node_leaves)} species]", fsize=10) + node.add_face(label, column=0, position="branch-right") + +ts = TreeStyle() +ts.layout_fn = layout + +tree.render("tree_collapsed.pdf", tree_style=ts) +``` + +### Heat Map Visualization + +```python +from ete3 import Tree, TreeStyle, RectFace, TextFace +import numpy as np + +tree = Tree("tree.nw") + +# Generate random data for heatmap +for leaf in tree: + leaf.add_feature("data", np.random.rand(10)) # 10 data points + +def layout(node): + if node.is_leaf(): + # Add name + name = TextFace(node.name, fsize=8) + node.add_face(name, column=0, position="aligned") + + # Add heatmap cells + for i, value in enumerate(node.data): + # Color based on value + intensity = int(255 * value) + color = f"#{255-intensity:02x}{intensity:02x}00" # Green-red gradient + + rect = RectFace(width=20, height=15, fgcolor=color, bgcolor=color) + node.add_face(rect, column=i+1, position="aligned") + +# Add column headers +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False + +# Add header +for i in range(10): + header = TextFace(f"C{i+1}", fsize=8, fgcolor="gray") + ts.aligned_header.add_face(header, column=i+1) + +tree.render("tree_heatmap.pdf", tree_style=ts) +``` + +### Phylogenetic Events Visualization + +```python +from ete3 import PhyloTree, TreeStyle, TextFace, NodeStyle + +tree = PhyloTree("gene_tree.nw") +tree.set_species_naming_function(lambda x: x.split("_")[0]) +tree.get_descendant_evol_events() + +def layout(node): + # Style based on evolutionary event + if hasattr(node, "evoltype"): + nstyle = NodeStyle() + + if node.evoltype == "D": # Duplication + nstyle["fgcolor"] = "red" + nstyle["size"] = 10 + nstyle["shape"] = "square" + + label = TextFace("DUP", fsize=8, fgcolor="red", bold=True) + node.add_face(label, column=0, position="branch-top") + + elif node.evoltype == "S": # Speciation + nstyle["fgcolor"] = "blue" + nstyle["size"] = 6 + nstyle["shape"] = "circle" + + node.set_style(nstyle) + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = True + +tree.render("gene_tree_events.pdf", tree_style=ts) +``` + +### Custom Tree with Legend + +```python +from ete3 import Tree, TreeStyle, TextFace, CircleFace, NodeStyle + +tree = Tree("tree.nw") + +# Categorize species +for leaf in tree: + if "fish" in leaf.name.lower(): + leaf.add_feature("category", "fish") + elif "bird" in leaf.name.lower(): + leaf.add_feature("category", "bird") + else: + leaf.add_feature("category", "mammal") + +category_colors = { + "fish": "blue", + "bird": "green", + "mammal": "red" +} + +def layout(node): + if node.is_leaf(): + # Color by category + nstyle = NodeStyle() + nstyle["fgcolor"] = category_colors[node.category] + nstyle["size"] = 10 + node.set_style(nstyle) + +ts = TreeStyle() +ts.layout_fn = layout + +# Add legend +ts.legend.add_face(TextFace("Legend:", fsize=12, bold=True), column=0) +for category, color in category_colors.items(): + circle = CircleFace(radius=5, color=color) + ts.legend.add_face(circle, column=0) + label = TextFace(f" {category.capitalize()}", fsize=10) + ts.legend.add_face(label, column=1) + +ts.legend_position = 1 + +tree.render("tree_with_legend.pdf", tree_style=ts) +``` + +--- + +## Best Practices + +1. **Use layout functions** for complex visualizations - they're called during rendering +2. **Set `show_leaf_name = False`** when using custom name faces +3. **Use aligned position** for columnar data at leaf level +4. **Choose appropriate units**: pixels for screen, mm/inches for print +5. **Use vector formats (PDF/SVG)** for publications +6. **Precompute styling** when possible - layout functions should be fast +7. **Test interactively** with `show()` before rendering to file +8. **Use NodeStyle for permanent** changes, layout functions for rendering-time changes +9. **Align faces in columns** for clean, organized appearance +10. **Add legends** to explain colors and symbols used diff --git a/scientific-packages/etetoolkit/references/workflows.md b/scientific-packages/etetoolkit/references/workflows.md new file mode 100644 index 0000000..43b4938 --- /dev/null +++ b/scientific-packages/etetoolkit/references/workflows.md @@ -0,0 +1,774 @@ +# ETE Toolkit Common Workflows + +This document provides complete workflows for common tasks using the ETE Toolkit. + +## Table of Contents +1. [Basic Tree Operations](#basic-tree-operations) +2. [Phylogenetic Analysis](#phylogenetic-analysis) +3. [Tree Comparison](#tree-comparison) +4. [Taxonomy Integration](#taxonomy-integration) +5. [Clustering Analysis](#clustering-analysis) +6. [Tree Visualization](#tree-visualization) + +--- + +## Basic Tree Operations + +### Loading and Exploring a Tree + +```python +from ete3 import Tree + +# Load tree from file +tree = Tree("my_tree.nw", format=1) + +# Display ASCII representation +print(tree.get_ascii(show_internal=True)) + +# Get basic statistics +print(f"Number of leaves: {len(tree)}") +print(f"Total nodes: {len(list(tree.traverse()))}") +print(f"Tree depth: {tree.get_farthest_leaf()[1]}") + +# List all leaf names +for leaf in tree: + print(leaf.name) +``` + +### Extracting and Saving Subtrees + +```python +from ete3 import Tree + +tree = Tree("full_tree.nw") + +# Get subtree rooted at specific node +node = tree.search_nodes(name="MyNode")[0] +subtree = node.copy() + +# Save subtree to file +subtree.write(outfile="subtree.nw", format=1) + +# Extract monophyletic clade +species_of_interest = ["species1", "species2", "species3"] +ancestor = tree.get_common_ancestor(species_of_interest) +clade = ancestor.copy() +clade.write(outfile="clade.nw") +``` + +### Pruning Trees to Specific Taxa + +```python +from ete3 import Tree + +tree = Tree("large_tree.nw") + +# Keep only taxa of interest +taxa_to_keep = ["taxon1", "taxon2", "taxon3", "taxon4"] +tree.prune(taxa_to_keep, preserve_branch_length=True) + +# Save pruned tree +tree.write(outfile="pruned_tree.nw") +``` + +### Rerooting Trees + +```python +from ete3 import Tree + +tree = Tree("unrooted_tree.nw") + +# Method 1: Root by outgroup +outgroup = tree & "Outgroup_species" +tree.set_outgroup(outgroup) + +# Method 2: Midpoint rooting +midpoint = tree.get_midpoint_outgroup() +tree.set_outgroup(midpoint) + +# Save rooted tree +tree.write(outfile="rooted_tree.nw") +``` + +### Annotating Nodes with Custom Data + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# Add features to nodes based on metadata +metadata = { + "species1": {"habitat": "marine", "temperature": 20}, + "species2": {"habitat": "freshwater", "temperature": 15}, +} + +for leaf in tree: + if leaf.name in metadata: + leaf.add_features(**metadata[leaf.name]) + +# Query annotated features +for leaf in tree: + if hasattr(leaf, "habitat"): + print(f"{leaf.name}: {leaf.habitat}, {leaf.temperature}°C") + +# Save with custom features (NHX format) +tree.write(outfile="annotated_tree.nhx", features=["habitat", "temperature"]) +``` + +### Modifying Tree Topology + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# Remove a clade +node_to_remove = tree & "unwanted_clade" +node_to_remove.detach() + +# Collapse a node (delete but keep children) +node_to_collapse = tree & "low_support_node" +node_to_collapse.delete() + +# Add a new species to existing clade +target_clade = tree & "target_node" +new_leaf = target_clade.add_child(name="new_species", dist=0.5) + +# Resolve polytomies +tree.resolve_polytomy(recursive=True) + +# Save modified tree +tree.write(outfile="modified_tree.nw") +``` + +--- + +## Phylogenetic Analysis + +### Complete Gene Tree Analysis with Alignment + +```python +from ete3 import PhyloTree + +# Load gene tree and link alignment +tree = PhyloTree("gene_tree.nw", format=1) +tree.link_to_alignment("alignment.fasta", alg_format="fasta") + +# Set species naming function (e.g., gene_species format) +def extract_species(node_name): + return node_name.split("_")[0] + +tree.set_species_naming_function(extract_species) + +# Access sequences +for leaf in tree: + print(f"{leaf.name} ({leaf.species})") + print(f"Sequence: {leaf.sequence[:50]}...") +``` + +### Detecting Duplication and Speciation Events + +```python +from ete3 import PhyloTree, Tree + +# Load gene tree +gene_tree = PhyloTree("gene_tree.nw") + +# Set species naming +gene_tree.set_species_naming_function(lambda x: x.split("_")[0]) + +# Option 1: Species Overlap algorithm (no species tree needed) +events = gene_tree.get_descendant_evol_events() + +# Option 2: Tree reconciliation (requires species tree) +species_tree = Tree("species_tree.nw") +events = gene_tree.get_descendant_evol_events(species_tree=species_tree) + +# Analyze events +duplications = 0 +speciations = 0 + +for node in gene_tree.traverse(): + if hasattr(node, "evoltype"): + if node.evoltype == "D": + duplications += 1 + print(f"Duplication at node {node.name}") + elif node.evoltype == "S": + speciations += 1 + +print(f"\nTotal duplications: {duplications}") +print(f"Total speciations: {speciations}") +``` + +### Extracting Orthologs and Paralogs + +```python +from ete3 import PhyloTree + +gene_tree = PhyloTree("gene_tree.nw") +gene_tree.set_species_naming_function(lambda x: x.split("_")[0]) + +# Detect evolutionary events +events = gene_tree.get_descendant_evol_events() + +# Find all orthologs to a query gene +query_gene = gene_tree & "species1_gene1" + +orthologs = [] +paralogs = [] + +for event in events: + if query_gene in event.in_seqs: + if event.etype == "S": # Speciation + orthologs.extend([s for s in event.out_seqs if s != query_gene]) + elif event.etype == "D": # Duplication + paralogs.extend([s for s in event.out_seqs if s != query_gene]) + +print(f"Orthologs of {query_gene.name}:") +for ortholog in set(orthologs): + print(f" {ortholog.name}") + +print(f"\nParalogs of {query_gene.name}:") +for paralog in set(paralogs): + print(f" {paralog.name}") +``` + +### Splitting Gene Families by Duplication Events + +```python +from ete3 import PhyloTree + +gene_tree = PhyloTree("gene_family.nw") +gene_tree.set_species_naming_function(lambda x: x.split("_")[0]) +gene_tree.get_descendant_evol_events() + +# Split into individual gene families +subfamilies = gene_tree.split_by_dups() + +print(f"Gene family split into {len(subfamilies)} subfamilies") + +for i, subtree in enumerate(subfamilies): + subtree.write(outfile=f"subfamily_{i}.nw") + species = set([leaf.species for leaf in subtree]) + print(f"Subfamily {i}: {len(subtree)} genes from {len(species)} species") +``` + +### Collapsing Lineage-Specific Expansions + +```python +from ete3 import PhyloTree + +gene_tree = PhyloTree("expanded_tree.nw") +gene_tree.set_species_naming_function(lambda x: x.split("_")[0]) + +# Collapse lineage-specific duplications +gene_tree.collapse_lineage_specific_expansions() + +print("After collapsing expansions:") +print(gene_tree.get_ascii()) + +gene_tree.write(outfile="collapsed_tree.nw") +``` + +### Testing Monophyly + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# Test if a group is monophyletic +target_species = ["species1", "species2", "species3"] +is_mono, clade_type, base_node = tree.check_monophyly( + values=target_species, + target_attr="name" +) + +if is_mono: + print(f"Group is monophyletic") + print(f"MRCA: {base_node.name}") +elif clade_type == "paraphyletic": + print(f"Group is paraphyletic") +elif clade_type == "polyphyletic": + print(f"Group is polyphyletic") + +# Get all monophyletic clades of a specific type +# Annotate leaves first +for leaf in tree: + if leaf.name.startswith("species"): + leaf.add_feature("type", "typeA") + else: + leaf.add_feature("type", "typeB") + +mono_clades = tree.get_monophyletic(values=["typeA"], target_attr="type") +print(f"Found {len(mono_clades)} monophyletic clades of typeA") +``` + +--- + +## Tree Comparison + +### Computing Robinson-Foulds Distance + +```python +from ete3 import Tree + +tree1 = Tree("tree1.nw") +tree2 = Tree("tree2.nw") + +# Compute RF distance +rf, max_rf, common_leaves, parts_t1, parts_t2 = tree1.robinson_foulds(tree2) + +print(f"Robinson-Foulds distance: {rf}") +print(f"Maximum RF distance: {max_rf}") +print(f"Normalized RF: {rf/max_rf:.3f}") +print(f"Common leaves: {len(common_leaves)}") + +# Find unique partitions +unique_in_t1 = parts_t1 - parts_t2 +unique_in_t2 = parts_t2 - parts_t1 + +print(f"\nPartitions unique to tree1: {len(unique_in_t1)}") +print(f"Partitions unique to tree2: {len(unique_in_t2)}") +``` + +### Comparing Multiple Trees + +```python +from ete3 import Tree +import numpy as np + +# Load multiple trees +tree_files = ["tree1.nw", "tree2.nw", "tree3.nw", "tree4.nw"] +trees = [Tree(f) for f in tree_files] + +# Create distance matrix +n = len(trees) +dist_matrix = np.zeros((n, n)) + +for i in range(n): + for j in range(i+1, n): + rf, max_rf, _, _, _ = trees[i].robinson_foulds(trees[j]) + norm_rf = rf / max_rf if max_rf > 0 else 0 + dist_matrix[i, j] = norm_rf + dist_matrix[j, i] = norm_rf + +print("Normalized RF distance matrix:") +print(dist_matrix) + +# Find most similar pair +min_dist = float('inf') +best_pair = None + +for i in range(n): + for j in range(i+1, n): + if dist_matrix[i, j] < min_dist: + min_dist = dist_matrix[i, j] + best_pair = (i, j) + +print(f"\nMost similar trees: {tree_files[best_pair[0]]} and {tree_files[best_pair[1]]}") +print(f"Distance: {min_dist:.3f}") +``` + +### Finding Consensus Topology + +```python +from ete3 import Tree + +# Load multiple bootstrap trees +bootstrap_trees = [Tree(f"bootstrap_{i}.nw") for i in range(100)] + +# Get reference tree (first tree) +ref_tree = bootstrap_trees[0].copy() + +# Count bipartitions +bipartition_counts = {} + +for tree in bootstrap_trees: + rf, max_rf, common, parts_ref, parts_tree = ref_tree.robinson_foulds(tree) + for partition in parts_tree: + bipartition_counts[partition] = bipartition_counts.get(partition, 0) + 1 + +# Filter by support threshold +threshold = 70 # 70% support +supported_bipartitions = { + k: v for k, v in bipartition_counts.items() + if (v / len(bootstrap_trees)) * 100 >= threshold +} + +print(f"Bipartitions with >{threshold}% support: {len(supported_bipartitions)}") +``` + +--- + +## Taxonomy Integration + +### Building Species Trees from NCBI Taxonomy + +```python +from ete3 import NCBITaxa + +ncbi = NCBITaxa() + +# Define species of interest +species = ["Homo sapiens", "Pan troglodytes", "Gorilla gorilla", + "Mus musculus", "Rattus norvegicus"] + +# Get taxids +name2taxid = ncbi.get_name_translator(species) +taxids = [name2taxid[sp][0] for sp in species] + +# Build tree +tree = ncbi.get_topology(taxids) + +# Annotate with taxonomy info +for node in tree.traverse(): + if hasattr(node, "sci_name"): + print(f"{node.sci_name} - Rank: {node.rank} - TaxID: {node.taxid}") + +# Save tree +tree.write(outfile="species_tree.nw") +``` + +### Annotating Existing Tree with NCBI Taxonomy + +```python +from ete3 import Tree, NCBITaxa + +tree = Tree("species_tree.nw") +ncbi = NCBITaxa() + +# Map leaf names to species names (adjust as needed) +leaf_to_species = { + "Hsap_gene1": "Homo sapiens", + "Ptro_gene1": "Pan troglodytes", + "Mmur_gene1": "Microcebus murinus", +} + +# Get taxids +all_species = list(set(leaf_to_species.values())) +name2taxid = ncbi.get_name_translator(all_species) + +# Annotate leaves +for leaf in tree: + if leaf.name in leaf_to_species: + species_name = leaf_to_species[leaf.name] + taxid = name2taxid[species_name][0] + + # Add taxonomy info + leaf.add_feature("species", species_name) + leaf.add_feature("taxid", taxid) + + # Get full lineage + lineage = ncbi.get_lineage(taxid) + names = ncbi.get_taxid_translator(lineage) + leaf.add_feature("lineage", [names[t] for t in lineage]) + + print(f"{leaf.name}: {species_name} (taxid: {taxid})") +``` + +### Querying NCBI Taxonomy + +```python +from ete3 import NCBITaxa + +ncbi = NCBITaxa() + +# Get all primates +primates_taxid = ncbi.get_name_translator(["Primates"])["Primates"][0] +all_primates = ncbi.get_descendant_taxa(primates_taxid, collapse_subspecies=True) + +print(f"Total primate species: {len(all_primates)}") + +# Get names for subset +taxid2name = ncbi.get_taxid_translator(all_primates[:10]) +for taxid, name in taxid2name.items(): + rank = ncbi.get_rank([taxid])[taxid] + print(f"{name} ({rank})") + +# Get lineage for specific species +human_taxid = 9606 +lineage = ncbi.get_lineage(human_taxid) +ranks = ncbi.get_rank(lineage) +names = ncbi.get_taxid_translator(lineage) + +print("\nHuman lineage:") +for taxid in lineage: + print(f"{ranks[taxid]:15s} {names[taxid]}") +``` + +--- + +## Clustering Analysis + +### Analyzing Hierarchical Clustering Results + +```python +from ete3 import ClusterTree + +# Load clustering tree with data matrix +matrix = """#Names\tSample1\tSample2\tSample3\tSample4 +Gene1\t1.5\t2.3\t0.8\t1.2 +Gene2\t0.9\t1.1\t1.8\t2.1 +Gene3\t2.1\t2.5\t0.5\t0.9 +Gene4\t0.7\t0.9\t2.2\t2.4""" + +tree = ClusterTree("((Gene1,Gene2),(Gene3,Gene4));", text_array=matrix) + +# Calculate cluster quality metrics +for node in tree.traverse(): + if not node.is_leaf(): + # Silhouette coefficient + silhouette = node.get_silhouette() + + # Dunn index + dunn = node.get_dunn() + + # Distances + inter = node.intercluster_dist + intra = node.intracluster_dist + + print(f"Node: {node.name}") + print(f" Silhouette: {silhouette:.3f}") + print(f" Dunn index: {dunn:.3f}") + print(f" Intercluster distance: {inter:.3f}") + print(f" Intracluster distance: {intra:.3f}") +``` + +### Validating Clusters + +```python +from ete3 import ClusterTree + +matrix = """#Names\tCol1\tCol2\tCol3 +ItemA\t1.2\t0.5\t0.8 +ItemB\t1.3\t0.6\t0.9 +ItemC\t0.1\t2.5\t2.3 +ItemD\t0.2\t2.6\t2.4""" + +tree = ClusterTree("((ItemA,ItemB),(ItemC,ItemD));", text_array=matrix) + +# Test different distance metrics +metrics = ["euclidean", "pearson", "spearman"] + +for metric in metrics: + print(f"\nUsing {metric} distance:") + + for node in tree.traverse(): + if not node.is_leaf(): + silhouette = node.get_silhouette(distance=metric) + + # Positive silhouette = good clustering + # Negative silhouette = poor clustering + quality = "good" if silhouette > 0 else "poor" + + print(f" Cluster {node.name}: {silhouette:.3f} ({quality})") +``` + +--- + +## Tree Visualization + +### Basic Tree Rendering + +```python +from ete3 import Tree, TreeStyle + +tree = Tree("tree.nw") + +# Create tree style +ts = TreeStyle() +ts.show_leaf_name = True +ts.show_branch_length = True +ts.show_branch_support = True +ts.scale = 50 # pixels per branch length unit + +# Render to file +tree.render("tree_output.pdf", tree_style=ts) +tree.render("tree_output.png", tree_style=ts, w=800, h=600, units="px") +tree.render("tree_output.svg", tree_style=ts) +``` + +### Customizing Node Appearance + +```python +from ete3 import Tree, TreeStyle, NodeStyle + +tree = Tree("tree.nw") + +# Define node styles +for node in tree.traverse(): + nstyle = NodeStyle() + + if node.is_leaf(): + nstyle["fgcolor"] = "blue" + nstyle["size"] = 10 + else: + nstyle["fgcolor"] = "red" + nstyle["size"] = 5 + + if node.support > 0.9: + nstyle["shape"] = "sphere" + else: + nstyle["shape"] = "circle" + + node.set_style(nstyle) + +# Render +ts = TreeStyle() +tree.render("styled_tree.pdf", tree_style=ts) +``` + +### Adding Faces to Nodes + +```python +from ete3 import Tree, TreeStyle, TextFace, CircleFace, AttrFace + +tree = Tree("tree.nw") + +# Add features to nodes +for leaf in tree: + leaf.add_feature("habitat", "marine" if "fish" in leaf.name else "terrestrial") + leaf.add_feature("temp", 20) + +# Layout function to add faces +def layout(node): + if node.is_leaf(): + # Add text face + name_face = TextFace(node.name, fsize=10) + node.add_face(name_face, column=0, position="branch-right") + + # Add colored circle based on habitat + color = "blue" if node.habitat == "marine" else "green" + circle_face = CircleFace(radius=5, color=color) + node.add_face(circle_face, column=1, position="branch-right") + + # Add attribute face + temp_face = AttrFace("temp", fsize=8) + node.add_face(temp_face, column=2, position="branch-right") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = False # We're adding custom names + +tree.render("tree_with_faces.pdf", tree_style=ts) +``` + +### Circular Tree Layout + +```python +from ete3 import Tree, TreeStyle + +tree = Tree("tree.nw") + +ts = TreeStyle() +ts.mode = "c" # Circular mode +ts.arc_start = 0 # Degrees +ts.arc_span = 360 # Full circle +ts.show_leaf_name = True + +tree.render("circular_tree.pdf", tree_style=ts) +``` + +### Interactive Exploration + +```python +from ete3 import Tree + +tree = Tree("tree.nw") + +# Launch GUI (allows zooming, searching, modifying) +# Changes persist after closing +tree.show() + +# Can save changes made in GUI +tree.write(outfile="modified_tree.nw") +``` + +--- + +## Advanced Workflows + +### Complete Phylogenomic Pipeline + +```python +from ete3 import PhyloTree, NCBITaxa, TreeStyle + +# 1. Load gene tree +gene_tree = PhyloTree("gene_tree.nw", alignment="alignment.fasta") + +# 2. Set species naming +gene_tree.set_species_naming_function(lambda x: x.split("_")[0]) + +# 3. Detect evolutionary events +gene_tree.get_descendant_evol_events() + +# 4. Annotate with NCBI taxonomy +ncbi = NCBITaxa() +species_set = set([leaf.species for leaf in gene_tree]) +name2taxid = ncbi.get_name_translator(list(species_set)) + +for leaf in gene_tree: + if leaf.species in name2taxid: + taxid = name2taxid[leaf.species][0] + lineage = ncbi.get_lineage(taxid) + names = ncbi.get_taxid_translator(lineage) + leaf.add_feature("lineage", [names[t] for t in lineage]) + +# 5. Identify and save ortholog groups +ortho_groups = gene_tree.get_speciation_trees() + +for i, ortho_tree in enumerate(ortho_groups): + ortho_tree.write(outfile=f"ortholog_group_{i}.nw") + +# 6. Visualize with evolutionary events marked +def layout(node): + from ete3 import TextFace + if hasattr(node, "evoltype"): + if node.evoltype == "D": + dup_face = TextFace("DUPLICATION", fsize=8, fgcolor="red") + node.add_face(dup_face, column=0, position="branch-top") + +ts = TreeStyle() +ts.layout_fn = layout +ts.show_leaf_name = True +gene_tree.render("annotated_gene_tree.pdf", tree_style=ts) + +print(f"Pipeline complete. Found {len(ortho_groups)} ortholog groups.") +``` + +### Batch Processing Multiple Trees + +```python +from ete3 import Tree +import os + +input_dir = "input_trees" +output_dir = "processed_trees" +os.makedirs(output_dir, exist_ok=True) + +for filename in os.listdir(input_dir): + if filename.endswith(".nw"): + # Load tree + tree = Tree(os.path.join(input_dir, filename)) + + # Process: root, prune, annotate + midpoint = tree.get_midpoint_outgroup() + tree.set_outgroup(midpoint) + + # Filter by branch length + to_remove = [] + for node in tree.traverse(): + if node.dist < 0.001 and not node.is_root(): + to_remove.append(node) + + for node in to_remove: + node.delete() + + # Save processed tree + output_file = os.path.join(output_dir, f"processed_{filename}") + tree.write(outfile=output_file) + + print(f"Processed {filename}") +``` diff --git a/scientific-packages/etetoolkit/scripts/quick_visualize.py b/scientific-packages/etetoolkit/scripts/quick_visualize.py new file mode 100755 index 0000000..757baa7 --- /dev/null +++ b/scientific-packages/etetoolkit/scripts/quick_visualize.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Quick tree visualization script with common customization options. + +Provides command-line interface for rapid tree visualization with +customizable styles, layouts, and output formats. +""" + +import argparse +import sys +from pathlib import Path + +try: + from ete3 import Tree, TreeStyle, NodeStyle +except ImportError: + print("Error: ete3 not installed. Install with: pip install ete3") + sys.exit(1) + + +def create_tree_style(args): + """Create TreeStyle based on arguments.""" + ts = TreeStyle() + + # Basic display options + ts.show_leaf_name = args.show_names + ts.show_branch_length = args.show_lengths + ts.show_branch_support = args.show_support + ts.show_scale = args.show_scale + + # Layout + ts.mode = args.mode + ts.rotation = args.rotation + + # Circular tree options + if args.mode == "c": + ts.arc_start = args.arc_start + ts.arc_span = args.arc_span + + # Spacing + ts.branch_vertical_margin = args.vertical_margin + if args.scale_factor: + ts.scale = args.scale_factor + + # Title + if args.title: + from ete3 import TextFace + title_face = TextFace(args.title, fsize=16, bold=True) + ts.title.add_face(title_face, column=0) + + return ts + + +def apply_node_styling(tree, args): + """Apply styling to tree nodes.""" + for node in tree.traverse(): + nstyle = NodeStyle() + + if node.is_leaf(): + # Leaf style + nstyle["fgcolor"] = args.leaf_color + nstyle["size"] = args.leaf_size + else: + # Internal node style + nstyle["fgcolor"] = args.internal_color + nstyle["size"] = args.internal_size + + # Color by support if enabled + if args.color_by_support and hasattr(node, 'support') and node.support: + if node.support >= 0.9: + nstyle["fgcolor"] = "darkgreen" + elif node.support >= 0.7: + nstyle["fgcolor"] = "orange" + else: + nstyle["fgcolor"] = "red" + + node.set_style(nstyle) + + +def visualize_tree(tree_file, output, args): + """Load tree, apply styles, and render.""" + try: + tree = Tree(str(tree_file), format=args.format) + except Exception as e: + print(f"Error loading tree: {e}") + sys.exit(1) + + # Apply styling + apply_node_styling(tree, args) + + # Create tree style + ts = create_tree_style(args) + + # Render + try: + # Determine output parameters based on format + output_path = str(output) + + render_args = {"tree_style": ts} + + if args.width: + render_args["w"] = args.width + if args.height: + render_args["h"] = args.height + if args.units: + render_args["units"] = args.units + if args.dpi: + render_args["dpi"] = args.dpi + + tree.render(output_path, **render_args) + print(f"Tree rendered successfully to: {output}") + + except Exception as e: + print(f"Error rendering tree: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Quick tree visualization with ETE toolkit", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic visualization + %(prog)s tree.nw output.pdf + + # Circular tree + %(prog)s tree.nw output.pdf --mode c + + # Large tree with custom sizing + %(prog)s tree.nw output.png --width 1200 --height 800 --units px --dpi 300 + + # Hide names, show support, color by support + %(prog)s tree.nw output.pdf --no-names --show-support --color-by-support + + # Custom title + %(prog)s tree.nw output.pdf --title "Phylogenetic Tree of Species" + + # Semicircular layout + %(prog)s tree.nw output.pdf --mode c --arc-start -90 --arc-span 180 + """ + ) + + parser.add_argument("input", help="Input tree file (Newick format)") + parser.add_argument("output", help="Output image file (png, pdf, or svg)") + + # Tree format + parser.add_argument("--format", type=int, default=0, + help="Newick format number (default: 0)") + + # Display options + display = parser.add_argument_group("Display options") + display.add_argument("--no-names", dest="show_names", action="store_false", + help="Don't show leaf names") + display.add_argument("--show-lengths", action="store_true", + help="Show branch lengths") + display.add_argument("--show-support", action="store_true", + help="Show support values") + display.add_argument("--show-scale", action="store_true", + help="Show scale bar") + + # Layout options + layout = parser.add_argument_group("Layout options") + layout.add_argument("--mode", choices=["r", "c"], default="r", + help="Tree mode: r=rectangular, c=circular (default: r)") + layout.add_argument("--rotation", type=int, default=0, + help="Tree rotation in degrees (default: 0)") + layout.add_argument("--arc-start", type=int, default=0, + help="Circular tree start angle (default: 0)") + layout.add_argument("--arc-span", type=int, default=360, + help="Circular tree arc span (default: 360)") + + # Styling options + styling = parser.add_argument_group("Styling options") + styling.add_argument("--leaf-color", default="blue", + help="Leaf node color (default: blue)") + styling.add_argument("--leaf-size", type=int, default=6, + help="Leaf node size (default: 6)") + styling.add_argument("--internal-color", default="gray", + help="Internal node color (default: gray)") + styling.add_argument("--internal-size", type=int, default=4, + help="Internal node size (default: 4)") + styling.add_argument("--color-by-support", action="store_true", + help="Color internal nodes by support value") + + # Size and spacing + size = parser.add_argument_group("Size and spacing") + size.add_argument("--width", type=int, help="Output width") + size.add_argument("--height", type=int, help="Output height") + size.add_argument("--units", choices=["px", "mm", "in"], + help="Size units (px, mm, in)") + size.add_argument("--dpi", type=int, help="DPI for raster output") + size.add_argument("--scale-factor", type=int, + help="Branch length scale factor (pixels per unit)") + size.add_argument("--vertical-margin", type=int, default=10, + help="Vertical margin between branches (default: 10)") + + # Other options + parser.add_argument("--title", help="Tree title") + + args = parser.parse_args() + + # Validate output format + output_path = Path(args.output) + valid_extensions = {".png", ".pdf", ".svg"} + if output_path.suffix.lower() not in valid_extensions: + print(f"Error: Output must be PNG, PDF, or SVG file") + sys.exit(1) + + # Visualize + visualize_tree(args.input, args.output, args) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/etetoolkit/scripts/tree_operations.py b/scientific-packages/etetoolkit/scripts/tree_operations.py new file mode 100755 index 0000000..6c53da3 --- /dev/null +++ b/scientific-packages/etetoolkit/scripts/tree_operations.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +""" +Tree operations helper script for common ETE toolkit tasks. + +Provides command-line interface for basic tree operations like: +- Format conversion +- Rooting (outgroup, midpoint) +- Pruning +- Basic statistics +- ASCII visualization +""" + +import argparse +import sys +from pathlib import Path + +try: + from ete3 import Tree +except ImportError: + print("Error: ete3 not installed. Install with: pip install ete3") + sys.exit(1) + + +def load_tree(tree_file, format_num=0): + """Load tree from file.""" + try: + return Tree(str(tree_file), format=format_num) + except Exception as e: + print(f"Error loading tree: {e}") + sys.exit(1) + + +def convert_format(tree_file, output, in_format=0, out_format=1): + """Convert tree between Newick formats.""" + tree = load_tree(tree_file, in_format) + tree.write(outfile=str(output), format=out_format) + print(f"Converted {tree_file} (format {in_format}) → {output} (format {out_format})") + + +def reroot_tree(tree_file, output, outgroup=None, midpoint=False, format_num=0): + """Reroot tree by outgroup or midpoint.""" + tree = load_tree(tree_file, format_num) + + if midpoint: + midpoint_node = tree.get_midpoint_outgroup() + tree.set_outgroup(midpoint_node) + print(f"Rerooted tree using midpoint method") + elif outgroup: + try: + outgroup_node = tree & outgroup + tree.set_outgroup(outgroup_node) + print(f"Rerooted tree using outgroup: {outgroup}") + except Exception as e: + print(f"Error: Could not find outgroup '{outgroup}': {e}") + sys.exit(1) + else: + print("Error: Must specify either --outgroup or --midpoint") + sys.exit(1) + + tree.write(outfile=str(output), format=format_num) + print(f"Saved rerooted tree to: {output}") + + +def prune_tree(tree_file, output, keep_taxa, preserve_length=True, format_num=0): + """Prune tree to keep only specified taxa.""" + tree = load_tree(tree_file, format_num) + + # Read taxa list + taxa_file = Path(keep_taxa) + if taxa_file.exists(): + with open(taxa_file) as f: + taxa = [line.strip() for line in f if line.strip()] + else: + taxa = [t.strip() for t in keep_taxa.split(",")] + + print(f"Pruning tree to {len(taxa)} taxa") + + try: + tree.prune(taxa, preserve_branch_length=preserve_length) + tree.write(outfile=str(output), format=format_num) + print(f"Pruned tree saved to: {output}") + print(f"Retained {len(tree)} leaves") + except Exception as e: + print(f"Error pruning tree: {e}") + sys.exit(1) + + +def tree_stats(tree_file, format_num=0): + """Display tree statistics.""" + tree = load_tree(tree_file, format_num) + + print(f"\n=== Tree Statistics ===") + print(f"File: {tree_file}") + print(f"Number of leaves: {len(tree)}") + print(f"Total nodes: {len(list(tree.traverse()))}") + + farthest_leaf, distance = tree.get_farthest_leaf() + print(f"Tree depth: {distance:.4f}") + print(f"Farthest leaf: {farthest_leaf.name}") + + # Branch length statistics + branch_lengths = [node.dist for node in tree.traverse() if not node.is_root()] + if branch_lengths: + print(f"\nBranch length statistics:") + print(f" Mean: {sum(branch_lengths)/len(branch_lengths):.4f}") + print(f" Min: {min(branch_lengths):.4f}") + print(f" Max: {max(branch_lengths):.4f}") + + # Support values + supports = [node.support for node in tree.traverse() if not node.is_leaf() and hasattr(node, 'support')] + if supports: + print(f"\nSupport value statistics:") + print(f" Mean: {sum(supports)/len(supports):.2f}") + print(f" Min: {min(supports):.2f}") + print(f" Max: {max(supports):.2f}") + + print() + + +def show_ascii(tree_file, format_num=0, show_internal=True): + """Display tree as ASCII art.""" + tree = load_tree(tree_file, format_num) + print(tree.get_ascii(show_internal=show_internal)) + + +def list_leaves(tree_file, format_num=0): + """List all leaf names.""" + tree = load_tree(tree_file, format_num) + for leaf in tree: + print(leaf.name) + + +def main(): + parser = argparse.ArgumentParser( + description="ETE toolkit tree operations helper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Convert format + %(prog)s convert input.nw output.nw --in-format 0 --out-format 1 + + # Midpoint root + %(prog)s reroot input.nw output.nw --midpoint + + # Reroot with outgroup + %(prog)s reroot input.nw output.nw --outgroup "Outgroup_species" + + # Prune tree + %(prog)s prune input.nw output.nw --keep-taxa "speciesA,speciesB,speciesC" + + # Show statistics + %(prog)s stats input.nw + + # Display as ASCII + %(prog)s ascii input.nw + + # List all leaves + %(prog)s leaves input.nw + """ + ) + + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + + # Convert command + convert_parser = subparsers.add_parser("convert", help="Convert tree format") + convert_parser.add_argument("input", help="Input tree file") + convert_parser.add_argument("output", help="Output tree file") + convert_parser.add_argument("--in-format", type=int, default=0, help="Input format (default: 0)") + convert_parser.add_argument("--out-format", type=int, default=1, help="Output format (default: 1)") + + # Reroot command + reroot_parser = subparsers.add_parser("reroot", help="Reroot tree") + reroot_parser.add_argument("input", help="Input tree file") + reroot_parser.add_argument("output", help="Output tree file") + reroot_parser.add_argument("--outgroup", help="Outgroup taxon name") + reroot_parser.add_argument("--midpoint", action="store_true", help="Use midpoint rooting") + reroot_parser.add_argument("--format", type=int, default=0, help="Newick format (default: 0)") + + # Prune command + prune_parser = subparsers.add_parser("prune", help="Prune tree to specified taxa") + prune_parser.add_argument("input", help="Input tree file") + prune_parser.add_argument("output", help="Output tree file") + prune_parser.add_argument("--keep-taxa", required=True, + help="Taxa to keep (comma-separated or file path)") + prune_parser.add_argument("--no-preserve-length", action="store_true", + help="Don't preserve branch lengths") + prune_parser.add_argument("--format", type=int, default=0, help="Newick format (default: 0)") + + # Stats command + stats_parser = subparsers.add_parser("stats", help="Display tree statistics") + stats_parser.add_argument("input", help="Input tree file") + stats_parser.add_argument("--format", type=int, default=0, help="Newick format (default: 0)") + + # ASCII command + ascii_parser = subparsers.add_parser("ascii", help="Display tree as ASCII art") + ascii_parser.add_argument("input", help="Input tree file") + ascii_parser.add_argument("--format", type=int, default=0, help="Newick format (default: 0)") + ascii_parser.add_argument("--no-internal", action="store_true", + help="Don't show internal node names") + + # Leaves command + leaves_parser = subparsers.add_parser("leaves", help="List all leaf names") + leaves_parser.add_argument("input", help="Input tree file") + leaves_parser.add_argument("--format", type=int, default=0, help="Newick format (default: 0)") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + sys.exit(1) + + # Execute command + if args.command == "convert": + convert_format(args.input, args.output, args.in_format, args.out_format) + elif args.command == "reroot": + reroot_tree(args.input, args.output, args.outgroup, args.midpoint, args.format) + elif args.command == "prune": + prune_tree(args.input, args.output, args.keep_taxa, + not args.no_preserve_length, args.format) + elif args.command == "stats": + tree_stats(args.input, args.format) + elif args.command == "ascii": + show_ascii(args.input, args.format, not args.no_internal) + elif args.command == "leaves": + list_leaves(args.input, args.format) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/flowio/SKILL.md b/scientific-packages/flowio/SKILL.md new file mode 100644 index 0000000..7a7d0e6 --- /dev/null +++ b/scientific-packages/flowio/SKILL.md @@ -0,0 +1,602 @@ +--- +name: flowio +description: Toolkit for working with Flow Cytometry Standard (FCS) files in Python. Use this skill when reading, parsing, creating, or exporting FCS files (versions 2.0, 3.0, 3.1), extracting flow cytometry metadata, accessing event data, handling multi-dataset FCS files, or converting between FCS formats. Essential for flow cytometry data processing, channel analysis, and cytometry file manipulation tasks. +--- + +# FlowIO: Flow Cytometry Standard File Handler + +## Overview + +FlowIO is a lightweight Python library for reading and writing Flow Cytometry Standard (FCS) files. It excels at parsing FCS metadata, extracting event data, and creating new FCS files with minimal dependencies. The library supports FCS versions 2.0, 3.0, and 3.1, making it ideal for backend services, data pipelines, and basic cytometry file operations. + +## When to Use This Skill + +Apply this skill when working with: + +- FCS files requiring parsing or metadata extraction +- Flow cytometry data needing conversion to NumPy arrays +- Event data requiring export to FCS format +- Multi-dataset FCS files needing separation +- Channel information extraction (scatter, fluorescence, time) +- Cytometry file validation or inspection +- Pre-processing workflows before advanced analysis + +**Related Tools:** For advanced flow cytometry analysis including compensation, gating, and FlowJo/GatingML support, recommend FlowKit library as a companion to FlowIO. + +## Installation + +```bash +pip install flowio +``` + +Requires Python 3.9 or later. + +## Quick Start + +### Basic File Reading + +```python +from flowio import FlowData + +# Read FCS file +flow_data = FlowData('experiment.fcs') + +# Access basic information +print(f"FCS Version: {flow_data.version}") +print(f"Events: {flow_data.event_count}") +print(f"Channels: {flow_data.pnn_labels}") + +# Get event data as NumPy array +events = flow_data.as_array() # Shape: (events, channels) +``` + +### Creating FCS Files + +```python +import numpy as np +from flowio import create_fcs + +# Prepare data +data = np.array([[100, 200, 50], [150, 180, 60]]) # 2 events, 3 channels +channels = ['FSC-A', 'SSC-A', 'FL1-A'] + +# Create FCS file +create_fcs('output.fcs', data, channels) +``` + +## Core Workflows + +### Reading and Parsing FCS Files + +The FlowData class provides the primary interface for reading FCS files. + +**Standard Reading:** + +```python +from flowio import FlowData + +# Basic reading +flow = FlowData('sample.fcs') + +# Access attributes +version = flow.version # '3.0', '3.1', etc. +event_count = flow.event_count # Number of events +channel_count = flow.channel_count # Number of channels +pnn_labels = flow.pnn_labels # Short channel names +pns_labels = flow.pns_labels # Descriptive stain names + +# Get event data +events = flow.as_array() # Preprocessed (gain, log scaling applied) +raw_events = flow.as_array(preprocess=False) # Raw data +``` + +**Memory-Efficient Metadata Reading:** + +When only metadata is needed (no event data): + +```python +# Only parse TEXT segment, skip DATA and ANALYSIS +flow = FlowData('sample.fcs', only_text=True) + +# Access metadata +metadata = flow.text # Dictionary of TEXT segment keywords +print(metadata.get('$DATE')) # Acquisition date +print(metadata.get('$CYT')) # Instrument name +``` + +**Handling Problematic Files:** + +Some FCS files have offset discrepancies or errors: + +```python +# Ignore offset discrepancies between HEADER and TEXT sections +flow = FlowData('problematic.fcs', ignore_offset_discrepancy=True) + +# Use HEADER offsets instead of TEXT offsets +flow = FlowData('problematic.fcs', use_header_offsets=True) + +# Ignore offset errors entirely +flow = FlowData('problematic.fcs', ignore_offset_error=True) +``` + +**Excluding Null Channels:** + +```python +# Exclude specific channels during parsing +flow = FlowData('sample.fcs', null_channel_list=['Time', 'Null']) +``` + +### Extracting Metadata and Channel Information + +FCS files contain rich metadata in the TEXT segment. + +**Common Metadata Keywords:** + +```python +flow = FlowData('sample.fcs') + +# File-level metadata +text_dict = flow.text +acquisition_date = text_dict.get('$DATE', 'Unknown') +instrument = text_dict.get('$CYT', 'Unknown') +data_type = flow.data_type # 'I', 'F', 'D', 'A' + +# Channel metadata +for i in range(flow.channel_count): + pnn = flow.pnn_labels[i] # Short name (e.g., 'FSC-A') + pns = flow.pns_labels[i] # Descriptive name (e.g., 'Forward Scatter') + pnr = flow.pnr_values[i] # Range/max value + print(f"Channel {i}: {pnn} ({pns}), Range: {pnr}") +``` + +**Channel Type Identification:** + +FlowIO automatically categorizes channels: + +```python +# Get indices by channel type +scatter_idx = flow.scatter_indices # [0, 1] for FSC, SSC +fluoro_idx = flow.fluoro_indices # [2, 3, 4] for FL channels +time_idx = flow.time_index # Index of time channel (or None) + +# Access specific channel types +events = flow.as_array() +scatter_data = events[:, scatter_idx] +fluorescence_data = events[:, fluoro_idx] +``` + +**ANALYSIS Segment:** + +If present, access processed results: + +```python +if flow.analysis: + analysis_keywords = flow.analysis # Dictionary of ANALYSIS keywords + print(analysis_keywords) +``` + +### Creating New FCS Files + +Generate FCS files from NumPy arrays or other data sources. + +**Basic Creation:** + +```python +import numpy as np +from flowio import create_fcs + +# Create event data (rows=events, columns=channels) +events = np.random.rand(10000, 5) * 1000 + +# Define channel names +channel_names = ['FSC-A', 'SSC-A', 'FL1-A', 'FL2-A', 'Time'] + +# Create FCS file +create_fcs('output.fcs', events, channel_names) +``` + +**With Descriptive Channel Names:** + +```python +# Add optional descriptive names (PnS) +channel_names = ['FSC-A', 'SSC-A', 'FL1-A', 'FL2-A', 'Time'] +descriptive_names = ['Forward Scatter', 'Side Scatter', 'FITC', 'PE', 'Time'] + +create_fcs('output.fcs', + events, + channel_names, + opt_channel_names=descriptive_names) +``` + +**With Custom Metadata:** + +```python +# Add TEXT segment metadata +metadata = { + '$SRC': 'Python script', + '$DATE': '19-OCT-2025', + '$CYT': 'Synthetic Instrument', + '$INST': 'Laboratory A' +} + +create_fcs('output.fcs', + events, + channel_names, + opt_channel_names=descriptive_names, + metadata=metadata) +``` + +**Note:** FlowIO exports as FCS 3.1 with single-precision floating-point data. + +### Exporting Modified Data + +Modify existing FCS files and re-export them. + +**Approach 1: Using write_fcs() Method:** + +```python +from flowio import FlowData + +# Read original file +flow = FlowData('original.fcs') + +# Write with updated metadata +flow.write_fcs('modified.fcs', metadata={'$SRC': 'Modified data'}) +``` + +**Approach 2: Extract, Modify, and Recreate:** + +For modifying event data: + +```python +from flowio import FlowData, create_fcs + +# Read and extract data +flow = FlowData('original.fcs') +events = flow.as_array(preprocess=False) + +# Modify event data +events[:, 0] = events[:, 0] * 1.5 # Scale first channel + +# Create new FCS file with modified data +create_fcs('modified.fcs', + events, + flow.pnn_labels, + opt_channel_names=flow.pns_labels, + metadata=flow.text) +``` + +### Handling Multi-Dataset FCS Files + +Some FCS files contain multiple datasets in a single file. + +**Detecting Multi-Dataset Files:** + +```python +from flowio import FlowData, MultipleDataSetsError + +try: + flow = FlowData('sample.fcs') +except MultipleDataSetsError: + print("File contains multiple datasets") + # Use read_multiple_data_sets() instead +``` + +**Reading All Datasets:** + +```python +from flowio import read_multiple_data_sets + +# Read all datasets from file +datasets = read_multiple_data_sets('multi_dataset.fcs') + +print(f"Found {len(datasets)} datasets") + +# Process each dataset +for i, dataset in enumerate(datasets): + print(f"\nDataset {i}:") + print(f" Events: {dataset.event_count}") + print(f" Channels: {dataset.pnn_labels}") + + # Get event data for this dataset + events = dataset.as_array() + print(f" Shape: {events.shape}") + print(f" Mean values: {events.mean(axis=0)}") +``` + +**Reading Specific Dataset:** + +```python +from flowio import FlowData + +# Read first dataset (nextdata_offset=0) +first_dataset = FlowData('multi.fcs', nextdata_offset=0) + +# Read second dataset using NEXTDATA offset from first +next_offset = int(first_dataset.text['$NEXTDATA']) +if next_offset > 0: + second_dataset = FlowData('multi.fcs', nextdata_offset=next_offset) +``` + +## Data Preprocessing + +FlowIO applies standard FCS preprocessing transformations when `preprocess=True`. + +**Preprocessing Steps:** + +1. **Gain Scaling:** Multiply values by PnG (gain) keyword +2. **Logarithmic Transformation:** Apply PnE exponential transformation if present + - Formula: `value = a * 10^(b * raw_value)` where PnE = "a,b" +3. **Time Scaling:** Convert time values to appropriate units + +**Controlling Preprocessing:** + +```python +# Preprocessed data (default) +preprocessed = flow.as_array(preprocess=True) + +# Raw data (no transformations) +raw = flow.as_array(preprocess=False) +``` + +## Error Handling + +Handle common FlowIO exceptions appropriately. + +```python +from flowio import ( + FlowData, + FCSParsingError, + DataOffsetDiscrepancyError, + MultipleDataSetsError +) + +try: + flow = FlowData('sample.fcs') + events = flow.as_array() + +except FCSParsingError as e: + print(f"Failed to parse FCS file: {e}") + # Try with relaxed parsing + flow = FlowData('sample.fcs', ignore_offset_error=True) + +except DataOffsetDiscrepancyError as e: + print(f"Offset discrepancy detected: {e}") + # Use ignore_offset_discrepancy parameter + flow = FlowData('sample.fcs', ignore_offset_discrepancy=True) + +except MultipleDataSetsError as e: + print(f"Multiple datasets detected: {e}") + # Use read_multiple_data_sets instead + from flowio import read_multiple_data_sets + datasets = read_multiple_data_sets('sample.fcs') + +except Exception as e: + print(f"Unexpected error: {e}") +``` + +## Common Use Cases + +### Inspecting FCS File Contents + +Quick exploration of FCS file structure: + +```python +from flowio import FlowData + +flow = FlowData('unknown.fcs') + +print("=" * 50) +print(f"File: {flow.name}") +print(f"Version: {flow.version}") +print(f"Size: {flow.file_size:,} bytes") +print("=" * 50) + +print(f"\nEvents: {flow.event_count:,}") +print(f"Channels: {flow.channel_count}") + +print("\nChannel Information:") +for i, (pnn, pns) in enumerate(zip(flow.pnn_labels, flow.pns_labels)): + ch_type = "scatter" if i in flow.scatter_indices else \ + "fluoro" if i in flow.fluoro_indices else \ + "time" if i == flow.time_index else "other" + print(f" [{i}] {pnn:10s} | {pns:30s} | {ch_type}") + +print("\nKey Metadata:") +for key in ['$DATE', '$BTIM', '$ETIM', '$CYT', '$INST', '$SRC']: + value = flow.text.get(key, 'N/A') + print(f" {key:15s}: {value}") +``` + +### Batch Processing Multiple Files + +Process a directory of FCS files: + +```python +from pathlib import Path +from flowio import FlowData +import pandas as pd + +# Find all FCS files +fcs_files = list(Path('data/').glob('*.fcs')) + +# Extract summary information +summaries = [] +for fcs_path in fcs_files: + try: + flow = FlowData(str(fcs_path), only_text=True) + summaries.append({ + 'filename': fcs_path.name, + 'version': flow.version, + 'events': flow.event_count, + 'channels': flow.channel_count, + 'date': flow.text.get('$DATE', 'N/A') + }) + except Exception as e: + print(f"Error processing {fcs_path.name}: {e}") + +# Create summary DataFrame +df = pd.DataFrame(summaries) +print(df) +``` + +### Converting FCS to CSV + +Export event data to CSV format: + +```python +from flowio import FlowData +import pandas as pd + +# Read FCS file +flow = FlowData('sample.fcs') + +# Convert to DataFrame +df = pd.DataFrame( + flow.as_array(), + columns=flow.pnn_labels +) + +# Add metadata as attributes +df.attrs['fcs_version'] = flow.version +df.attrs['instrument'] = flow.text.get('$CYT', 'Unknown') + +# Export to CSV +df.to_csv('output.csv', index=False) +print(f"Exported {len(df)} events to CSV") +``` + +### Filtering Events and Re-exporting + +Apply filters and save filtered data: + +```python +from flowio import FlowData, create_fcs +import numpy as np + +# Read original file +flow = FlowData('sample.fcs') +events = flow.as_array(preprocess=False) + +# Apply filtering (example: threshold on first channel) +fsc_idx = 0 +threshold = 500 +mask = events[:, fsc_idx] > threshold +filtered_events = events[mask] + +print(f"Original events: {len(events)}") +print(f"Filtered events: {len(filtered_events)}") + +# Create new FCS file with filtered data +create_fcs('filtered.fcs', + filtered_events, + flow.pnn_labels, + opt_channel_names=flow.pns_labels, + metadata={**flow.text, '$SRC': 'Filtered data'}) +``` + +### Extracting Specific Channels + +Extract and process specific channels: + +```python +from flowio import FlowData +import numpy as np + +flow = FlowData('sample.fcs') +events = flow.as_array() + +# Extract fluorescence channels only +fluoro_indices = flow.fluoro_indices +fluoro_data = events[:, fluoro_indices] +fluoro_names = [flow.pnn_labels[i] for i in fluoro_indices] + +print(f"Fluorescence channels: {fluoro_names}") +print(f"Shape: {fluoro_data.shape}") + +# Calculate statistics per channel +for i, name in enumerate(fluoro_names): + channel_data = fluoro_data[:, i] + print(f"\n{name}:") + print(f" Mean: {channel_data.mean():.2f}") + print(f" Median: {np.median(channel_data):.2f}") + print(f" Std Dev: {channel_data.std():.2f}") +``` + +## Best Practices + +1. **Memory Efficiency:** Use `only_text=True` when event data is not needed +2. **Error Handling:** Wrap file operations in try-except blocks for robust code +3. **Multi-Dataset Detection:** Check for MultipleDataSetsError and use appropriate function +4. **Preprocessing Control:** Explicitly set `preprocess` parameter based on analysis needs +5. **Offset Issues:** If parsing fails, try `ignore_offset_discrepancy=True` parameter +6. **Channel Validation:** Verify channel counts and names match expectations before processing +7. **Metadata Preservation:** When modifying files, preserve original TEXT segment keywords + +## Advanced Topics + +### Understanding FCS File Structure + +FCS files consist of four segments: + +1. **HEADER:** FCS version and byte offsets for other segments +2. **TEXT:** Key-value metadata pairs (delimiter-separated) +3. **DATA:** Raw event data (binary/float/ASCII format) +4. **ANALYSIS** (optional): Results from data processing + +Access these segments via FlowData attributes: +- `flow.header` - HEADER segment +- `flow.text` - TEXT segment keywords +- `flow.events` - DATA segment (as bytes) +- `flow.analysis` - ANALYSIS segment keywords (if present) + +### Detailed API Reference + +For comprehensive API documentation including all parameters, methods, exceptions, and FCS keyword reference, consult the detailed reference file: + +**Read:** `references/api_reference.md` + +The reference includes: +- Complete FlowData class documentation +- All utility functions (read_multiple_data_sets, create_fcs) +- Exception classes and handling +- FCS file structure details +- Common TEXT segment keywords +- Extended example workflows + +When working with complex FCS operations or encountering unusual file formats, load this reference for detailed guidance. + +## Integration Notes + +**NumPy Arrays:** All event data is returned as NumPy ndarrays with shape (events, channels) + +**Pandas DataFrames:** Easily convert to DataFrames for analysis: +```python +import pandas as pd +df = pd.DataFrame(flow.as_array(), columns=flow.pnn_labels) +``` + +**FlowKit Integration:** For advanced analysis (compensation, gating, FlowJo support), use FlowKit library which builds on FlowIO's parsing capabilities + +**Web Applications:** FlowIO's minimal dependencies make it ideal for web backend services processing FCS uploads + +## Troubleshooting + +**Problem:** "Offset discrepancy error" +**Solution:** Use `ignore_offset_discrepancy=True` parameter + +**Problem:** "Multiple datasets error" +**Solution:** Use `read_multiple_data_sets()` function instead of FlowData constructor + +**Problem:** Out of memory with large files +**Solution:** Use `only_text=True` for metadata-only operations, or process events in chunks + +**Problem:** Unexpected channel counts +**Solution:** Check for null channels; use `null_channel_list` parameter to exclude them + +**Problem:** Cannot modify event data in place +**Solution:** FlowIO doesn't support direct modification; extract data, modify, then use `create_fcs()` to save + +## Summary + +FlowIO provides essential FCS file handling capabilities for flow cytometry workflows. Use it for parsing, metadata extraction, and file creation. For simple file operations and data extraction, FlowIO is sufficient. For complex analysis including compensation and gating, integrate with FlowKit or other specialized tools. diff --git a/scientific-packages/flowio/references/api_reference.md b/scientific-packages/flowio/references/api_reference.md new file mode 100644 index 0000000..0d3ff4b --- /dev/null +++ b/scientific-packages/flowio/references/api_reference.md @@ -0,0 +1,372 @@ +# FlowIO API Reference + +## Overview + +FlowIO is a Python library for reading and writing Flow Cytometry Standard (FCS) files. It supports FCS versions 2.0, 3.0, and 3.1 with minimal dependencies. + +## Installation + +```bash +pip install flowio +``` + +Supports Python 3.9 and later. + +## Core Classes + +### FlowData + +The primary class for working with FCS files. + +#### Constructor + +```python +FlowData(fcs_file, + ignore_offset_error=False, + ignore_offset_discrepancy=False, + use_header_offsets=False, + only_text=False, + nextdata_offset=None, + null_channel_list=None) +``` + +**Parameters:** +- `fcs_file`: File path (str), Path object, or file handle +- `ignore_offset_error` (bool): Ignore offset errors (default: False) +- `ignore_offset_discrepancy` (bool): Ignore offset discrepancies between HEADER and TEXT sections (default: False) +- `use_header_offsets` (bool): Use HEADER section offsets instead of TEXT section (default: False) +- `only_text` (bool): Only parse the TEXT segment, skip DATA and ANALYSIS (default: False) +- `nextdata_offset` (int): Byte offset for reading multi-dataset files +- `null_channel_list` (list): List of PnN labels for null channels to exclude + +#### Attributes + +**File Information:** +- `name`: Name of the FCS file +- `file_size`: Size of the file in bytes +- `version`: FCS version (e.g., '3.0', '3.1') +- `header`: Dictionary containing HEADER segment information +- `data_type`: Type of data format ('I', 'F', 'D', 'A') + +**Channel Information:** +- `channel_count`: Number of channels in the dataset +- `channels`: Dictionary mapping channel numbers to channel info +- `pnn_labels`: List of PnN (short channel name) labels +- `pns_labels`: List of PnS (descriptive stain name) labels +- `pnr_values`: List of PnR (range) values for each channel +- `fluoro_indices`: List of indices for fluorescence channels +- `scatter_indices`: List of indices for scatter channels +- `time_index`: Index of the time channel (or None) +- `null_channels`: List of null channel indices + +**Event Data:** +- `event_count`: Number of events (rows) in the dataset +- `events`: Raw event data as bytes + +**Metadata:** +- `text`: Dictionary of TEXT segment key-value pairs +- `analysis`: Dictionary of ANALYSIS segment key-value pairs (if present) + +#### Methods + +##### as_array() + +```python +as_array(preprocess=True) +``` + +Return event data as a 2-D NumPy array. + +**Parameters:** +- `preprocess` (bool): Apply gain, logarithmic, and time scaling transformations (default: True) + +**Returns:** +- NumPy ndarray with shape (event_count, channel_count) + +**Example:** +```python +flow_data = FlowData('sample.fcs') +events_array = flow_data.as_array() # Preprocessed data +raw_array = flow_data.as_array(preprocess=False) # Raw data +``` + +##### write_fcs() + +```python +write_fcs(filename, metadata=None) +``` + +Export the FlowData instance as a new FCS file. + +**Parameters:** +- `filename` (str): Output file path +- `metadata` (dict): Optional dictionary of TEXT segment keywords to add/update + +**Example:** +```python +flow_data = FlowData('sample.fcs') +flow_data.write_fcs('output.fcs', metadata={'$SRC': 'Modified data'}) +``` + +**Note:** Exports as FCS 3.1 with single-precision floating-point data. + +## Utility Functions + +### read_multiple_data_sets() + +```python +read_multiple_data_sets(fcs_file, + ignore_offset_error=False, + ignore_offset_discrepancy=False, + use_header_offsets=False) +``` + +Read all datasets from an FCS file containing multiple datasets. + +**Parameters:** +- Same as FlowData constructor (except `nextdata_offset`) + +**Returns:** +- List of FlowData instances, one for each dataset + +**Example:** +```python +from flowio import read_multiple_data_sets + +datasets = read_multiple_data_sets('multi_dataset.fcs') +print(f"Found {len(datasets)} datasets") +for i, dataset in enumerate(datasets): + print(f"Dataset {i}: {dataset.event_count} events") +``` + +### create_fcs() + +```python +create_fcs(filename, + event_data, + channel_names, + opt_channel_names=None, + metadata=None) +``` + +Create a new FCS file from event data. + +**Parameters:** +- `filename` (str): Output file path +- `event_data` (ndarray): 2-D NumPy array of event data (rows=events, columns=channels) +- `channel_names` (list): List of PnN (short) channel names +- `opt_channel_names` (list): Optional list of PnS (descriptive) channel names +- `metadata` (dict): Optional dictionary of TEXT segment keywords + +**Example:** +```python +import numpy as np +from flowio import create_fcs + +# Create synthetic data +events = np.random.rand(10000, 5) +channels = ['FSC-A', 'SSC-A', 'FL1-A', 'FL2-A', 'Time'] +opt_channels = ['Forward Scatter', 'Side Scatter', 'FITC', 'PE', 'Time'] + +create_fcs('synthetic.fcs', + events, + channels, + opt_channel_names=opt_channels, + metadata={'$SRC': 'Synthetic data'}) +``` + +## Exception Classes + +### FlowIOWarning + +Generic warning class for non-critical issues. + +### PnEWarning + +Warning raised when PnE values are invalid during FCS file creation. + +### FlowIOException + +Base exception class for FlowIO errors. + +### FCSParsingError + +Raised when there are issues parsing an FCS file. + +### DataOffsetDiscrepancyError + +Raised when the HEADER and TEXT sections provide different byte offsets for data segments. + +**Workaround:** Use `ignore_offset_discrepancy=True` parameter when creating FlowData instance. + +### MultipleDataSetsError + +Raised when attempting to read a file with multiple datasets using the standard FlowData constructor. + +**Solution:** Use `read_multiple_data_sets()` function instead. + +## FCS File Structure Reference + +FCS files consist of four segments: + +1. **HEADER**: Contains FCS version and byte locations of other segments +2. **TEXT**: Key-value metadata pairs (delimited format) +3. **DATA**: Raw event data (binary, floating-point, or ASCII) +4. **ANALYSIS** (optional): Results from data processing + +### Common TEXT Segment Keywords + +- `$BEGINDATA`, `$ENDDATA`: Byte offsets for DATA segment +- `$BEGINANALYSIS`, `$ENDANALYSIS`: Byte offsets for ANALYSIS segment +- `$BYTEORD`: Byte order (1,2,3,4 for little-endian; 4,3,2,1 for big-endian) +- `$DATATYPE`: Data type ('I'=integer, 'F'=float, 'D'=double, 'A'=ASCII) +- `$MODE`: Data mode ('L'=list mode, most common) +- `$NEXTDATA`: Offset to next dataset (0 if single dataset) +- `$PAR`: Number of parameters (channels) +- `$TOT`: Total number of events +- `PnN`: Short name for parameter n +- `PnS`: Descriptive stain name for parameter n +- `PnR`: Range (max value) for parameter n +- `PnE`: Amplification exponent for parameter n (format: "a,b" where value = a * 10^(b*x)) +- `PnG`: Amplification gain for parameter n + +## Channel Types + +FlowIO automatically categorizes channels: + +- **Scatter channels**: FSC (forward scatter), SSC (side scatter) +- **Fluorescence channels**: FL1, FL2, FITC, PE, etc. +- **Time channel**: Usually labeled "Time" + +Access indices via: +- `flow_data.scatter_indices` +- `flow_data.fluoro_indices` +- `flow_data.time_index` + +## Data Preprocessing + +When calling `as_array(preprocess=True)`, FlowIO applies: + +1. **Gain scaling**: Multiply by PnG value +2. **Logarithmic transformation**: Apply PnE exponential transformation if present +3. **Time scaling**: Convert time values to appropriate units + +To access raw, unprocessed data: `as_array(preprocess=False)` + +## Best Practices + +1. **Memory efficiency**: Use `only_text=True` when only metadata is needed +2. **Error handling**: Wrap file operations in try-except blocks for FCSParsingError +3. **Multi-dataset files**: Always use `read_multiple_data_sets()` if unsure about dataset count +4. **Offset issues**: If encountering offset errors, try `ignore_offset_discrepancy=True` +5. **Channel selection**: Use null_channel_list to exclude unwanted channels during parsing + +## Integration with FlowKit + +For advanced flow cytometry analysis including compensation, gating, and GatingML support, consider using FlowKit library alongside FlowIO. FlowKit provides higher-level abstractions built on top of FlowIO's file parsing capabilities. + +## Example Workflows + +### Basic File Reading + +```python +from flowio import FlowData + +# Read FCS file +flow = FlowData('experiment.fcs') + +# Print basic info +print(f"Version: {flow.version}") +print(f"Events: {flow.event_count}") +print(f"Channels: {flow.channel_count}") +print(f"Channel names: {flow.pnn_labels}") + +# Get event data +events = flow.as_array() +print(f"Data shape: {events.shape}") +``` + +### Metadata Extraction + +```python +from flowio import FlowData + +flow = FlowData('sample.fcs', only_text=True) + +# Access metadata +print(f"Acquisition date: {flow.text.get('$DATE', 'N/A')}") +print(f"Instrument: {flow.text.get('$CYT', 'N/A')}") + +# Channel information +for i, (pnn, pns) in enumerate(zip(flow.pnn_labels, flow.pns_labels)): + print(f"Channel {i}: {pnn} ({pns})") +``` + +### Creating New FCS Files + +```python +import numpy as np +from flowio import create_fcs + +# Generate or process data +data = np.random.rand(5000, 3) * 1000 + +# Define channels +channels = ['FSC-A', 'SSC-A', 'FL1-A'] +stains = ['Forward Scatter', 'Side Scatter', 'GFP'] + +# Create FCS file +create_fcs('output.fcs', + data, + channels, + opt_channel_names=stains, + metadata={ + '$SRC': 'Python script', + '$DATE': '19-OCT-2025' + }) +``` + +### Processing Multi-Dataset Files + +```python +from flowio import read_multiple_data_sets + +# Read all datasets +datasets = read_multiple_data_sets('multi.fcs') + +# Process each dataset +for i, dataset in enumerate(datasets): + print(f"\nDataset {i}:") + print(f" Events: {dataset.event_count}") + print(f" Channels: {dataset.pnn_labels}") + + # Get data array + events = dataset.as_array() + mean_values = events.mean(axis=0) + print(f" Mean values: {mean_values}") +``` + +### Modifying and Re-exporting + +```python +from flowio import FlowData + +# Read original file +flow = FlowData('original.fcs') + +# Get event data +events = flow.as_array(preprocess=False) + +# Modify data (example: apply custom transformation) +events[:, 0] = events[:, 0] * 1.5 # Scale first channel + +# Note: Currently, FlowIO doesn't support direct modification of event data +# For modifications, use create_fcs() instead: +from flowio import create_fcs + +create_fcs('modified.fcs', + events, + flow.pnn_labels, + opt_channel_names=flow.pns_labels, + metadata=flow.text) +``` diff --git a/scientific-packages/gget/SKILL.md b/scientific-packages/gget/SKILL.md new file mode 100644 index 0000000..d5aee27 --- /dev/null +++ b/scientific-packages/gget/SKILL.md @@ -0,0 +1,870 @@ +--- +name: gget +description: Toolkit for querying genomic databases and performing bioinformatics analysis. Use this skill when working with gene sequences, protein structures, genomic databases (Ensembl, UniProt, NCBI, PDB, COSMIC, etc.), performing BLAST/BLAT searches, retrieving gene expression data, conducting enrichment analysis, predicting protein structures with AlphaFold, analyzing mutations, or any bioinformatics workflow requiring efficient database queries. This skill applies to tasks involving nucleotide/amino acid sequences, gene names, Ensembl IDs, UniProt accessions, or requests for genomic annotations, orthologs, disease associations, drug information, or single-cell RNA-seq data. +--- + +# gget + +## Overview + +gget is a command-line bioinformatics tool and Python package providing unified access to 20+ genomic databases and analysis methods. Execute queries for gene information, sequence analysis, protein structures, expression data, and disease associations through a consistent interface. All gget modules work both as command-line tools and as Python functions. + +**Important**: The databases queried by gget are continuously updated, which sometimes changes their structure. gget modules are tested automatically on a biweekly basis and updated to match new database structures when necessary. + +## Installation + +Install gget in a clean virtual environment to avoid conflicts: + +```bash +# Using uv (recommended) +uv pip install gget + +# Or using pip +pip install --upgrade gget + +# In Python/Jupyter +import gget +``` + +## Quick Start + +Basic usage pattern for all modules: + +```bash +# Command-line +gget [arguments] [options] + +# Python +gget.module(arguments, options) +``` + +Most modules return: +- **Command-line**: JSON (default) or CSV with `-csv` flag +- **Python**: DataFrame or dictionary + +Common flags across modules: +- `-o/--out`: Save results to file +- `-q/--quiet`: Suppress progress information +- `-csv`: Return CSV format (command-line only) + +## Module Categories + +### 1. Reference & Gene Information + +#### gget ref - Reference Genome Downloads + +Retrieve download links and metadata for Ensembl reference genomes. + +**Parameters**: +- `species`: Genus_species format (e.g., 'homo_sapiens', 'mus_musculus'). Shortcuts: 'human', 'mouse' +- `-w/--which`: Specify return types (gtf, cdna, dna, cds, cdrna, pep). Default: all +- `-r/--release`: Ensembl release number (default: latest) +- `-l/--list_species`: List available vertebrate species +- `-liv/--list_iv_species`: List available invertebrate species +- `-ftp`: Return only FTP links +- `-d/--download`: Download files (requires curl) + +**Examples**: +```bash +# List available species +gget ref --list_species + +# Get all reference files for human +gget ref homo_sapiens + +# Download only GTF annotation for mouse +gget ref -w gtf -d mouse +``` + +```python +# Python +gget.ref("homo_sapiens") +gget.ref("mus_musculus", which="gtf", download=True) +``` + +#### gget search - Gene Search + +Locate genes by name or description across species. + +**Parameters**: +- `searchwords`: One or more search terms (case-insensitive) +- `-s/--species`: Target species (e.g., 'homo_sapiens', 'mouse') +- `-r/--release`: Ensembl release number +- `-t/--id_type`: Return 'gene' (default) or 'transcript' +- `-ao/--andor`: 'or' (default) finds ANY searchword; 'and' requires ALL +- `-l/--limit`: Maximum results to return + +**Returns**: ensembl_id, gene_name, ensembl_description, ext_ref_description, biotype, URL + +**Examples**: +```bash +# Search for GABA-related genes in human +gget search -s human gaba gamma-aminobutyric + +# Find specific gene, require all terms +gget search -s mouse -ao and pax7 transcription +``` + +```python +# Python +gget.search(["gaba", "gamma-aminobutyric"], species="homo_sapiens") +``` + +#### gget info - Gene/Transcript Information + +Retrieve comprehensive gene and transcript metadata from Ensembl, UniProt, and NCBI. + +**Parameters**: +- `ens_ids`: One or more Ensembl IDs (also supports WormBase, Flybase IDs). Limit: ~1000 IDs +- `-n/--ncbi`: Disable NCBI data retrieval +- `-u/--uniprot`: Disable UniProt data retrieval +- `-pdb`: Include PDB identifiers (increases runtime) + +**Returns**: UniProt ID, NCBI gene ID, primary gene name, synonyms, protein names, descriptions, biotype, canonical transcript + +**Examples**: +```bash +# Get info for multiple genes +gget info ENSG00000034713 ENSG00000104853 ENSG00000170296 + +# Include PDB IDs +gget info ENSG00000034713 -pdb +``` + +```python +# Python +gget.info(["ENSG00000034713", "ENSG00000104853"], pdb=True) +``` + +#### gget seq - Sequence Retrieval + +Fetch nucleotide or amino acid sequences for genes and transcripts. + +**Parameters**: +- `ens_ids`: One or more Ensembl identifiers +- `-t/--translate`: Fetch amino acid sequences instead of nucleotide +- `-iso/--isoforms`: Return all transcript variants (gene IDs only) + +**Returns**: FASTA format sequences + +**Examples**: +```bash +# Get nucleotide sequences +gget seq ENSG00000034713 ENSG00000104853 + +# Get all protein isoforms +gget seq -t -iso ENSG00000034713 +``` + +```python +# Python +gget.seq(["ENSG00000034713"], translate=True, isoforms=True) +``` + +### 2. Sequence Analysis & Alignment + +#### gget blast - BLAST Searches + +BLAST nucleotide or amino acid sequences against standard databases. + +**Parameters**: +- `sequence`: Sequence string or path to FASTA/.txt file +- `-p/--program`: blastn, blastp, blastx, tblastn, tblastx (auto-detected) +- `-db/--database`: + - Nucleotide: nt, refseq_rna, pdbnt + - Protein: nr, swissprot, pdbaa, refseq_protein +- `-l/--limit`: Max hits (default: 50) +- `-e/--expect`: E-value cutoff (default: 10.0) +- `-lcf/--low_comp_filt`: Enable low complexity filtering +- `-mbo/--megablast_off`: Disable MegaBLAST (blastn only) + +**Examples**: +```bash +# BLAST protein sequence +gget blast MKWMFKEDHSLEHRCVESAKIRAKYPDRVPVIVEKVSGSQIVDIDKRKYLVPSDITVAQFMWIIRKRIQLPSEKAIFLFVDKTVPQSR + +# BLAST from file with specific database +gget blast sequence.fasta -db swissprot -l 10 +``` + +```python +# Python +gget.blast("MKWMFK...", database="swissprot", limit=10) +``` + +#### gget blat - BLAT Searches + +Locate genomic positions of sequences using UCSC BLAT. + +**Parameters**: +- `sequence`: Sequence string or path to FASTA/.txt file +- `-st/--seqtype`: 'DNA', 'protein', 'translated%20RNA', 'translated%20DNA' (auto-detected) +- `-a/--assembly`: Target assembly (default: 'human'/hg38; options: 'mouse'/mm39, 'zebrafinch'/taeGut2, etc.) + +**Returns**: genome, query size, alignment positions, matches, mismatches, alignment percentage + +**Examples**: +```bash +# Find genomic location in human +gget blat ATCGATCGATCGATCG + +# Search in different assembly +gget blat -a mm39 ATCGATCGATCGATCG +``` + +```python +# Python +gget.blat("ATCGATCGATCGATCG", assembly="mouse") +``` + +#### gget muscle - Multiple Sequence Alignment + +Align multiple nucleotide or amino acid sequences using Muscle5. + +**Parameters**: +- `fasta`: Sequences or path to FASTA/.txt file +- `-s5/--super5`: Use Super5 algorithm for faster processing (large datasets) + +**Returns**: Aligned sequences in ClustalW format or aligned FASTA (.afa) + +**Examples**: +```bash +# Align sequences from file +gget muscle sequences.fasta -o aligned.afa + +# Use Super5 for large dataset +gget muscle large_dataset.fasta -s5 +``` + +```python +# Python +gget.muscle("sequences.fasta", save=True) +``` + +#### gget diamond - Local Sequence Alignment + +Perform fast local protein or translated DNA alignment using DIAMOND. + +**Parameters**: +- Query: Sequences (string/list) or FASTA file path +- `--reference`: Reference sequences (string/list) or FASTA file path (required) +- `--sensitivity`: fast, mid-sensitive, sensitive, more-sensitive, very-sensitive (default), ultra-sensitive +- `--threads`: CPU threads (default: 1) +- `--diamond_db`: Save database for reuse +- `--translated`: Enable nucleotide-to-amino acid alignment + +**Returns**: Identity percentage, sequence lengths, match positions, gap openings, E-values, bit scores + +**Examples**: +```bash +# Align against reference +gget diamond GGETISAWESQME -ref reference.fasta --threads 4 + +# Save database for reuse +gget diamond query.fasta -ref ref.fasta --diamond_db my_db.dmnd +``` + +```python +# Python +gget.diamond("GGETISAWESQME", reference="reference.fasta", threads=4) +``` + +### 3. Structural & Protein Analysis + +#### gget pdb - Protein Structures + +Query RCSB Protein Data Bank for structure and metadata. + +**Parameters**: +- `pdb_id`: PDB identifier (e.g., '7S7U') +- `-r/--resource`: Data type (pdb, entry, pubmed, assembly, entity types) +- `-i/--identifier`: Assembly, entity, or chain ID + +**Returns**: PDB format (structures) or JSON (metadata) + +**Examples**: +```bash +# Download PDB structure +gget pdb 7S7U -o 7S7U.pdb + +# Get metadata +gget pdb 7S7U -r entry +``` + +```python +# Python +gget.pdb("7S7U", save=True) +``` + +#### gget alphafold - Protein Structure Prediction + +Predict 3D protein structures using simplified AlphaFold2. + +**Setup Required**: +```bash +# Install OpenMM first (version depends on Python version) +# Python < 3.10: +conda install -qy conda==4.13.0 && conda install -qy -c conda-forge openmm=7.5.1 +# Python 3.10: +conda install -qy conda==24.1.2 && conda install -qy -c conda-forge openmm=7.7.0 +# Python 3.11: +conda install -qy conda==24.11.1 && conda install -qy -c conda-forge openmm=8.0.0 + +# Then setup AlphaFold +gget setup alphafold +``` + +**Parameters**: +- `sequence`: Amino acid sequence (string), multiple sequences (list), or FASTA file. Multiple sequences trigger multimer modeling +- `-mr/--multimer_recycles`: Recycling iterations (default: 3; recommend 20 for accuracy) +- `-mfm/--multimer_for_monomer`: Apply multimer model to single proteins +- `-r/--relax`: AMBER relaxation for top-ranked model +- `plot`: Python-only; generate interactive 3D visualization (default: True) +- `show_sidechains`: Python-only; include side chains (default: True) + +**Returns**: PDB structure file, JSON alignment error data, optional 3D visualization + +**Examples**: +```bash +# Predict single protein structure +gget alphafold MKWMFKEDHSLEHRCVESAKIRAKYPDRVPVIVEKVSGSQIVDIDKRKYLVPSDITVAQFMWIIRKRIQLPSEKAIFLFVDKTVPQSR + +# Predict multimer with higher accuracy +gget alphafold sequence1.fasta -mr 20 -r +``` + +```python +# Python with visualization +gget.alphafold("MKWMFK...", plot=True, show_sidechains=True) + +# Multimer prediction +gget.alphafold(["sequence1", "sequence2"], multimer_recycles=20) +``` + +#### gget elm - Eukaryotic Linear Motifs + +Predict Eukaryotic Linear Motifs in protein sequences. + +**Setup Required**: +```bash +gget setup elm +``` + +**Parameters**: +- `sequence`: Amino acid sequence or UniProt Acc +- `-u/--uniprot`: Indicates sequence is UniProt Acc +- `-e/--expand`: Include protein names, organisms, references +- `-s/--sensitivity`: DIAMOND alignment sensitivity (default: "very-sensitive") +- `-t/--threads`: Number of threads (default: 1) + +**Returns**: Two outputs: +1. **ortholog_df**: Linear motifs from orthologous proteins +2. **regex_df**: Motifs directly matched in input sequence + +**Examples**: +```bash +# Predict motifs from sequence +gget elm LIAQSIGQASFV -o results + +# Use UniProt accession with expanded info +gget elm --uniprot Q02410 -e +``` + +```python +# Python +ortholog_df, regex_df = gget.elm("LIAQSIGQASFV") +``` + +### 4. Expression & Disease Data + +#### gget archs4 - Gene Correlation & Tissue Expression + +Query ARCHS4 database for correlated genes or tissue expression data. + +**Parameters**: +- `gene`: Gene symbol or Ensembl ID (with `--ensembl` flag) +- `-w/--which`: 'correlation' (default, returns 100 most correlated genes) or 'tissue' (expression atlas) +- `-s/--species`: 'human' (default) or 'mouse' (tissue data only) +- `-e/--ensembl`: Input is Ensembl ID + +**Returns**: +- **Correlation mode**: Gene symbols, Pearson correlation coefficients +- **Tissue mode**: Tissue identifiers, min/Q1/median/Q3/max expression values + +**Examples**: +```bash +# Get correlated genes +gget archs4 ACE2 + +# Get tissue expression +gget archs4 -w tissue ACE2 +``` + +```python +# Python +gget.archs4("ACE2", which="tissue") +``` + +#### gget cellxgene - Single-Cell RNA-seq Data + +Query CZ CELLxGENE Discover Census for single-cell data. + +**Setup Required**: +```bash +gget setup cellxgene +``` + +**Parameters**: +- `--gene` (-g): Gene names or Ensembl IDs (case-sensitive! 'PAX7' for human, 'Pax7' for mouse) +- `--tissue`: Tissue type(s) +- `--cell_type`: Specific cell type(s) +- `--species` (-s): 'homo_sapiens' (default) or 'mus_musculus' +- `--census_version` (-cv): Version ("stable", "latest", or dated) +- `--ensembl` (-e): Use Ensembl IDs +- `--meta_only` (-mo): Return metadata only +- Additional filters: disease, development_stage, sex, assay, dataset_id, donor_id, ethnicity, suspension_type + +**Returns**: AnnData object with count matrices and metadata (or metadata-only dataframes) + +**Examples**: +```bash +# Get single-cell data for specific genes and cell types +gget cellxgene --gene ACE2 ABCA1 --tissue lung --cell_type "mucus secreting cell" -o lung_data.h5ad + +# Metadata only +gget cellxgene --gene PAX7 --tissue muscle --meta_only -o metadata.csv +``` + +```python +# Python +adata = gget.cellxgene(gene=["ACE2", "ABCA1"], tissue="lung", cell_type="mucus secreting cell") +``` + +#### gget enrichr - Enrichment Analysis + +Perform ontology enrichment analysis on gene lists using Enrichr. + +**Parameters**: +- `genes`: Gene symbols or Ensembl IDs +- `-db/--database`: Reference database (supports shortcuts: 'pathway', 'transcription', 'ontology', 'diseases_drugs', 'celltypes') +- `-s/--species`: human (default), mouse, fly, yeast, worm, fish +- `-bkg_l/--background_list`: Background genes for comparison +- `-ko/--kegg_out`: Save KEGG pathway images with highlighted genes +- `plot`: Python-only; generate graphical results + +**Database Shortcuts**: +- 'pathway' → KEGG_2021_Human +- 'transcription' → ChEA_2016 +- 'ontology' → GO_Biological_Process_2021 +- 'diseases_drugs' → GWAS_Catalog_2019 +- 'celltypes' → PanglaoDB_Augmented_2021 + +**Examples**: +```bash +# Enrichment analysis for ontology +gget enrichr -db ontology ACE2 AGT AGTR1 + +# Save KEGG pathways +gget enrichr -db pathway ACE2 AGT AGTR1 -ko ./kegg_images/ +``` + +```python +# Python with plot +gget.enrichr(["ACE2", "AGT", "AGTR1"], database="ontology", plot=True) +``` + +#### gget bgee - Orthology & Expression + +Retrieve orthology and gene expression data from Bgee database. + +**Parameters**: +- `ens_id`: Ensembl gene ID or NCBI gene ID (for non-Ensembl species). Multiple IDs supported when `type=expression` +- `-t/--type`: 'orthologs' (default) or 'expression' + +**Returns**: +- **Orthologs mode**: Matching genes across species with IDs, names, taxonomic info +- **Expression mode**: Anatomical entities, confidence scores, expression status + +**Examples**: +```bash +# Get orthologs +gget bgee ENSG00000169194 + +# Get expression data +gget bgee ENSG00000169194 -t expression + +# Multiple genes +gget bgee ENSBTAG00000047356 ENSBTAG00000018317 -t expression +``` + +```python +# Python +gget.bgee("ENSG00000169194", type="orthologs") +``` + +#### gget opentargets - Disease & Drug Associations + +Retrieve disease and drug associations from OpenTargets. + +**Parameters**: +- Ensembl gene ID (required) +- `-r/--resource`: diseases (default), drugs, tractability, pharmacogenetics, expression, depmap, interactions +- `-l/--limit`: Cap results count +- Filter arguments (vary by resource): + - drugs: `--filter_disease` + - pharmacogenetics: `--filter_drug` + - expression/depmap: `--filter_tissue`, `--filter_anat_sys`, `--filter_organ` + - interactions: `--filter_protein_a`, `--filter_protein_b`, `--filter_gene_b` + +**Examples**: +```bash +# Get associated diseases +gget opentargets ENSG00000169194 -r diseases -l 5 + +# Get associated drugs +gget opentargets ENSG00000169194 -r drugs -l 10 + +# Get tissue expression +gget opentargets ENSG00000169194 -r expression --filter_tissue brain +``` + +```python +# Python +gget.opentargets("ENSG00000169194", resource="diseases", limit=5) +``` + +#### gget cbio - cBioPortal Cancer Genomics + +Plot cancer genomics heatmaps using cBioPortal data. + +**Two subcommands**: + +**search** - Find study IDs: +```bash +gget cbio search breast lung +``` + +**plot** - Generate heatmaps: + +**Parameters**: +- `-s/--study_ids`: Space-separated cBioPortal study IDs (required) +- `-g/--genes`: Space-separated gene names or Ensembl IDs (required) +- `-st/--stratification`: Column to organize data (tissue, cancer_type, cancer_type_detailed, study_id, sample) +- `-vt/--variation_type`: Data type (mutation_occurrences, cna_nonbinary, sv_occurrences, cna_occurrences, Consequence) +- `-f/--filter`: Filter by column value (e.g., 'study_id:msk_impact_2017') +- `-dd/--data_dir`: Cache directory (default: ./gget_cbio_cache) +- `-fd/--figure_dir`: Output directory (default: ./gget_cbio_figures) +- `-dpi`: Resolution (default: 100) +- `-sh/--show`: Display plot in window +- `-nc/--no_confirm`: Skip download confirmations + +**Examples**: +```bash +# Search for studies +gget cbio search esophag ovary + +# Create heatmap +gget cbio plot -s msk_impact_2017 -g AKT1 ALK BRAF -st tissue -vt mutation_occurrences +``` + +```python +# Python +gget.cbio_search(["esophag", "ovary"]) +gget.cbio_plot(["msk_impact_2017"], ["AKT1", "ALK"], stratification="tissue") +``` + +#### gget cosmic - COSMIC Database + +Search COSMIC (Catalogue Of Somatic Mutations In Cancer) database. + +**Important**: License fees apply for commercial use. Requires COSMIC account credentials. + +**Parameters**: +- `searchterm`: Gene name, Ensembl ID, mutation notation, or sample ID +- `-ctp/--cosmic_tsv_path`: Path to downloaded COSMIC TSV file (required for querying) +- `-l/--limit`: Maximum results (default: 100) + +**Database download flags**: +- `-d/--download_cosmic`: Activate download mode +- `-gm/--gget_mutate`: Create version for gget mutate +- `-cp/--cosmic_project`: Database type (cancer, census, cell_line, resistance, genome_screen, targeted_screen) +- `-cv/--cosmic_version`: COSMIC version +- `-gv/--grch_version`: Human reference genome (37 or 38) +- `--email`, `--password`: COSMIC credentials + +**Examples**: +```bash +# First download database +gget cosmic -d --email user@example.com --password xxx -cp cancer + +# Then query +gget cosmic EGFR -ctp cosmic_data.tsv -l 10 +``` + +```python +# Python +gget.cosmic("EGFR", cosmic_tsv_path="cosmic_data.tsv", limit=10) +``` + +### 5. Additional Tools + +#### gget mutate - Generate Mutated Sequences + +Generate mutated nucleotide sequences from mutation annotations. + +**Parameters**: +- `sequences`: FASTA file path or direct sequence input (string/list) +- `-m/--mutations`: CSV/TSV file or DataFrame with mutation data (required) +- `-mc/--mut_column`: Mutation column name (default: 'mutation') +- `-sic/--seq_id_column`: Sequence ID column (default: 'seq_ID') +- `-mic/--mut_id_column`: Mutation ID column +- `-k/--k`: Length of flanking sequences (default: 30 nucleotides) + +**Returns**: Mutated sequences in FASTA format + +**Examples**: +```bash +# Single mutation +gget mutate ATCGCTAAGCT -m "c.4G>T" + +# Multiple sequences with mutations from file +gget mutate sequences.fasta -m mutations.csv -o mutated.fasta +``` + +```python +# Python +import pandas as pd +mutations_df = pd.DataFrame({"seq_ID": ["seq1"], "mutation": ["c.4G>T"]}) +gget.mutate(["ATCGCTAAGCT"], mutations=mutations_df) +``` + +#### gget gpt - OpenAI Text Generation + +Generate natural language text using OpenAI's API. + +**Setup Required**: +```bash +gget setup gpt +``` + +**Important**: Free tier limited to 3 months after account creation. Set monthly billing limits. + +**Parameters**: +- `prompt`: Text input for generation (required) +- `api_key`: OpenAI authentication (required) +- Model configuration: temperature, top_p, max_tokens, frequency_penalty, presence_penalty +- Default model: gpt-3.5-turbo (configurable) + +**Examples**: +```bash +gget gpt "Explain CRISPR" --api_key your_key_here +``` + +```python +# Python +gget.gpt("Explain CRISPR", api_key="your_key_here") +``` + +#### gget setup - Install Dependencies + +Install/download third-party dependencies for specific modules. + +**Parameters**: +- `module`: Module name requiring dependency installation +- `-o/--out`: Output folder path (elm module only) + +**Modules requiring setup**: +- `alphafold` - Downloads ~4GB of model parameters +- `cellxgene` - Installs cellxgene-census (may not support latest Python) +- `elm` - Downloads local ELM database +- `gpt` - Configures OpenAI integration + +**Examples**: +```bash +# Setup AlphaFold +gget setup alphafold + +# Setup ELM with custom directory +gget setup elm -o /path/to/elm_data +``` + +```python +# Python +gget.setup("alphafold") +``` + +## Common Workflows + +### Workflow 1: Gene Discovery to Sequence Analysis + +Find and analyze genes of interest: + +```python +# 1. Search for genes +results = gget.search(["GABA", "receptor"], species="homo_sapiens") + +# 2. Get detailed information +gene_ids = results["ensembl_id"].tolist() +info = gget.info(gene_ids[:5]) + +# 3. Retrieve sequences +sequences = gget.seq(gene_ids[:5], translate=True) +``` + +### Workflow 2: Sequence Alignment and Structure + +Align sequences and predict structures: + +```python +# 1. Align multiple sequences +alignment = gget.muscle("sequences.fasta") + +# 2. Find similar sequences +blast_results = gget.blast(my_sequence, database="swissprot", limit=10) + +# 3. Predict structure +structure = gget.alphafold(my_sequence, plot=True) + +# 4. Find linear motifs +ortholog_df, regex_df = gget.elm(my_sequence) +``` + +### Workflow 3: Gene Expression and Enrichment + +Analyze expression patterns and functional enrichment: + +```python +# 1. Get tissue expression +tissue_expr = gget.archs4("ACE2", which="tissue") + +# 2. Find correlated genes +correlated = gget.archs4("ACE2", which="correlation") + +# 3. Get single-cell data +adata = gget.cellxgene(gene=["ACE2"], tissue="lung", cell_type="epithelial cell") + +# 4. Perform enrichment analysis +gene_list = correlated["gene_symbol"].tolist()[:50] +enrichment = gget.enrichr(gene_list, database="ontology", plot=True) +``` + +### Workflow 4: Disease and Drug Analysis + +Investigate disease associations and therapeutic targets: + +```python +# 1. Search for genes +genes = gget.search(["breast cancer"], species="homo_sapiens") + +# 2. Get disease associations +diseases = gget.opentargets("ENSG00000169194", resource="diseases") + +# 3. Get drug associations +drugs = gget.opentargets("ENSG00000169194", resource="drugs") + +# 4. Query cancer genomics data +study_ids = gget.cbio_search(["breast"]) +gget.cbio_plot(study_ids[:2], ["BRCA1", "BRCA2"], stratification="cancer_type") + +# 5. Search COSMIC for mutations +cosmic_results = gget.cosmic("BRCA1", cosmic_tsv_path="cosmic.tsv") +``` + +### Workflow 5: Comparative Genomics + +Compare proteins across species: + +```python +# 1. Get orthologs +orthologs = gget.bgee("ENSG00000169194", type="orthologs") + +# 2. Get sequences for comparison +human_seq = gget.seq("ENSG00000169194", translate=True) +mouse_seq = gget.seq("ENSMUSG00000026091", translate=True) + +# 3. Align sequences +alignment = gget.muscle([human_seq, mouse_seq]) + +# 4. Compare structures +human_structure = gget.pdb("7S7U") +mouse_structure = gget.alphafold(mouse_seq) +``` + +### Workflow 6: Building Reference Indices + +Prepare reference data for downstream analysis (e.g., kallisto|bustools): + +```bash +# 1. List available species +gget ref --list_species + +# 2. Download reference files +gget ref -w gtf -w cdna -d homo_sapiens + +# 3. Build kallisto index +kallisto index -i transcriptome.idx transcriptome.fasta + +# 4. Download genome for alignment +gget ref -w dna -d homo_sapiens +``` + +## Best Practices + +### Data Retrieval +- Use `--limit` to control result sizes for large queries +- Save results with `-o/--out` for reproducibility +- Check database versions/releases for consistency across analyses +- Use `--quiet` in production scripts to reduce output + +### Sequence Analysis +- For BLAST/BLAT, start with default parameters, then adjust sensitivity +- Use `gget diamond` with `--threads` for faster local alignment +- Save DIAMOND databases with `--diamond_db` for repeated queries +- For multiple sequence alignment, use `-s5/--super5` for large datasets + +### Expression and Disease Data +- Gene symbols are case-sensitive in cellxgene (e.g., 'PAX7' vs 'Pax7') +- Run `gget setup` before first use of alphafold, cellxgene, elm, gpt +- For enrichment analysis, use database shortcuts for convenience +- Cache cBioPortal data with `-dd` to avoid repeated downloads + +### Structure Prediction +- AlphaFold multimer predictions: use `-mr 20` for higher accuracy +- Use `-r` flag for AMBER relaxation of final structures +- Visualize results in Python with `plot=True` +- Check PDB database first before running AlphaFold predictions + +### Error Handling +- Database structures change; update gget regularly: `pip install --upgrade gget` +- Process max ~1000 Ensembl IDs at once with gget info +- For large-scale analyses, implement rate limiting for API queries +- Use virtual environments to avoid dependency conflicts + +## Output Formats + +### Command-line +- Default: JSON +- CSV: Add `-csv` flag +- FASTA: gget seq, gget mutate +- PDB: gget pdb, gget alphafold +- PNG: gget cbio plot + +### Python +- Default: DataFrame or dictionary +- JSON: Add `json=True` parameter +- Save to file: Add `save=True` or specify `out="filename"` +- AnnData: gget cellxgene + +## Resources + +This skill includes reference documentation for detailed module information: + +### references/ +- `module_reference.md` - Comprehensive parameter reference for all modules +- `database_info.md` - Information about queried databases and their update frequencies +- `workflows.md` - Extended workflow examples and use cases + +For additional help: +- Official documentation: https://pachterlab.github.io/gget/ +- GitHub issues: https://github.com/pachterlab/gget/issues +- Citation: Luebbert, L. & Pachter, L. (2023). Efficient querying of genomic reference databases with gget. Bioinformatics. https://doi.org/10.1093/bioinformatics/btac836 diff --git a/scientific-packages/gget/references/database_info.md b/scientific-packages/gget/references/database_info.md new file mode 100644 index 0000000..54bc48a --- /dev/null +++ b/scientific-packages/gget/references/database_info.md @@ -0,0 +1,300 @@ +# gget Database Information + +Overview of databases queried by gget modules, including update frequencies and important considerations. + +## Important Note + +The databases queried by gget are continuously being updated, which sometimes changes their structure. gget modules are tested automatically on a biweekly basis and updated to match new database structures when necessary. Always keep gget updated: + +```bash +pip install --upgrade gget +``` + +## Database Directory + +### Genomic Reference Databases + +#### Ensembl +- **Used by:** gget ref, gget search, gget info, gget seq +- **Description:** Comprehensive genome database with annotations for vertebrate and invertebrate species +- **Update frequency:** Regular releases (numbered); new releases approximately every 3 months +- **Access:** FTP downloads, REST API +- **Website:** https://www.ensembl.org/ +- **Notes:** + - Supports both vertebrate and invertebrate genomes + - Can specify release number for reproducibility + - Shortcuts available for common species ('human', 'mouse') + +#### UCSC Genome Browser +- **Used by:** gget blat +- **Description:** Genome browser database with BLAT alignment tool +- **Update frequency:** Regular updates with new assemblies +- **Access:** Web service API +- **Website:** https://genome.ucsc.edu/ +- **Notes:** + - Multiple genome assemblies available (hg38, mm39, etc.) + - BLAT optimized for vertebrate genomes + +### Protein & Structure Databases + +#### UniProt +- **Used by:** gget info, gget seq (amino acid sequences), gget elm +- **Description:** Universal Protein Resource, comprehensive protein sequence and functional information +- **Update frequency:** Regular releases (weekly for Swiss-Prot, monthly for TrEMBL) +- **Access:** REST API +- **Website:** https://www.uniprot.org/ +- **Notes:** + - Swiss-Prot: manually annotated and reviewed + - TrEMBL: automatically annotated + +#### NCBI (National Center for Biotechnology Information) +- **Used by:** gget info, gget bgee (for non-Ensembl species) +- **Description:** Gene and protein databases with extensive cross-references +- **Update frequency:** Continuous updates +- **Access:** E-utilities API +- **Website:** https://www.ncbi.nlm.nih.gov/ +- **Databases:** Gene, Protein, RefSeq + +#### RCSB PDB (Protein Data Bank) +- **Used by:** gget pdb +- **Description:** Repository of 3D structural data for proteins and nucleic acids +- **Update frequency:** Weekly updates +- **Access:** REST API +- **Website:** https://www.rcsb.org/ +- **Notes:** + - Experimentally determined structures (X-ray, NMR, cryo-EM) + - Includes metadata about experiments and publications + +#### ELM (Eukaryotic Linear Motif) +- **Used by:** gget elm +- **Description:** Database of functional sites in eukaryotic proteins +- **Update frequency:** Periodic updates +- **Access:** Downloaded database (via gget setup elm) +- **Website:** http://elm.eu.org/ +- **Notes:** + - Requires local download before first use + - Contains validated motifs and patterns + +### Sequence Similarity Databases + +#### BLAST Databases (NCBI) +- **Used by:** gget blast +- **Description:** Pre-formatted databases for BLAST searches +- **Update frequency:** Regular updates +- **Access:** NCBI BLAST API +- **Databases:** + - **Nucleotide:** nt (all GenBank), refseq_rna, pdbnt + - **Protein:** nr (non-redundant), swissprot, pdbaa, refseq_protein +- **Notes:** + - nt and nr are very large databases + - Consider specialized databases for faster, more focused searches + +### Expression & Correlation Databases + +#### ARCHS4 +- **Used by:** gget archs4 +- **Description:** Massive mining of publicly available RNA-seq data +- **Update frequency:** Periodic updates with new samples +- **Access:** HTTP API +- **Website:** https://maayanlab.cloud/archs4/ +- **Data:** + - Human and mouse RNA-seq data + - Correlation matrices + - Tissue expression atlases +- **Citation:** Lachmann et al., Nature Communications, 2018 + +#### CZ CELLxGENE Discover +- **Used by:** gget cellxgene +- **Description:** Single-cell RNA-seq data from multiple studies +- **Update frequency:** Continuous additions of new datasets +- **Access:** Census API (via cellxgene-census package) +- **Website:** https://cellxgene.cziscience.com/ +- **Data:** + - Single-cell RNA-seq count matrices + - Cell type annotations + - Tissue and disease metadata +- **Notes:** + - Requires gget setup cellxgene + - Gene symbols are case-sensitive + - May not support latest Python versions + +#### Bgee +- **Used by:** gget bgee +- **Description:** Gene expression and orthology database +- **Update frequency:** Regular releases +- **Access:** REST API +- **Website:** https://www.bgee.org/ +- **Data:** + - Gene expression across tissues and developmental stages + - Orthology relationships across species +- **Citation:** Bastian et al., 2021 + +### Functional & Pathway Databases + +#### Enrichr / modEnrichr +- **Used by:** gget enrichr +- **Description:** Gene set enrichment analysis web service +- **Update frequency:** Regular updates to underlying databases +- **Access:** REST API +- **Website:** https://maayanlab.cloud/Enrichr/ +- **Databases included:** + - KEGG pathways + - Gene Ontology (GO) + - Transcription factor targets (ChEA) + - Disease associations (GWAS Catalog) + - Cell type markers (PanglaoDB) +- **Notes:** + - Supports multiple model organisms + - Background gene lists can be provided for custom enrichment + +### Disease & Drug Databases + +#### Open Targets +- **Used by:** gget opentargets +- **Description:** Integrative platform for disease-target associations +- **Update frequency:** Regular releases (quarterly) +- **Access:** GraphQL API +- **Website:** https://www.opentargets.org/ +- **Data:** + - Disease associations + - Drug information and clinical trials + - Target tractability + - Pharmacogenetics + - Gene expression + - DepMap gene-disease effects + - Protein-protein interactions + +#### cBioPortal +- **Used by:** gget cbio +- **Description:** Cancer genomics data portal +- **Update frequency:** Continuous addition of new studies +- **Access:** Web API, downloadable datasets +- **Website:** https://www.cbioportal.org/ +- **Data:** + - Mutations, copy number alterations, structural variants + - Gene expression + - Clinical data +- **Notes:** + - Large datasets; caching recommended + - Multiple cancer types and studies available + +#### COSMIC (Catalogue Of Somatic Mutations In Cancer) +- **Used by:** gget cosmic +- **Description:** Comprehensive cancer mutation database +- **Update frequency:** Regular releases +- **Access:** Download (requires account and license for commercial use) +- **Website:** https://cancer.sanger.ac.uk/cosmic +- **Data:** + - Somatic mutations in cancer + - Gene census + - Cell line data + - Drug resistance mutations +- **Important:** + - Free for academic use + - License fees apply for commercial use + - Requires COSMIC account credentials + - Must download database before querying + +### AI & Prediction Services + +#### AlphaFold2 (DeepMind) +- **Used by:** gget alphafold +- **Description:** Deep learning model for protein structure prediction +- **Model version:** Simplified version for local execution +- **Access:** Local computation (requires model download via gget setup) +- **Website:** https://alphafold.ebi.ac.uk/ +- **Notes:** + - Requires ~4GB model parameters download + - Requires OpenMM installation + - Computationally intensive + - Python version-specific requirements + +#### OpenAI API +- **Used by:** gget gpt +- **Description:** Large language model API +- **Update frequency:** New models released periodically +- **Access:** REST API (requires API key) +- **Website:** https://openai.com/ +- **Notes:** + - Default model: gpt-3.5-turbo + - Free tier limited to 3 months after account creation + - Set billing limits to control costs + +## Data Consistency & Reproducibility + +### Version Control +To ensure reproducibility in analyses: + +1. **Specify database versions/releases:** + ```python + # Use specific Ensembl release + gget.ref("homo_sapiens", release=110) + + # Use specific Census version + gget.cellxgene(gene=["PAX7"], census_version="2023-07-25") + ``` + +2. **Document gget version:** + ```python + import gget + print(gget.__version__) + ``` + +3. **Save raw data:** + ```python + # Always save results for reproducibility + results = gget.search(["ACE2"], species="homo_sapiens") + results.to_csv("search_results_2025-01-15.csv", index=False) + ``` + +### Handling Database Updates + +1. **Regular gget updates:** + - Update gget biweekly to match database structure changes + - Check release notes for breaking changes + +2. **Error handling:** + - Database structure changes may cause temporary failures + - Check GitHub issues: https://github.com/pachterlab/gget/issues + - Update gget if errors occur + +3. **API rate limiting:** + - Implement delays for large-scale queries + - Use local databases (DIAMOND, COSMIC) when possible + - Cache results to avoid repeated queries + +## Database-Specific Best Practices + +### Ensembl +- Use species shortcuts ('human', 'mouse') for convenience +- Specify release numbers for reproducibility +- Check available species with `gget ref --list_species` + +### UniProt +- UniProt IDs are more stable than gene names +- Swiss-Prot annotations are manually curated and more reliable +- Use PDB flag in gget info only when needed (increases runtime) + +### BLAST/BLAT +- Start with default parameters, then optimize +- Use specialized databases (swissprot, refseq_protein) for focused searches +- Consider E-value cutoffs based on query length + +### Expression Databases +- Gene symbols are case-sensitive in CELLxGENE +- ARCHS4 correlation data is based on co-expression patterns +- Consider tissue-specificity when interpreting results + +### Cancer Databases +- cBioPortal: cache data locally for repeated analyses +- COSMIC: download appropriate database subset for your needs +- Respect license agreements for commercial use + +## Citations + +When using gget, cite both the gget publication and the underlying databases: + +**gget:** +Luebbert, L. & Pachter, L. (2023). Efficient querying of genomic reference databases with gget. Bioinformatics. https://doi.org/10.1093/bioinformatics/btac836 + +**Database-specific citations:** Check references/ directory or database websites for appropriate citations. diff --git a/scientific-packages/gget/references/module_reference.md b/scientific-packages/gget/references/module_reference.md new file mode 100644 index 0000000..9f466d3 --- /dev/null +++ b/scientific-packages/gget/references/module_reference.md @@ -0,0 +1,467 @@ +# gget Module Reference + +Comprehensive parameter reference for all gget modules. + +## Reference & Gene Information Modules + +### gget ref +Retrieve Ensembl reference genome FTPs and metadata. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `species` | str | Species in Genus_species format or shortcuts ('human', 'mouse') | Required | +| `-w/--which` | str | File types to return: gtf, cdna, dna, cds, cdrna, pep | All | +| `-r/--release` | int | Ensembl release number | Latest | +| `-od/--out_dir` | str | Output directory path | None | +| `-o/--out` | str | JSON file path for results | None | +| `-l/--list_species` | flag | List available vertebrate species | False | +| `-liv/--list_iv_species` | flag | List available invertebrate species | False | +| `-ftp` | flag | Return only FTP links | False | +| `-d/--download` | flag | Download files (requires curl) | False | +| `-q/--quiet` | flag | Suppress progress information | False | + +**Returns:** JSON containing FTP links, Ensembl release numbers, release dates, file sizes + +--- + +### gget search +Search for genes by name or description in Ensembl. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `searchwords` | str/list | Search terms (case-insensitive) | Required | +| `-s/--species` | str | Target species or core database name | Required | +| `-r/--release` | int | Ensembl release number | Latest | +| `-t/--id_type` | str | Return 'gene' or 'transcript' | 'gene' | +| `-ao/--andor` | str | 'or' (ANY term) or 'and' (ALL terms) | 'or' | +| `-l/--limit` | int | Maximum results to return | None | +| `-o/--out` | str | Output file path (CSV/JSON) | None | + +**Returns:** ensembl_id, gene_name, ensembl_description, ext_ref_description, biotype, URL + +--- + +### gget info +Get comprehensive gene/transcript metadata from Ensembl, UniProt, and NCBI. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `ens_ids` | str/list | Ensembl IDs (WormBase, Flybase also supported) | Required | +| `-o/--out` | str | Output file path (CSV/JSON) | None | +| `-n/--ncbi` | bool | Disable NCBI data retrieval | False | +| `-u/--uniprot` | bool | Disable UniProt data retrieval | False | +| `-pdb` | bool | Include PDB identifiers | False | +| `-csv` | flag | Return CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress display | False | + +**Python-specific:** +- `save=True`: Save output to current directory +- `wrap_text=True`: Format dataframe with wrapped text + +**Note:** Processing >1000 IDs simultaneously may cause server errors. + +**Returns:** UniProt ID, NCBI gene ID, gene name, synonyms, protein names, descriptions, biotype, canonical transcript + +--- + +### gget seq +Retrieve nucleotide or amino acid sequences in FASTA format. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `ens_ids` | str/list | Ensembl identifiers | Required | +| `-o/--out` | str | Output file path | stdout | +| `-t/--translate` | flag | Fetch amino acid sequences | False | +| `-iso/--isoforms` | flag | Return all transcript variants | False | +| `-q/--quiet` | flag | Suppress progress information | False | + +**Data sources:** Ensembl (nucleotide), UniProt (amino acid) + +**Returns:** FASTA format sequences + +--- + +## Sequence Analysis & Alignment Modules + +### gget blast +BLAST sequences against standard databases. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `sequence` | str | Sequence or path to FASTA/.txt | Required | +| `-p/--program` | str | blastn, blastp, blastx, tblastn, tblastx | Auto-detect | +| `-db/--database` | str | nt, refseq_rna, pdbnt, nr, swissprot, pdbaa, refseq_protein | nt or nr | +| `-l/--limit` | int | Max hits returned | 50 | +| `-e/--expect` | float | E-value cutoff | 10.0 | +| `-lcf/--low_comp_filt` | flag | Enable low complexity filtering | False | +| `-mbo/--megablast_off` | flag | Disable MegaBLAST (blastn only) | False | +| `-o/--out` | str | Output file path | None | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** Description, Scientific Name, Common Name, Taxid, Max Score, Total Score, Query Coverage + +--- + +### gget blat +Find genomic positions using UCSC BLAT. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `sequence` | str | Sequence or path to FASTA/.txt | Required | +| `-st/--seqtype` | str | 'DNA', 'protein', 'translated%20RNA', 'translated%20DNA' | Auto-detect | +| `-a/--assembly` | str | Target assembly (hg38, mm39, taeGut2, etc.) | 'human'/hg38 | +| `-o/--out` | str | Output file path | None | +| `-csv` | flag | Return CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** genome, query size, alignment start/end, matches, mismatches, alignment percentage + +--- + +### gget muscle +Align multiple sequences using Muscle5. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `fasta` | str/list | Sequences or FASTA file path | Required | +| `-o/--out` | str | Output file path | stdout | +| `-s5/--super5` | flag | Use Super5 algorithm (faster, large datasets) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** ClustalW format alignment or aligned FASTA (.afa) + +--- + +### gget diamond +Fast local protein/translated DNA alignment. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `query` | str/list | Query sequences or FASTA file | Required | +| `--reference` | str/list | Reference sequences or FASTA file | Required | +| `--sensitivity` | str | fast, mid-sensitive, sensitive, more-sensitive, very-sensitive, ultra-sensitive | very-sensitive | +| `--threads` | int | CPU threads | 1 | +| `--diamond_binary` | str | Path to DIAMOND installation | Auto-detect | +| `--diamond_db` | str | Save database for reuse | None | +| `--translated` | flag | Enable nucleotide-to-amino acid alignment | False | +| `-o/--out` | str | Output file path | None | +| `-csv` | flag | CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** Identity %, sequence lengths, match positions, gap openings, E-values, bit scores + +--- + +## Structural & Protein Analysis Modules + +### gget pdb +Query RCSB Protein Data Bank. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `pdb_id` | str | PDB identifier (e.g., '7S7U') | Required | +| `-r/--resource` | str | pdb, entry, pubmed, assembly, entity types | 'pdb' | +| `-i/--identifier` | str | Assembly, entity, or chain ID | None | +| `-o/--out` | str | Output file path | stdout | + +**Returns:** PDB format (structures) or JSON (metadata) + +--- + +### gget alphafold +Predict 3D protein structures using AlphaFold2. + +**Setup:** Requires OpenMM and `gget setup alphafold` (~4GB download) + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `sequence` | str/list | Amino acid sequence(s) or FASTA file | Required | +| `-mr/--multimer_recycles` | int | Recycling iterations for multimers | 3 | +| `-o/--out` | str | Output folder path | timestamped | +| `-mfm/--multimer_for_monomer` | flag | Apply multimer model to monomers | False | +| `-r/--relax` | flag | AMBER relaxation for top model | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Python-only:** +- `plot` (bool): Generate 3D visualization (default: True) +- `show_sidechains` (bool): Include side chains (default: True) + +**Note:** Multiple sequences automatically trigger multimer modeling + +**Returns:** PDB structure file, JSON alignment error data, optional 3D plot + +--- + +### gget elm +Predict Eukaryotic Linear Motifs. + +**Setup:** Requires `gget setup elm` + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `sequence` | str | Amino acid sequence or UniProt Acc | Required | +| `-s/--sensitivity` | str | DIAMOND alignment sensitivity | very-sensitive | +| `-t/--threads` | int | Number of threads | 1 | +| `-bin/--diamond_binary` | str | Path to DIAMOND binary | Auto-detect | +| `-o/--out` | str | Output directory path | None | +| `-u/--uniprot` | flag | Input is UniProt Acc | False | +| `-e/--expand` | flag | Include protein names, organisms, references | False | +| `-csv` | flag | CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** Two outputs: +1. **ortholog_df**: Motifs from orthologous proteins +2. **regex_df**: Motifs matched in input sequence + +--- + +## Expression & Disease Data Modules + +### gget archs4 +Query ARCHS4 for gene correlation or tissue expression. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `gene` | str | Gene symbol or Ensembl ID | Required | +| `-w/--which` | str | 'correlation' or 'tissue' | 'correlation' | +| `-s/--species` | str | 'human' or 'mouse' (tissue only) | 'human' | +| `-o/--out` | str | Output file path | None | +| `-e/--ensembl` | flag | Input is Ensembl ID | False | +| `-csv` | flag | CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** +- **correlation**: Gene symbols, Pearson correlation coefficients (top 100) +- **tissue**: Tissue IDs, min/Q1/median/Q3/max expression + +--- + +### gget cellxgene +Query CZ CELLxGENE Discover Census for single-cell data. + +**Setup:** Requires `gget setup cellxgene` + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `--gene` (-g) | list | Gene names or Ensembl IDs (case-sensitive!) | Required | +| `--tissue` | list | Tissue type(s) | None | +| `--cell_type` | list | Cell type(s) | None | +| `--species` (-s) | str | 'homo_sapiens' or 'mus_musculus' | 'homo_sapiens' | +| `--census_version` (-cv) | str | "stable", "latest", or dated version | "stable" | +| `-o/--out` | str | Output file path (required for CLI) | Required | +| `--ensembl` (-e) | flag | Use Ensembl IDs | False | +| `--meta_only` (-mo) | flag | Return metadata only | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Additional filters:** disease, development_stage, sex, assay, dataset_id, donor_id, ethnicity, suspension_type + +**Important:** Gene symbols are case-sensitive ('PAX7' for human, 'Pax7' for mouse) + +**Returns:** AnnData object with count matrices and metadata + +--- + +### gget enrichr +Perform enrichment analysis using Enrichr/modEnrichr. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `genes` | list | Gene symbols or Ensembl IDs | Required | +| `-db/--database` | str | Reference database or shortcut | Required | +| `-s/--species` | str | human, mouse, fly, yeast, worm, fish | 'human' | +| `-bkg_l/--background_list` | list | Background genes | None | +| `-o/--out` | str | Output file path | None | +| `-ko/--kegg_out` | str | KEGG pathway images directory | None | + +**Python-only:** +- `plot` (bool): Generate graphical results + +**Database shortcuts:** +- 'pathway' → KEGG_2021_Human +- 'transcription' → ChEA_2016 +- 'ontology' → GO_Biological_Process_2021 +- 'diseases_drugs' → GWAS_Catalog_2019 +- 'celltypes' → PanglaoDB_Augmented_2021 + +**Returns:** Pathway/function associations with adjusted p-values, overlapping gene counts + +--- + +### gget bgee +Retrieve orthology and expression from Bgee. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `ens_id` | str/list | Ensembl or NCBI gene ID | Required | +| `-t/--type` | str | 'orthologs' or 'expression' | 'orthologs' | +| `-o/--out` | str | Output file path | None | +| `-csv` | flag | CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Note:** Multiple IDs supported when `type='expression'` + +**Returns:** +- **orthologs**: Genes across species with IDs, names, taxonomic info +- **expression**: Anatomical entities, confidence scores, expression status + +--- + +### gget opentargets +Retrieve disease/drug associations from OpenTargets. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `ens_id` | str | Ensembl gene ID | Required | +| `-r/--resource` | str | diseases, drugs, tractability, pharmacogenetics, expression, depmap, interactions | 'diseases' | +| `-l/--limit` | int | Maximum results | None | +| `-o/--out` | str | Output file path | None | +| `-csv` | flag | CSV format (CLI) | False | +| `-q/--quiet` | flag | Suppress progress | False | + +**Resource-specific filters:** +- drugs: `--filter_disease` +- pharmacogenetics: `--filter_drug` +- expression/depmap: `--filter_tissue`, `--filter_anat_sys`, `--filter_organ` +- interactions: `--filter_protein_a`, `--filter_protein_b`, `--filter_gene_b` + +**Returns:** Disease/drug associations, tractability, pharmacogenetics, expression, DepMap, interactions + +--- + +### gget cbio +Plot cancer genomics heatmaps from cBioPortal. + +**Subcommands:** search, plot + +**search parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `keywords` | list | Search terms | Required | + +**plot parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `-s/--study_ids` | list | cBioPortal study IDs | Required | +| `-g/--genes` | list | Gene names or Ensembl IDs | Required | +| `-st/--stratification` | str | tissue, cancer_type, cancer_type_detailed, study_id, sample | None | +| `-vt/--variation_type` | str | mutation_occurrences, cna_nonbinary, sv_occurrences, cna_occurrences, Consequence | None | +| `-f/--filter` | str | Filter by column value (e.g., 'study_id:msk_impact_2017') | None | +| `-dd/--data_dir` | str | Cache directory | ./gget_cbio_cache | +| `-fd/--figure_dir` | str | Output directory | ./gget_cbio_figures | +| `-t/--title` | str | Custom figure title | None | +| `-dpi` | int | Resolution | 100 | +| `-q/--quiet` | flag | Suppress progress | False | +| `-nc/--no_confirm` | flag | Skip download confirmations | False | +| `-sh/--show` | flag | Display plot in window | False | + +**Returns:** PNG heatmap figure + +--- + +### gget cosmic +Search COSMIC database for cancer mutations. + +**Important:** License fees for commercial use. Requires COSMIC account. + +**Query parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `searchterm` | str | Gene name, Ensembl ID, mutation, sample ID | Required | +| `-ctp/--cosmic_tsv_path` | str | Path to COSMIC TSV file | Required | +| `-l/--limit` | int | Maximum results | 100 | +| `-csv` | flag | CSV format (CLI) | False | + +**Download parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `-d/--download_cosmic` | flag | Activate download mode | False | +| `-gm/--gget_mutate` | flag | Create version for gget mutate | False | +| `-cp/--cosmic_project` | str | cancer, census, cell_line, resistance, genome_screen, targeted_screen | None | +| `-cv/--cosmic_version` | str | COSMIC version | Latest | +| `-gv/--grch_version` | int | Human reference genome (37 or 38) | None | +| `--email` | str | COSMIC account email | Required | +| `--password` | str | COSMIC account password | Required | + +**Note:** First-time users must download database + +**Returns:** Mutation data from COSMIC + +--- + +## Additional Tools + +### gget mutate +Generate mutated nucleotide sequences. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `sequences` | str/list | FASTA file or sequences | Required | +| `-m/--mutations` | str/df | CSV/TSV file or DataFrame | Required | +| `-mc/--mut_column` | str | Mutation column name | 'mutation' | +| `-sic/--seq_id_column` | str | Sequence ID column | 'seq_ID' | +| `-mic/--mut_id_column` | str | Mutation ID column | None | +| `-k/--k` | int | Length of flanking sequences | 30 | +| `-o/--out` | str | Output FASTA file path | stdout | +| `-q/--quiet` | flag | Suppress progress | False | + +**Returns:** Mutated sequences in FASTA format + +--- + +### gget gpt +Generate text using OpenAI's API. + +**Setup:** Requires `gget setup gpt` and OpenAI API key + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `prompt` | str | Text input for generation | Required | +| `api_key` | str | OpenAI API key | Required | +| `model` | str | OpenAI model name | gpt-3.5-turbo | +| `temperature` | float | Sampling temperature (0-2) | 1.0 | +| `top_p` | float | Nucleus sampling | 1.0 | +| `max_tokens` | int | Maximum tokens to generate | None | +| `frequency_penalty` | float | Frequency penalty (0-2) | 0 | +| `presence_penalty` | float | Presence penalty (0-2) | 0 | + +**Important:** Free tier limited to 3 months. Set billing limits. + +**Returns:** Generated text string + +--- + +### gget setup +Install/download dependencies for modules. + +**Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `module` | str | Module name | Required | +| `-o/--out` | str | Output folder (elm only) | Package install folder | +| `-q/--quiet` | flag | Suppress progress | False | + +**Modules requiring setup:** +- `alphafold` - Downloads ~4GB model parameters +- `cellxgene` - Installs cellxgene-census +- `elm` - Downloads local ELM database +- `gpt` - Configures OpenAI integration + +**Returns:** None (installs dependencies) diff --git a/scientific-packages/gget/references/workflows.md b/scientific-packages/gget/references/workflows.md new file mode 100644 index 0000000..487fc62 --- /dev/null +++ b/scientific-packages/gget/references/workflows.md @@ -0,0 +1,814 @@ +# gget Workflow Examples + +Extended workflow examples demonstrating how to combine multiple gget modules for common bioinformatics tasks. + +## Table of Contents +1. [Complete Gene Analysis Pipeline](#complete-gene-analysis-pipeline) +2. [Comparative Structural Biology](#comparative-structural-biology) +3. [Cancer Genomics Analysis](#cancer-genomics-analysis) +4. [Single-Cell Expression Analysis](#single-cell-expression-analysis) +5. [Building Reference Transcriptomes](#building-reference-transcriptomes) +6. [Mutation Impact Assessment](#mutation-impact-assessment) +7. [Drug Target Discovery](#drug-target-discovery) + +--- + +## Complete Gene Analysis Pipeline + +Comprehensive analysis of a gene from discovery to functional annotation. + +```python +import gget +import pandas as pd + +# Step 1: Search for genes of interest +print("Step 1: Searching for GABA receptor genes...") +search_results = gget.search(["GABA", "receptor", "alpha"], + species="homo_sapiens", + andor="and") +print(f"Found {len(search_results)} genes") + +# Step 2: Get detailed information +print("\nStep 2: Getting detailed information...") +gene_ids = search_results["ensembl_id"].tolist()[:5] # Top 5 genes +gene_info = gget.info(gene_ids, pdb=True) +print(gene_info[["ensembl_id", "gene_name", "uniprot_id", "description"]]) + +# Step 3: Retrieve sequences +print("\nStep 3: Retrieving sequences...") +nucleotide_seqs = gget.seq(gene_ids) +protein_seqs = gget.seq(gene_ids, translate=True) + +# Save sequences +with open("gaba_receptors_nt.fasta", "w") as f: + f.write(nucleotide_seqs) +with open("gaba_receptors_aa.fasta", "w") as f: + f.write(protein_seqs) + +# Step 4: Get expression data +print("\nStep 4: Getting tissue expression...") +for gene_id, gene_name in zip(gene_ids, gene_info["gene_name"]): + expr_data = gget.archs4(gene_name, which="tissue") + print(f"\n{gene_name} expression:") + print(expr_data.head()) + +# Step 5: Find correlated genes +print("\nStep 5: Finding correlated genes...") +correlated = gget.archs4(gene_info["gene_name"].iloc[0], which="correlation") +correlated_top = correlated.head(20) +print(correlated_top) + +# Step 6: Enrichment analysis on correlated genes +print("\nStep 6: Performing enrichment analysis...") +gene_list = correlated_top["gene_symbol"].tolist() +enrichment = gget.enrichr(gene_list, database="ontology", plot=True) +print(enrichment.head(10)) + +# Step 7: Get disease associations +print("\nStep 7: Getting disease associations...") +for gene_id, gene_name in zip(gene_ids[:3], gene_info["gene_name"][:3]): + diseases = gget.opentargets(gene_id, resource="diseases", limit=5) + print(f"\n{gene_name} disease associations:") + print(diseases) + +# Step 8: Check for orthologs +print("\nStep 8: Finding orthologs...") +orthologs = gget.bgee(gene_ids[0], type="orthologs") +print(orthologs) + +print("\nComplete gene analysis pipeline finished!") +``` + +--- + +## Comparative Structural Biology + +Compare protein structures across species and analyze functional motifs. + +```python +import gget + +# Define genes for comparison +human_gene = "ENSG00000169174" # PCSK9 +mouse_gene = "ENSMUSG00000044254" # Pcsk9 + +print("Comparative Structural Biology Workflow") +print("=" * 50) + +# Step 1: Get gene information +print("\n1. Getting gene information...") +human_info = gget.info([human_gene]) +mouse_info = gget.info([mouse_gene]) + +print(f"Human: {human_info['gene_name'].iloc[0]}") +print(f"Mouse: {mouse_info['gene_name'].iloc[0]}") + +# Step 2: Retrieve protein sequences +print("\n2. Retrieving protein sequences...") +human_seq = gget.seq(human_gene, translate=True) +mouse_seq = gget.seq(mouse_gene, translate=True) + +# Save to file for alignment +with open("pcsk9_sequences.fasta", "w") as f: + f.write(human_seq) + f.write("\n") + f.write(mouse_seq) + +# Step 3: Align sequences +print("\n3. Aligning sequences...") +alignment = gget.muscle("pcsk9_sequences.fasta") +print("Alignment completed. Visualizing in ClustalW format:") +print(alignment) + +# Step 4: Get existing structures from PDB +print("\n4. Searching PDB for existing structures...") +# Search by sequence using BLAST +pdb_results = gget.blast(human_seq, database="pdbaa", limit=5) +print("Top PDB matches:") +print(pdb_results[["Description", "Max Score", "Query Coverage"]]) + +# Download top structure +if len(pdb_results) > 0: + # Extract PDB ID from description (usually format: "PDB|XXXX|...") + pdb_id = pdb_results.iloc[0]["Description"].split("|")[1] + print(f"\nDownloading PDB structure: {pdb_id}") + gget.pdb(pdb_id, save=True) + +# Step 5: Predict AlphaFold structures +print("\n5. Predicting structures with AlphaFold...") +# Note: This requires gget setup alphafold and is computationally intensive +# Uncomment to run: +# human_structure = gget.alphafold(human_seq, plot=True) +# mouse_structure = gget.alphafold(mouse_seq, plot=True) +print("(AlphaFold prediction skipped - uncomment to run)") + +# Step 6: Identify functional motifs +print("\n6. Identifying functional motifs with ELM...") +# Note: Requires gget setup elm +# Uncomment to run: +# human_ortholog_df, human_regex_df = gget.elm(human_seq) +# print("Human PCSK9 functional motifs:") +# print(human_regex_df) +print("(ELM analysis skipped - uncomment to run)") + +# Step 7: Get orthology information +print("\n7. Getting orthology information from Bgee...") +orthologs = gget.bgee(human_gene, type="orthologs") +print("PCSK9 orthologs:") +print(orthologs) + +print("\nComparative structural biology workflow completed!") +``` + +--- + +## Cancer Genomics Analysis + +Analyze cancer-associated genes and their mutations. + +```python +import gget +import matplotlib.pyplot as plt + +print("Cancer Genomics Analysis Workflow") +print("=" * 50) + +# Step 1: Search for cancer-related genes +print("\n1. Searching for breast cancer genes...") +genes = gget.search(["breast", "cancer", "BRCA"], + species="homo_sapiens", + andor="or", + limit=20) +print(f"Found {len(genes)} genes") + +# Focus on specific genes +target_genes = ["BRCA1", "BRCA2", "TP53", "PIK3CA", "ESR1"] +print(f"\nAnalyzing: {', '.join(target_genes)}") + +# Step 2: Get gene information +print("\n2. Getting gene information...") +gene_search = [] +for gene in target_genes: + result = gget.search([gene], species="homo_sapiens", limit=1) + if len(result) > 0: + gene_search.append(result.iloc[0]) + +gene_df = pd.DataFrame(gene_search) +gene_ids = gene_df["ensembl_id"].tolist() + +# Step 3: Get disease associations +print("\n3. Getting disease associations from OpenTargets...") +for gene_id, gene_name in zip(gene_ids, target_genes): + print(f"\n{gene_name} disease associations:") + diseases = gget.opentargets(gene_id, resource="diseases", limit=3) + print(diseases[["disease_name", "overall_score"]]) + +# Step 4: Get drug associations +print("\n4. Getting drug associations...") +for gene_id, gene_name in zip(gene_ids[:3], target_genes[:3]): + print(f"\n{gene_name} drug associations:") + drugs = gget.opentargets(gene_id, resource="drugs", limit=3) + if len(drugs) > 0: + print(drugs[["drug_name", "drug_type", "max_phase_for_all_diseases"]]) + +# Step 5: Search cBioPortal for studies +print("\n5. Searching cBioPortal for breast cancer studies...") +studies = gget.cbio_search(["breast", "cancer"]) +print(f"Found {len(studies)} studies") +print(studies[:5]) + +# Step 6: Create cancer genomics heatmap +print("\n6. Creating cancer genomics heatmap...") +if len(studies) > 0: + # Select relevant studies + selected_studies = studies[:2] # Top 2 studies + + gget.cbio_plot( + selected_studies, + target_genes, + stratification="cancer_type", + variation_type="mutation_occurrences", + show=False + ) + print("Heatmap saved to ./gget_cbio_figures/") + +# Step 7: Query COSMIC database (requires setup) +print("\n7. Querying COSMIC database...") +# Note: Requires COSMIC account and database download +# Uncomment to run: +# for gene in target_genes[:2]: +# cosmic_results = gget.cosmic( +# gene, +# cosmic_tsv_path="cosmic_cancer.tsv", +# limit=10 +# ) +# print(f"\n{gene} mutations in COSMIC:") +# print(cosmic_results) +print("(COSMIC query skipped - requires database download)") + +# Step 8: Enrichment analysis +print("\n8. Performing pathway enrichment...") +enrichment = gget.enrichr(target_genes, database="pathway", plot=True) +print("\nTop enriched pathways:") +print(enrichment.head(10)) + +print("\nCancer genomics analysis completed!") +``` + +--- + +## Single-Cell Expression Analysis + +Analyze single-cell RNA-seq data for specific cell types and tissues. + +```python +import gget +import scanpy as sc + +print("Single-Cell Expression Analysis Workflow") +print("=" * 50) + +# Note: Requires gget setup cellxgene + +# Step 1: Define genes and cell types of interest +genes_of_interest = ["ACE2", "TMPRSS2", "CD4", "CD8A"] +tissue = "lung" +cell_types = ["type ii pneumocyte", "macrophage", "t cell"] + +print(f"\nAnalyzing genes: {', '.join(genes_of_interest)}") +print(f"Tissue: {tissue}") +print(f"Cell types: {', '.join(cell_types)}") + +# Step 2: Get metadata first +print("\n1. Retrieving metadata...") +metadata = gget.cellxgene( + gene=genes_of_interest, + tissue=tissue, + species="homo_sapiens", + meta_only=True +) +print(f"Found {len(metadata)} datasets") +print(metadata.head()) + +# Step 3: Download count matrices +print("\n2. Downloading single-cell data...") +# Note: This can be a large download +adata = gget.cellxgene( + gene=genes_of_interest, + tissue=tissue, + species="homo_sapiens", + census_version="stable" +) +print(f"AnnData shape: {adata.shape}") +print(f"Genes: {adata.n_vars}") +print(f"Cells: {adata.n_obs}") + +# Step 4: Basic QC and filtering with scanpy +print("\n3. Performing quality control...") +sc.pp.filter_cells(adata, min_genes=200) +sc.pp.filter_genes(adata, min_cells=3) +print(f"After QC - Cells: {adata.n_obs}, Genes: {adata.n_vars}") + +# Step 5: Normalize and log-transform +print("\n4. Normalizing data...") +sc.pp.normalize_total(adata, target_sum=1e4) +sc.pp.log1p(adata) + +# Step 6: Calculate gene expression statistics +print("\n5. Calculating expression statistics...") +for gene in genes_of_interest: + if gene in adata.var_names: + expr = adata[:, gene].X.toarray().flatten() + print(f"\n{gene} expression:") + print(f" Mean: {expr.mean():.3f}") + print(f" Median: {np.median(expr):.3f}") + print(f" % expressing: {(expr > 0).sum() / len(expr) * 100:.1f}%") + +# Step 7: Get tissue expression from ARCHS4 for comparison +print("\n6. Getting bulk tissue expression from ARCHS4...") +for gene in genes_of_interest: + tissue_expr = gget.archs4(gene, which="tissue") + lung_expr = tissue_expr[tissue_expr["tissue"] == "lung"] + if len(lung_expr) > 0: + print(f"\n{gene} in lung (ARCHS4):") + print(f" Median: {lung_expr['median'].iloc[0]:.3f}") + +# Step 8: Enrichment analysis +print("\n7. Performing enrichment analysis...") +enrichment = gget.enrichr(genes_of_interest, database="celltypes", plot=True) +print("\nTop cell type associations:") +print(enrichment.head(10)) + +# Step 9: Get disease associations +print("\n8. Getting disease associations...") +for gene in genes_of_interest: + gene_search = gget.search([gene], species="homo_sapiens", limit=1) + if len(gene_search) > 0: + gene_id = gene_search["ensembl_id"].iloc[0] + diseases = gget.opentargets(gene_id, resource="diseases", limit=3) + print(f"\n{gene} disease associations:") + print(diseases[["disease_name", "overall_score"]]) + +print("\nSingle-cell expression analysis completed!") +``` + +--- + +## Building Reference Transcriptomes + +Prepare reference data for RNA-seq analysis pipelines. + +```bash +#!/bin/bash +# Reference transcriptome building workflow + +echo "Reference Transcriptome Building Workflow" +echo "==========================================" + +# Step 1: List available species +echo -e "\n1. Listing available species..." +gget ref --list_species > available_species.txt +echo "Available species saved to available_species.txt" + +# Step 2: Download reference files for human +echo -e "\n2. Downloading human reference files..." +SPECIES="homo_sapiens" +RELEASE=110 # Specify release for reproducibility + +# Download GTF annotation +echo "Downloading GTF annotation..." +gget ref -w gtf -r $RELEASE -d $SPECIES -o human_ref_gtf.json + +# Download cDNA sequences +echo "Downloading cDNA sequences..." +gget ref -w cdna -r $RELEASE -d $SPECIES -o human_ref_cdna.json + +# Download protein sequences +echo "Downloading protein sequences..." +gget ref -w pep -r $RELEASE -d $SPECIES -o human_ref_pep.json + +# Step 3: Build kallisto index (if kallisto is installed) +echo -e "\n3. Building kallisto index..." +if command -v kallisto &> /dev/null; then + # Get cDNA FASTA file from download + CDNA_FILE=$(ls *.cdna.all.fa.gz) + if [ -f "$CDNA_FILE" ]; then + kallisto index -i transcriptome.idx $CDNA_FILE + echo "Kallisto index created: transcriptome.idx" + else + echo "cDNA FASTA file not found" + fi +else + echo "kallisto not installed, skipping index building" +fi + +# Step 4: Download genome for alignment-based methods +echo -e "\n4. Downloading genome sequence..." +gget ref -w dna -r $RELEASE -d $SPECIES -o human_ref_dna.json + +# Step 5: Get gene information for genes of interest +echo -e "\n5. Getting information for specific genes..." +gget search -s $SPECIES "TP53 BRCA1 BRCA2" -o key_genes.csv + +echo -e "\nReference transcriptome building completed!" +``` + +```python +# Python version +import gget +import json + +print("Reference Transcriptome Building Workflow") +print("=" * 50) + +# Configuration +species = "homo_sapiens" +release = 110 +genes_of_interest = ["TP53", "BRCA1", "BRCA2", "MYC", "EGFR"] + +# Step 1: Get reference information +print("\n1. Getting reference information...") +ref_info = gget.ref(species, release=release) + +# Save reference information +with open("reference_info.json", "w") as f: + json.dump(ref_info, f, indent=2) +print("Reference information saved to reference_info.json") + +# Step 2: Download specific files +print("\n2. Downloading reference files...") +# GTF annotation +gget.ref(species, which="gtf", release=release, download=True) +# cDNA sequences +gget.ref(species, which="cdna", release=release, download=True) + +# Step 3: Get information for genes of interest +print(f"\n3. Getting information for {len(genes_of_interest)} genes...") +gene_data = [] +for gene in genes_of_interest: + result = gget.search([gene], species=species, limit=1) + if len(result) > 0: + gene_data.append(result.iloc[0]) + +# Get detailed info +if gene_data: + gene_ids = [g["ensembl_id"] for g in gene_data] + detailed_info = gget.info(gene_ids) + detailed_info.to_csv("genes_of_interest_info.csv", index=False) + print("Gene information saved to genes_of_interest_info.csv") + +# Step 4: Get sequences +print("\n4. Retrieving sequences...") +sequences_nt = gget.seq(gene_ids) +sequences_aa = gget.seq(gene_ids, translate=True) + +with open("key_genes_nucleotide.fasta", "w") as f: + f.write(sequences_nt) +with open("key_genes_protein.fasta", "w") as f: + f.write(sequences_aa) + +print("\nReference transcriptome building completed!") +print(f"Files created:") +print(" - reference_info.json") +print(" - genes_of_interest_info.csv") +print(" - key_genes_nucleotide.fasta") +print(" - key_genes_protein.fasta") +``` + +--- + +## Mutation Impact Assessment + +Analyze the impact of genetic mutations on protein structure and function. + +```python +import gget +import pandas as pd + +print("Mutation Impact Assessment Workflow") +print("=" * 50) + +# Define mutations to analyze +mutations = [ + {"gene": "TP53", "mutation": "c.818G>A", "description": "R273H hotspot"}, + {"gene": "EGFR", "mutation": "c.2573T>G", "description": "L858R activating"}, +] + +# Step 1: Get gene information +print("\n1. Getting gene information...") +for mut in mutations: + results = gget.search([mut["gene"]], species="homo_sapiens", limit=1) + if len(results) > 0: + mut["ensembl_id"] = results["ensembl_id"].iloc[0] + print(f"{mut['gene']}: {mut['ensembl_id']}") + +# Step 2: Get sequences +print("\n2. Retrieving wild-type sequences...") +for mut in mutations: + # Get nucleotide sequence + nt_seq = gget.seq(mut["ensembl_id"]) + mut["wt_sequence"] = nt_seq + + # Get protein sequence + aa_seq = gget.seq(mut["ensembl_id"], translate=True) + mut["wt_protein"] = aa_seq + +# Step 3: Generate mutated sequences +print("\n3. Generating mutated sequences...") +# Create mutation dataframe for gget mutate +mut_df = pd.DataFrame({ + "seq_ID": [m["gene"] for m in mutations], + "mutation": [m["mutation"] for m in mutations] +}) + +# For each mutation +for mut in mutations: + # Extract sequence from FASTA + lines = mut["wt_sequence"].split("\n") + seq = "".join(lines[1:]) + + # Create single mutation df + single_mut = pd.DataFrame({ + "seq_ID": [mut["gene"]], + "mutation": [mut["mutation"]] + }) + + # Generate mutated sequence + mutated = gget.mutate([seq], mutations=single_mut) + mut["mutated_sequence"] = mutated + +print("Mutated sequences generated") + +# Step 4: Get existing structure information +print("\n4. Getting structure information...") +for mut in mutations: + # Get info with PDB IDs + info = gget.info([mut["ensembl_id"]], pdb=True) + + if "pdb_id" in info.columns and pd.notna(info["pdb_id"].iloc[0]): + pdb_ids = info["pdb_id"].iloc[0].split(";") + print(f"\n{mut['gene']} PDB structures: {', '.join(pdb_ids[:3])}") + + # Download first structure + if len(pdb_ids) > 0: + pdb_id = pdb_ids[0].strip() + mut["pdb_id"] = pdb_id + gget.pdb(pdb_id, save=True) + else: + print(f"\n{mut['gene']}: No PDB structure available") + mut["pdb_id"] = None + +# Step 5: Predict structures with AlphaFold (optional) +print("\n5. Predicting structures with AlphaFold...") +# Note: Requires gget setup alphafold and is computationally intensive +# Uncomment to run: +# for mut in mutations: +# print(f"Predicting {mut['gene']} wild-type structure...") +# wt_structure = gget.alphafold(mut["wt_protein"]) +# +# print(f"Predicting {mut['gene']} mutant structure...") +# # Would need to translate mutated sequence first +# # mutant_structure = gget.alphafold(mutated_protein) +print("(AlphaFold prediction skipped - uncomment to run)") + +# Step 6: Find functional motifs +print("\n6. Identifying functional motifs...") +# Note: Requires gget setup elm +# Uncomment to run: +# for mut in mutations: +# ortholog_df, regex_df = gget.elm(mut["wt_protein"]) +# print(f"\n{mut['gene']} functional motifs:") +# print(regex_df) +print("(ELM analysis skipped - uncomment to run)") + +# Step 7: Get disease associations +print("\n7. Getting disease associations...") +for mut in mutations: + diseases = gget.opentargets( + mut["ensembl_id"], + resource="diseases", + limit=5 + ) + print(f"\n{mut['gene']} ({mut['description']}) disease associations:") + print(diseases[["disease_name", "overall_score"]]) + +# Step 8: Query COSMIC for mutation frequency +print("\n8. Querying COSMIC database...") +# Note: Requires COSMIC database download +# Uncomment to run: +# for mut in mutations: +# cosmic_results = gget.cosmic( +# mut["mutation"], +# cosmic_tsv_path="cosmic_cancer.tsv", +# limit=10 +# ) +# print(f"\n{mut['gene']} {mut['mutation']} in COSMIC:") +# print(cosmic_results) +print("(COSMIC query skipped - requires database download)") + +print("\nMutation impact assessment completed!") +``` + +--- + +## Drug Target Discovery + +Identify and validate potential drug targets for specific diseases. + +```python +import gget +import pandas as pd + +print("Drug Target Discovery Workflow") +print("=" * 50) + +# Step 1: Search for disease-related genes +disease = "alzheimer" +print(f"\n1. Searching for {disease} disease genes...") +genes = gget.search([disease], species="homo_sapiens", limit=50) +print(f"Found {len(genes)} potential genes") + +# Step 2: Get detailed information +print("\n2. Getting detailed gene information...") +gene_ids = genes["ensembl_id"].tolist()[:20] # Top 20 +gene_info = gget.info(gene_ids[:10]) # Limit to avoid timeout + +# Step 3: Get disease associations from OpenTargets +print("\n3. Getting disease associations...") +disease_scores = [] +for gene_id, gene_name in zip(gene_info["ensembl_id"], gene_info["gene_name"]): + diseases = gget.opentargets(gene_id, resource="diseases", limit=10) + + # Filter for Alzheimer's disease + alzheimer = diseases[diseases["disease_name"].str.contains("Alzheimer", case=False, na=False)] + + if len(alzheimer) > 0: + disease_scores.append({ + "ensembl_id": gene_id, + "gene_name": gene_name, + "disease_score": alzheimer["overall_score"].max() + }) + +disease_df = pd.DataFrame(disease_scores).sort_values("disease_score", ascending=False) +print("\nTop disease-associated genes:") +print(disease_df.head(10)) + +# Step 4: Get tractability information +print("\n4. Assessing target tractability...") +top_targets = disease_df.head(5) +for _, row in top_targets.iterrows(): + tractability = gget.opentargets( + row["ensembl_id"], + resource="tractability" + ) + print(f"\n{row['gene_name']} tractability:") + print(tractability) + +# Step 5: Get expression data +print("\n5. Getting tissue expression data...") +for _, row in top_targets.iterrows(): + # Brain expression from OpenTargets + expression = gget.opentargets( + row["ensembl_id"], + resource="expression", + filter_tissue="brain" + ) + print(f"\n{row['gene_name']} brain expression:") + print(expression) + + # Tissue expression from ARCHS4 + tissue_expr = gget.archs4(row["gene_name"], which="tissue") + brain_expr = tissue_expr[tissue_expr["tissue"].str.contains("brain", case=False, na=False)] + print(f"ARCHS4 brain expression:") + print(brain_expr) + +# Step 6: Check for existing drugs +print("\n6. Checking for existing drugs...") +for _, row in top_targets.iterrows(): + drugs = gget.opentargets(row["ensembl_id"], resource="drugs", limit=5) + print(f"\n{row['gene_name']} drug associations:") + if len(drugs) > 0: + print(drugs[["drug_name", "drug_type", "max_phase_for_all_diseases"]]) + else: + print("No drugs found") + +# Step 7: Get protein-protein interactions +print("\n7. Getting protein-protein interactions...") +for _, row in top_targets.iterrows(): + interactions = gget.opentargets( + row["ensembl_id"], + resource="interactions", + limit=10 + ) + print(f"\n{row['gene_name']} interacts with:") + if len(interactions) > 0: + print(interactions[["gene_b_symbol", "interaction_score"]]) + +# Step 8: Enrichment analysis +print("\n8. Performing pathway enrichment...") +gene_list = top_targets["gene_name"].tolist() +enrichment = gget.enrichr(gene_list, database="pathway", plot=True) +print("\nTop enriched pathways:") +print(enrichment.head(10)) + +# Step 9: Get structure information +print("\n9. Getting structure information...") +for _, row in top_targets.iterrows(): + info = gget.info([row["ensembl_id"]], pdb=True) + + if "pdb_id" in info.columns and pd.notna(info["pdb_id"].iloc[0]): + pdb_ids = info["pdb_id"].iloc[0].split(";") + print(f"\n{row['gene_name']} PDB structures: {', '.join(pdb_ids[:3])}") + else: + print(f"\n{row['gene_name']}: No PDB structure available") + # Could predict with AlphaFold + print(f" Consider AlphaFold prediction") + +# Step 10: Generate target summary report +print("\n10. Generating target summary report...") +report = [] +for _, row in top_targets.iterrows(): + report.append({ + "Gene": row["gene_name"], + "Ensembl ID": row["ensembl_id"], + "Disease Score": row["disease_score"], + "Target Status": "High Priority" + }) + +report_df = pd.DataFrame(report) +report_df.to_csv("drug_targets_report.csv", index=False) +print("\nTarget report saved to drug_targets_report.csv") + +print("\nDrug target discovery workflow completed!") +``` + +--- + +## Tips for Workflow Development + +### Error Handling +```python +import gget + +def safe_gget_call(func, *args, **kwargs): + """Wrapper for gget calls with error handling""" + try: + result = func(*args, **kwargs) + return result + except Exception as e: + print(f"Error in {func.__name__}: {str(e)}") + return None + +# Usage +result = safe_gget_call(gget.search, ["ACE2"], species="homo_sapiens") +if result is not None: + print(result) +``` + +### Rate Limiting +```python +import time +import gget + +def rate_limited_queries(gene_ids, delay=1): + """Query multiple genes with rate limiting""" + results = [] + for i, gene_id in enumerate(gene_ids): + print(f"Querying {i+1}/{len(gene_ids)}: {gene_id}") + result = gget.info([gene_id]) + results.append(result) + + if i < len(gene_ids) - 1: # Don't sleep after last query + time.sleep(delay) + + return pd.concat(results, ignore_index=True) +``` + +### Caching Results +```python +import os +import pickle +import gget + +def cached_gget(cache_file, func, *args, **kwargs): + """Cache gget results to avoid repeated queries""" + if os.path.exists(cache_file): + print(f"Loading from cache: {cache_file}") + with open(cache_file, "rb") as f: + return pickle.load(f) + + result = func(*args, **kwargs) + + with open(cache_file, "wb") as f: + pickle.dump(result, f) + print(f"Saved to cache: {cache_file}") + + return result + +# Usage +result = cached_gget("ace2_info.pkl", gget.info, ["ENSG00000130234"]) +``` + +--- + +These workflows demonstrate how to combine multiple gget modules for comprehensive bioinformatics analyses. Adapt them to your specific research questions and data types. diff --git a/scientific-packages/gget/scripts/batch_sequence_analysis.py b/scientific-packages/gget/scripts/batch_sequence_analysis.py new file mode 100755 index 0000000..a64a085 --- /dev/null +++ b/scientific-packages/gget/scripts/batch_sequence_analysis.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" +Batch Sequence Analysis Script +Analyze multiple sequences: BLAST, alignment, and structure prediction +""" + +import argparse +import sys +from pathlib import Path +import gget + + +def read_fasta(fasta_file): + """Read sequences from FASTA file.""" + sequences = [] + current_id = None + current_seq = [] + + with open(fasta_file, "r") as f: + for line in f: + line = line.strip() + if line.startswith(">"): + if current_id: + sequences.append({"id": current_id, "seq": "".join(current_seq)}) + current_id = line[1:] + current_seq = [] + else: + current_seq.append(line) + + if current_id: + sequences.append({"id": current_id, "seq": "".join(current_seq)}) + + return sequences + + +def analyze_sequences( + fasta_file, + blast_db="nr", + align=True, + predict_structure=False, + output_dir="output", +): + """ + Perform batch sequence analysis. + + Args: + fasta_file: Path to FASTA file with sequences + blast_db: BLAST database to search (default: nr) + align: Whether to perform multiple sequence alignment + predict_structure: Whether to predict structures with AlphaFold + output_dir: Output directory for results + """ + output_path = Path(output_dir) + output_path.mkdir(exist_ok=True) + + print(f"Batch Sequence Analysis") + print("=" * 60) + print(f"Input file: {fasta_file}") + print(f"Output directory: {output_dir}") + print("") + + # Read sequences + print("Reading sequences...") + sequences = read_fasta(fasta_file) + print(f"Found {len(sequences)} sequences\n") + + # Step 1: BLAST each sequence + print("Step 1: Running BLAST searches...") + print("-" * 60) + for i, seq_data in enumerate(sequences): + print(f"\n{i+1}. BLASTing {seq_data['id']}...") + try: + blast_results = gget.blast( + seq_data["seq"], database=blast_db, limit=10, save=False + ) + + output_file = output_path / f"{seq_data['id']}_blast.csv" + blast_results.to_csv(output_file, index=False) + print(f" Results saved to: {output_file}") + + if len(blast_results) > 0: + print(f" Top hit: {blast_results.iloc[0]['Description']}") + print( + f" Max Score: {blast_results.iloc[0]['Max Score']}, " + f"Query Coverage: {blast_results.iloc[0]['Query Coverage']}" + ) + except Exception as e: + print(f" Error: {e}") + + # Step 2: Multiple sequence alignment + if align and len(sequences) > 1: + print("\n\nStep 2: Multiple sequence alignment...") + print("-" * 60) + try: + alignment = gget.muscle(fasta_file) + alignment_file = output_path / "alignment.afa" + with open(alignment_file, "w") as f: + f.write(alignment) + print(f"Alignment saved to: {alignment_file}") + except Exception as e: + print(f"Error in alignment: {e}") + else: + print("\n\nStep 2: Skipping alignment (only 1 sequence or disabled)") + + # Step 3: Structure prediction (optional) + if predict_structure: + print("\n\nStep 3: Predicting structures with AlphaFold...") + print("-" * 60) + print( + "Note: This requires 'gget setup alphafold' and is computationally intensive" + ) + + for i, seq_data in enumerate(sequences): + print(f"\n{i+1}. Predicting structure for {seq_data['id']}...") + try: + structure_dir = output_path / f"structure_{seq_data['id']}" + # Uncomment to run AlphaFold prediction: + # gget.alphafold(seq_data['seq'], out=str(structure_dir)) + # print(f" Structure saved to: {structure_dir}") + print( + " (Prediction skipped - uncomment code to run AlphaFold prediction)" + ) + except Exception as e: + print(f" Error: {e}") + else: + print("\n\nStep 3: Structure prediction disabled") + + # Summary + print("\n" + "=" * 60) + print("Batch analysis complete!") + print(f"\nResults saved to: {output_dir}/") + print(f" - BLAST results: *_blast.csv") + if align and len(sequences) > 1: + print(f" - Alignment: alignment.afa") + if predict_structure: + print(f" - Structures: structure_*/") + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Perform batch sequence analysis using gget" + ) + parser.add_argument("fasta", help="Input FASTA file with sequences") + parser.add_argument( + "-db", + "--database", + default="nr", + help="BLAST database (default: nr for proteins, nt for nucleotides)", + ) + parser.add_argument( + "--no-align", action="store_true", help="Skip multiple sequence alignment" + ) + parser.add_argument( + "--predict-structure", + action="store_true", + help="Predict structures with AlphaFold (requires setup)", + ) + parser.add_argument( + "-o", "--output", default="output", help="Output directory (default: output)" + ) + + args = parser.parse_args() + + if not Path(args.fasta).exists(): + print(f"Error: File not found: {args.fasta}") + sys.exit(1) + + try: + success = analyze_sequences( + args.fasta, + blast_db=args.database, + align=not args.no_align, + predict_structure=args.predict_structure, + output_dir=args.output, + ) + sys.exit(0 if success else 1) + except KeyboardInterrupt: + print("\n\nAnalysis interrupted by user") + sys.exit(1) + except Exception as e: + print(f"\n\nError: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/gget/scripts/enrichment_pipeline.py b/scientific-packages/gget/scripts/enrichment_pipeline.py new file mode 100755 index 0000000..c88a505 --- /dev/null +++ b/scientific-packages/gget/scripts/enrichment_pipeline.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +""" +Enrichment Analysis Pipeline +Perform comprehensive enrichment analysis on a gene list +""" + +import argparse +import sys +from pathlib import Path +import gget +import pandas as pd + + +def read_gene_list(file_path): + """Read gene list from file (one gene per line or CSV).""" + file_path = Path(file_path) + + if file_path.suffix == ".csv": + df = pd.read_csv(file_path) + # Assume first column contains gene names + genes = df.iloc[:, 0].tolist() + else: + # Plain text file + with open(file_path, "r") as f: + genes = [line.strip() for line in f if line.strip()] + + return genes + + +def enrichment_pipeline( + gene_list, + species="human", + background=None, + output_prefix="enrichment", + plot=True, +): + """ + Perform comprehensive enrichment analysis. + + Args: + gene_list: List of gene symbols + species: Species for analysis + background: Background gene list (optional) + output_prefix: Prefix for output files + plot: Whether to generate plots + """ + print("Enrichment Analysis Pipeline") + print("=" * 60) + print(f"Analyzing {len(gene_list)} genes") + print(f"Species: {species}\n") + + # Database categories to analyze + databases = { + "pathway": "KEGG Pathways", + "ontology": "Gene Ontology (Biological Process)", + "transcription": "Transcription Factors (ChEA)", + "diseases_drugs": "Disease Associations (GWAS)", + "celltypes": "Cell Type Markers (PanglaoDB)", + } + + results = {} + + for db_key, db_name in databases.items(): + print(f"\nAnalyzing: {db_name}") + print("-" * 60) + + try: + enrichment = gget.enrichr( + gene_list, + database=db_key, + species=species, + background_list=background, + plot=plot, + ) + + if enrichment is not None and len(enrichment) > 0: + # Save results + output_file = f"{output_prefix}_{db_key}.csv" + enrichment.to_csv(output_file, index=False) + print(f"Results saved to: {output_file}") + + # Show top 5 results + print(f"\nTop 5 enriched terms:") + for i, row in enrichment.head(5).iterrows(): + term = row.get("name", row.get("term", "Unknown")) + p_val = row.get( + "adjusted_p_value", + row.get("p_value", row.get("Adjusted P-value", 1)), + ) + print(f" {i+1}. {term}") + print(f" P-value: {p_val:.2e}") + + results[db_key] = enrichment + else: + print("No significant results found") + + except Exception as e: + print(f"Error: {e}") + + # Generate summary report + print("\n" + "=" * 60) + print("Generating summary report...") + + summary = [] + for db_key, db_name in databases.items(): + if db_key in results and len(results[db_key]) > 0: + summary.append( + { + "Database": db_name, + "Total Terms": len(results[db_key]), + "Top Term": results[db_key].iloc[0].get( + "name", results[db_key].iloc[0].get("term", "N/A") + ), + } + ) + + if summary: + summary_df = pd.DataFrame(summary) + summary_file = f"{output_prefix}_summary.csv" + summary_df.to_csv(summary_file, index=False) + print(f"\nSummary saved to: {summary_file}") + print("\n" + summary_df.to_string(index=False)) + else: + print("\nNo enrichment results to summarize") + + # Get expression data for genes + print("\n" + "=" * 60) + print("Getting expression data for input genes...") + + try: + # Get tissue expression for first few genes + expr_data = [] + for gene in gene_list[:5]: # Limit to first 5 + print(f" Getting expression for {gene}...") + try: + tissue_expr = gget.archs4(gene, which="tissue") + top_tissue = tissue_expr.nlargest(1, "median").iloc[0] + expr_data.append( + { + "Gene": gene, + "Top Tissue": top_tissue["tissue"], + "Median Expression": top_tissue["median"], + } + ) + except Exception as e: + print(f" Warning: {e}") + + if expr_data: + expr_df = pd.DataFrame(expr_data) + expr_file = f"{output_prefix}_expression.csv" + expr_df.to_csv(expr_file, index=False) + print(f"\nExpression data saved to: {expr_file}") + + except Exception as e: + print(f"Error getting expression data: {e}") + + print("\n" + "=" * 60) + print("Enrichment analysis complete!") + print(f"\nOutput files (prefix: {output_prefix}):") + for db_key in databases.keys(): + if db_key in results: + print(f" - {output_prefix}_{db_key}.csv") + print(f" - {output_prefix}_summary.csv") + print(f" - {output_prefix}_expression.csv") + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Perform comprehensive enrichment analysis using gget" + ) + parser.add_argument( + "genes", + help="Gene list file (one gene per line or CSV with genes in first column)", + ) + parser.add_argument( + "-s", + "--species", + default="human", + help="Species (human, mouse, fly, yeast, worm, fish)", + ) + parser.add_argument( + "-b", "--background", help="Background gene list file (optional)" + ) + parser.add_argument( + "-o", "--output", default="enrichment", help="Output prefix (default: enrichment)" + ) + parser.add_argument( + "--no-plot", action="store_true", help="Disable plotting" + ) + + args = parser.parse_args() + + # Read gene list + if not Path(args.genes).exists(): + print(f"Error: File not found: {args.genes}") + sys.exit(1) + + try: + gene_list = read_gene_list(args.genes) + print(f"Read {len(gene_list)} genes from {args.genes}") + + # Read background if provided + background = None + if args.background: + if Path(args.background).exists(): + background = read_gene_list(args.background) + print(f"Read {len(background)} background genes from {args.background}") + else: + print(f"Warning: Background file not found: {args.background}") + + success = enrichment_pipeline( + gene_list, + species=args.species, + background=background, + output_prefix=args.output, + plot=not args.no_plot, + ) + + sys.exit(0 if success else 1) + + except KeyboardInterrupt: + print("\n\nAnalysis interrupted by user") + sys.exit(1) + except Exception as e: + print(f"\n\nError: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/gget/scripts/gene_analysis.py b/scientific-packages/gget/scripts/gene_analysis.py new file mode 100755 index 0000000..bed474f --- /dev/null +++ b/scientific-packages/gget/scripts/gene_analysis.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Gene Analysis Script +Quick analysis of a gene: search, info, sequences, expression, and enrichment +""" + +import argparse +import sys +import gget + + +def analyze_gene(gene_name, species="homo_sapiens", output_prefix=None): + """ + Perform comprehensive analysis of a gene. + + Args: + gene_name: Gene symbol to analyze + species: Species name (default: homo_sapiens) + output_prefix: Prefix for output files (default: gene_name) + """ + if output_prefix is None: + output_prefix = gene_name.lower() + + print(f"Analyzing gene: {gene_name}") + print("=" * 60) + + # Step 1: Search for the gene + print("\n1. Searching for gene...") + search_results = gget.search([gene_name], species=species, limit=1) + + if len(search_results) == 0: + print(f"Error: Gene '{gene_name}' not found in {species}") + return False + + gene_id = search_results["ensembl_id"].iloc[0] + print(f" Found: {gene_id}") + print(f" Description: {search_results['ensembl_description'].iloc[0]}") + + # Step 2: Get detailed information + print("\n2. Getting detailed information...") + gene_info = gget.info([gene_id], pdb=True) + gene_info.to_csv(f"{output_prefix}_info.csv", index=False) + print(f" Saved to: {output_prefix}_info.csv") + + if "uniprot_id" in gene_info.columns and gene_info["uniprot_id"].iloc[0]: + print(f" UniProt ID: {gene_info['uniprot_id'].iloc[0]}") + if "pdb_id" in gene_info.columns and gene_info["pdb_id"].iloc[0]: + print(f" PDB IDs: {gene_info['pdb_id'].iloc[0]}") + + # Step 3: Get sequences + print("\n3. Retrieving sequences...") + nucleotide_seq = gget.seq([gene_id]) + protein_seq = gget.seq([gene_id], translate=True) + + with open(f"{output_prefix}_nucleotide.fasta", "w") as f: + f.write(nucleotide_seq) + print(f" Nucleotide sequence saved to: {output_prefix}_nucleotide.fasta") + + with open(f"{output_prefix}_protein.fasta", "w") as f: + f.write(protein_seq) + print(f" Protein sequence saved to: {output_prefix}_protein.fasta") + + # Step 4: Get tissue expression + print("\n4. Getting tissue expression...") + try: + tissue_expr = gget.archs4(gene_name, which="tissue") + tissue_expr.to_csv(f"{output_prefix}_tissue_expression.csv", index=False) + print(f" Saved to: {output_prefix}_tissue_expression.csv") + + # Show top tissues + top_tissues = tissue_expr.nlargest(5, "median") + print("\n Top expressing tissues:") + for _, row in top_tissues.iterrows(): + print(f" {row['tissue']}: median = {row['median']:.2f}") + except Exception as e: + print(f" Warning: Could not retrieve ARCHS4 data: {e}") + + # Step 5: Find correlated genes + print("\n5. Finding correlated genes...") + try: + correlated = gget.archs4(gene_name, which="correlation") + correlated.to_csv(f"{output_prefix}_correlated_genes.csv", index=False) + print(f" Saved to: {output_prefix}_correlated_genes.csv") + + # Show top correlated + print("\n Top 10 correlated genes:") + for _, row in correlated.head(10).iterrows(): + print(f" {row['gene_symbol']}: r = {row['correlation']:.3f}") + except Exception as e: + print(f" Warning: Could not retrieve correlation data: {e}") + + # Step 6: Get disease associations + print("\n6. Getting disease associations...") + try: + diseases = gget.opentargets(gene_id, resource="diseases", limit=10) + diseases.to_csv(f"{output_prefix}_diseases.csv", index=False) + print(f" Saved to: {output_prefix}_diseases.csv") + + print("\n Top 5 disease associations:") + for _, row in diseases.head(5).iterrows(): + print(f" {row['disease_name']}: score = {row['overall_score']:.3f}") + except Exception as e: + print(f" Warning: Could not retrieve disease data: {e}") + + # Step 7: Get drug associations + print("\n7. Getting drug associations...") + try: + drugs = gget.opentargets(gene_id, resource="drugs", limit=10) + if len(drugs) > 0: + drugs.to_csv(f"{output_prefix}_drugs.csv", index=False) + print(f" Saved to: {output_prefix}_drugs.csv") + print(f"\n Found {len(drugs)} drug associations") + else: + print(" No drug associations found") + except Exception as e: + print(f" Warning: Could not retrieve drug data: {e}") + + print("\n" + "=" * 60) + print("Analysis complete!") + print(f"\nOutput files (prefix: {output_prefix}):") + print(f" - {output_prefix}_info.csv") + print(f" - {output_prefix}_nucleotide.fasta") + print(f" - {output_prefix}_protein.fasta") + print(f" - {output_prefix}_tissue_expression.csv") + print(f" - {output_prefix}_correlated_genes.csv") + print(f" - {output_prefix}_diseases.csv") + print(f" - {output_prefix}_drugs.csv (if available)") + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Perform comprehensive analysis of a gene using gget" + ) + parser.add_argument("gene", help="Gene symbol to analyze") + parser.add_argument( + "-s", + "--species", + default="homo_sapiens", + help="Species (default: homo_sapiens)", + ) + parser.add_argument( + "-o", "--output", help="Output prefix for files (default: gene name)" + ) + + args = parser.parse_args() + + try: + success = analyze_gene(args.gene, args.species, args.output) + sys.exit(0 if success else 1) + except KeyboardInterrupt: + print("\n\nAnalysis interrupted by user") + sys.exit(1) + except Exception as e: + print(f"\n\nError: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/matplotlib/SKILL.md b/scientific-packages/matplotlib/SKILL.md new file mode 100644 index 0000000..e115b2f --- /dev/null +++ b/scientific-packages/matplotlib/SKILL.md @@ -0,0 +1,355 @@ +--- +name: matplotlib +description: Comprehensive toolkit for creating publication-quality data visualizations in Python. Use this skill when creating plots, charts, or any scientific/statistical visualizations including line plots, scatter plots, bar charts, histograms, heatmaps, 3D plots, and more. Applies to tasks involving data visualization, figure generation, plot customization, or exporting graphics to various formats. +--- + +# Matplotlib + +## Overview + +Matplotlib is Python's foundational visualization library for creating static, animated, and interactive plots. This skill provides guidance on using matplotlib effectively, covering both the pyplot interface (MATLAB-style) and the object-oriented API (Figure/Axes), along with best practices for creating publication-quality visualizations. + +## When to Use This Skill + +Apply this skill when: +- Creating any type of plot or chart (line, scatter, bar, histogram, heatmap, contour, etc.) +- Generating scientific or statistical visualizations +- Customizing plot appearance (colors, styles, labels, legends) +- Creating multi-panel figures with subplots +- Exporting visualizations to various formats (PNG, PDF, SVG, etc.) +- Building interactive plots or animations +- Working with 3D visualizations +- Integrating plots into Jupyter notebooks or GUI applications + +## Core Concepts + +### The Matplotlib Hierarchy + +Matplotlib uses a hierarchical structure of objects: + +1. **Figure** - The top-level container for all plot elements +2. **Axes** - The actual plotting area where data is displayed (one Figure can contain multiple Axes) +3. **Artist** - Everything visible on the figure (lines, text, ticks, etc.) +4. **Axis** - The number line objects (x-axis, y-axis) that handle ticks and labels + +### Two Interfaces + +**1. pyplot Interface (Implicit, MATLAB-style)** +```python +import matplotlib.pyplot as plt + +plt.plot([1, 2, 3, 4]) +plt.ylabel('some numbers') +plt.show() +``` +- Convenient for quick, simple plots +- Maintains state automatically +- Good for interactive work and simple scripts + +**2. Object-Oriented Interface (Explicit)** +```python +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.plot([1, 2, 3, 4]) +ax.set_ylabel('some numbers') +plt.show() +``` +- **Recommended for most use cases** +- More explicit control over figure and axes +- Better for complex figures with multiple subplots +- Easier to maintain and debug + +## Common Workflows + +### 1. Basic Plot Creation + +**Single plot workflow:** +```python +import matplotlib.pyplot as plt +import numpy as np + +# Create figure and axes (OO interface - RECOMMENDED) +fig, ax = plt.subplots(figsize=(10, 6)) + +# Generate and plot data +x = np.linspace(0, 2*np.pi, 100) +ax.plot(x, np.sin(x), label='sin(x)') +ax.plot(x, np.cos(x), label='cos(x)') + +# Customize +ax.set_xlabel('x') +ax.set_ylabel('y') +ax.set_title('Trigonometric Functions') +ax.legend() +ax.grid(True, alpha=0.3) + +# Save and/or display +plt.savefig('plot.png', dpi=300, bbox_inches='tight') +plt.show() +``` + +### 2. Multiple Subplots + +**Creating subplot layouts:** +```python +# Method 1: Regular grid +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) +axes[0, 0].plot(x, y1) +axes[0, 1].scatter(x, y2) +axes[1, 0].bar(categories, values) +axes[1, 1].hist(data, bins=30) + +# Method 2: Mosaic layout (more flexible) +fig, axes = plt.subplot_mosaic([['left', 'right_top'], + ['left', 'right_bottom']], + figsize=(10, 8)) +axes['left'].plot(x, y) +axes['right_top'].scatter(x, y) +axes['right_bottom'].hist(data) + +# Method 3: GridSpec (maximum control) +from matplotlib.gridspec import GridSpec +fig = plt.figure(figsize=(12, 8)) +gs = GridSpec(3, 3, figure=fig) +ax1 = fig.add_subplot(gs[0, :]) # Top row, all columns +ax2 = fig.add_subplot(gs[1:, 0]) # Bottom two rows, first column +ax3 = fig.add_subplot(gs[1:, 1:]) # Bottom two rows, last two columns +``` + +### 3. Plot Types and Use Cases + +**Line plots** - Time series, continuous data, trends +```python +ax.plot(x, y, linewidth=2, linestyle='--', marker='o', color='blue') +``` + +**Scatter plots** - Relationships between variables, correlations +```python +ax.scatter(x, y, s=sizes, c=colors, alpha=0.6, cmap='viridis') +``` + +**Bar charts** - Categorical comparisons +```python +ax.bar(categories, values, color='steelblue', edgecolor='black') +# For horizontal bars: +ax.barh(categories, values) +``` + +**Histograms** - Distributions +```python +ax.hist(data, bins=30, edgecolor='black', alpha=0.7) +``` + +**Heatmaps** - Matrix data, correlations +```python +im = ax.imshow(matrix, cmap='coolwarm', aspect='auto') +plt.colorbar(im, ax=ax) +``` + +**Contour plots** - 3D data on 2D plane +```python +contour = ax.contour(X, Y, Z, levels=10) +ax.clabel(contour, inline=True, fontsize=8) +``` + +**Box plots** - Statistical distributions +```python +ax.boxplot([data1, data2, data3], labels=['A', 'B', 'C']) +``` + +**Violin plots** - Distribution densities +```python +ax.violinplot([data1, data2, data3], positions=[1, 2, 3]) +``` + +For comprehensive plot type examples and variations, refer to `references/plot_types.md`. + +### 4. Styling and Customization + +**Color specification methods:** +- Named colors: `'red'`, `'blue'`, `'steelblue'` +- Hex codes: `'#FF5733'` +- RGB tuples: `(0.1, 0.2, 0.3)` +- Colormaps: `cmap='viridis'`, `cmap='plasma'`, `cmap='coolwarm'` + +**Using style sheets:** +```python +plt.style.use('seaborn-v0_8-darkgrid') # Apply predefined style +# Available styles: 'ggplot', 'bmh', 'fivethirtyeight', etc. +print(plt.style.available) # List all available styles +``` + +**Customizing with rcParams:** +```python +plt.rcParams['font.size'] = 12 +plt.rcParams['axes.labelsize'] = 14 +plt.rcParams['axes.titlesize'] = 16 +plt.rcParams['xtick.labelsize'] = 10 +plt.rcParams['ytick.labelsize'] = 10 +plt.rcParams['legend.fontsize'] = 12 +plt.rcParams['figure.titlesize'] = 18 +``` + +**Text and annotations:** +```python +ax.text(x, y, 'annotation', fontsize=12, ha='center') +ax.annotate('important point', xy=(x, y), xytext=(x+1, y+1), + arrowprops=dict(arrowstyle='->', color='red')) +``` + +For detailed styling options and colormap guidelines, see `references/styling_guide.md`. + +### 5. Saving Figures + +**Export to various formats:** +```python +# High-resolution PNG for presentations/papers +plt.savefig('figure.png', dpi=300, bbox_inches='tight', facecolor='white') + +# Vector format for publications (scalable) +plt.savefig('figure.pdf', bbox_inches='tight') +plt.savefig('figure.svg', bbox_inches='tight') + +# Transparent background +plt.savefig('figure.png', dpi=300, bbox_inches='tight', transparent=True) +``` + +**Important parameters:** +- `dpi`: Resolution (300 for publications, 150 for web, 72 for screen) +- `bbox_inches='tight'`: Removes excess whitespace +- `facecolor='white'`: Ensures white background (useful for transparent themes) +- `transparent=True`: Transparent background + +### 6. Working with 3D Plots + +```python +from mpl_toolkits.mplot3d import Axes3D + +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') + +# Surface plot +ax.plot_surface(X, Y, Z, cmap='viridis') + +# 3D scatter +ax.scatter(x, y, z, c=colors, marker='o') + +# 3D line plot +ax.plot(x, y, z, linewidth=2) + +# Labels +ax.set_xlabel('X Label') +ax.set_ylabel('Y Label') +ax.set_zlabel('Z Label') +``` + +## Best Practices + +### 1. Interface Selection +- **Use the object-oriented interface** (fig, ax = plt.subplots()) for production code +- Reserve pyplot interface for quick interactive exploration only +- Always create figures explicitly rather than relying on implicit state + +### 2. Figure Size and DPI +- Set figsize at creation: `fig, ax = plt.subplots(figsize=(10, 6))` +- Use appropriate DPI for output medium: + - Screen/notebook: 72-100 dpi + - Web: 150 dpi + - Print/publications: 300 dpi + +### 3. Layout Management +- Use `constrained_layout=True` or `tight_layout()` to prevent overlapping elements +- `fig, ax = plt.subplots(constrained_layout=True)` is recommended for automatic spacing + +### 4. Colormap Selection +- **Sequential** (viridis, plasma, inferno): Ordered data with consistent progression +- **Diverging** (coolwarm, RdBu): Data with meaningful center point (e.g., zero) +- **Qualitative** (tab10, Set3): Categorical/nominal data +- Avoid rainbow colormaps (jet) - they are not perceptually uniform + +### 5. Accessibility +- Use colorblind-friendly colormaps (viridis, cividis) +- Add patterns/hatching for bar charts in addition to colors +- Ensure sufficient contrast between elements +- Include descriptive labels and legends + +### 6. Performance +- For large datasets, use `rasterized=True` in plot calls to reduce file size +- Use appropriate data reduction before plotting (e.g., downsample dense time series) +- For animations, use blitting for better performance + +### 7. Code Organization +```python +# Good practice: Clear structure +def create_analysis_plot(data, title): + """Create standardized analysis plot.""" + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + # Plot data + ax.plot(data['x'], data['y'], linewidth=2) + + # Customize + ax.set_xlabel('X Axis Label', fontsize=12) + ax.set_ylabel('Y Axis Label', fontsize=12) + ax.set_title(title, fontsize=14, fontweight='bold') + ax.grid(True, alpha=0.3) + + return fig, ax + +# Use the function +fig, ax = create_analysis_plot(my_data, 'My Analysis') +plt.savefig('analysis.png', dpi=300, bbox_inches='tight') +``` + +## Quick Reference Scripts + +This skill includes helper scripts in the `scripts/` directory: + +### `plot_template.py` +Template script demonstrating various plot types with best practices. Use this as a starting point for creating new visualizations. + +**Usage:** +```bash +python scripts/plot_template.py +``` + +### `style_configurator.py` +Interactive utility to configure matplotlib style preferences and generate custom style sheets. + +**Usage:** +```bash +python scripts/style_configurator.py +``` + +## Detailed References + +For comprehensive information, consult the reference documents: + +- **`references/plot_types.md`** - Complete catalog of plot types with code examples and use cases +- **`references/styling_guide.md`** - Detailed styling options, colormaps, and customization +- **`references/api_reference.md`** - Core classes and methods reference +- **`references/common_issues.md`** - Troubleshooting guide for common problems + +## Integration with Other Tools + +Matplotlib integrates well with: +- **NumPy/Pandas** - Direct plotting from arrays and DataFrames +- **Seaborn** - High-level statistical visualizations built on matplotlib +- **Jupyter** - Interactive plotting with `%matplotlib inline` or `%matplotlib widget` +- **GUI frameworks** - Embedding in Tkinter, Qt, wxPython applications + +## Common Gotchas + +1. **Overlapping elements**: Use `constrained_layout=True` or `tight_layout()` +2. **State confusion**: Use OO interface to avoid pyplot state machine issues +3. **Memory issues with many figures**: Close figures explicitly with `plt.close(fig)` +4. **Font warnings**: Install fonts or suppress warnings with `plt.rcParams['font.sans-serif']` +5. **DPI confusion**: Remember that figsize is in inches, not pixels: `pixels = dpi * inches` + +## Additional Resources + +- Official documentation: https://matplotlib.org/ +- Gallery: https://matplotlib.org/stable/gallery/index.html +- Cheatsheets: https://matplotlib.org/cheatsheets/ +- Tutorials: https://matplotlib.org/stable/tutorials/index.html diff --git a/scientific-packages/matplotlib/references/api_reference.md b/scientific-packages/matplotlib/references/api_reference.md new file mode 100644 index 0000000..9ca3c61 --- /dev/null +++ b/scientific-packages/matplotlib/references/api_reference.md @@ -0,0 +1,412 @@ +# Matplotlib API Reference + +This document provides a quick reference for the most commonly used matplotlib classes and methods. + +## Core Classes + +### Figure + +The top-level container for all plot elements. + +**Creation:** +```python +fig = plt.figure(figsize=(10, 6), dpi=100, facecolor='white') +fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 6)) +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) +``` + +**Key Methods:** +- `fig.add_subplot(nrows, ncols, index)` - Add a subplot +- `fig.add_axes([left, bottom, width, height])` - Add axes at specific position +- `fig.savefig(filename, dpi=300, bbox_inches='tight')` - Save figure +- `fig.tight_layout()` - Adjust spacing to prevent overlaps +- `fig.suptitle(title)` - Set figure title +- `fig.legend()` - Create figure-level legend +- `fig.colorbar(mappable)` - Add colorbar to figure +- `plt.close(fig)` - Close figure to free memory + +**Key Attributes:** +- `fig.axes` - List of all axes in the figure +- `fig.dpi` - Resolution in dots per inch +- `fig.figsize` - Figure dimensions in inches (width, height) + +### Axes + +The actual plotting area where data is visualized. + +**Creation:** +```python +fig, ax = plt.subplots() # Single axes +ax = fig.add_subplot(111) # Alternative method +``` + +**Plotting Methods:** + +**Line plots:** +- `ax.plot(x, y, **kwargs)` - Line plot +- `ax.step(x, y, where='pre'/'mid'/'post')` - Step plot +- `ax.errorbar(x, y, yerr, xerr)` - Error bars + +**Scatter plots:** +- `ax.scatter(x, y, s=size, c=color, marker='o', alpha=0.5)` - Scatter plot + +**Bar charts:** +- `ax.bar(x, height, width=0.8, align='center')` - Vertical bar chart +- `ax.barh(y, width)` - Horizontal bar chart + +**Statistical plots:** +- `ax.hist(data, bins=10, density=False)` - Histogram +- `ax.boxplot(data, labels=None)` - Box plot +- `ax.violinplot(data)` - Violin plot + +**2D plots:** +- `ax.imshow(array, cmap='viridis', aspect='auto')` - Display image/matrix +- `ax.contour(X, Y, Z, levels=10)` - Contour lines +- `ax.contourf(X, Y, Z, levels=10)` - Filled contours +- `ax.pcolormesh(X, Y, Z)` - Pseudocolor plot + +**Filling:** +- `ax.fill_between(x, y1, y2, alpha=0.3)` - Fill between curves +- `ax.fill_betweenx(y, x1, x2)` - Fill between vertical curves + +**Text and annotations:** +- `ax.text(x, y, text, fontsize=12)` - Add text +- `ax.annotate(text, xy=(x, y), xytext=(x2, y2), arrowprops={})` - Annotate with arrow + +**Customization Methods:** + +**Labels and titles:** +- `ax.set_xlabel(label, fontsize=12)` - Set x-axis label +- `ax.set_ylabel(label, fontsize=12)` - Set y-axis label +- `ax.set_title(title, fontsize=14)` - Set axes title + +**Limits and scales:** +- `ax.set_xlim(left, right)` - Set x-axis limits +- `ax.set_ylim(bottom, top)` - Set y-axis limits +- `ax.set_xscale('linear'/'log'/'symlog')` - Set x-axis scale +- `ax.set_yscale('linear'/'log'/'symlog')` - Set y-axis scale + +**Ticks:** +- `ax.set_xticks(positions)` - Set x-tick positions +- `ax.set_xticklabels(labels)` - Set x-tick labels +- `ax.tick_params(axis='both', labelsize=10)` - Customize tick appearance + +**Grid and spines:** +- `ax.grid(True, alpha=0.3, linestyle='--')` - Add grid +- `ax.spines['top'].set_visible(False)` - Hide top spine +- `ax.spines['right'].set_visible(False)` - Hide right spine + +**Legend:** +- `ax.legend(loc='best', fontsize=10, frameon=True)` - Add legend +- `ax.legend(handles, labels)` - Custom legend + +**Aspect and layout:** +- `ax.set_aspect('equal'/'auto'/ratio)` - Set aspect ratio +- `ax.invert_xaxis()` - Invert x-axis +- `ax.invert_yaxis()` - Invert y-axis + +### pyplot Module + +High-level interface for quick plotting. + +**Figure creation:** +- `plt.figure()` - Create new figure +- `plt.subplots()` - Create figure and axes +- `plt.subplot()` - Add subplot to current figure + +**Plotting (uses current axes):** +- `plt.plot()` - Line plot +- `plt.scatter()` - Scatter plot +- `plt.bar()` - Bar chart +- `plt.hist()` - Histogram +- (All axes methods available) + +**Display and save:** +- `plt.show()` - Display figure +- `plt.savefig()` - Save figure +- `plt.close()` - Close figure + +**Style:** +- `plt.style.use(style_name)` - Apply style sheet +- `plt.style.available` - List available styles + +**State management:** +- `plt.gca()` - Get current axes +- `plt.gcf()` - Get current figure +- `plt.sca(ax)` - Set current axes +- `plt.clf()` - Clear current figure +- `plt.cla()` - Clear current axes + +## Line and Marker Styles + +### Line Styles +- `'-'` or `'solid'` - Solid line +- `'--'` or `'dashed'` - Dashed line +- `'-.'` or `'dashdot'` - Dash-dot line +- `':'` or `'dotted'` - Dotted line +- `''` or `' '` or `'None'` - No line + +### Marker Styles +- `'.'` - Point marker +- `'o'` - Circle marker +- `'v'`, `'^'`, `'<'`, `'>'` - Triangle markers +- `'s'` - Square marker +- `'p'` - Pentagon marker +- `'*'` - Star marker +- `'h'`, `'H'` - Hexagon markers +- `'+'` - Plus marker +- `'x'` - X marker +- `'D'`, `'d'` - Diamond markers + +### Color Specifications + +**Single character shortcuts:** +- `'b'` - Blue +- `'g'` - Green +- `'r'` - Red +- `'c'` - Cyan +- `'m'` - Magenta +- `'y'` - Yellow +- `'k'` - Black +- `'w'` - White + +**Named colors:** +- `'steelblue'`, `'coral'`, `'teal'`, etc. +- See full list: https://matplotlib.org/stable/gallery/color/named_colors.html + +**Other formats:** +- Hex: `'#FF5733'` +- RGB tuple: `(0.1, 0.2, 0.3)` +- RGBA tuple: `(0.1, 0.2, 0.3, 0.5)` + +## Common Parameters + +### Plot Function Parameters + +```python +ax.plot(x, y, + color='blue', # Line color + linewidth=2, # Line width + linestyle='--', # Line style + marker='o', # Marker style + markersize=8, # Marker size + markerfacecolor='red', # Marker fill color + markeredgecolor='black',# Marker edge color + markeredgewidth=1, # Marker edge width + alpha=0.7, # Transparency (0-1) + label='data', # Legend label + zorder=2, # Drawing order + rasterized=True # Rasterize for smaller file size +) +``` + +### Scatter Function Parameters + +```python +ax.scatter(x, y, + s=50, # Size (scalar or array) + c='blue', # Color (scalar, array, or sequence) + marker='o', # Marker style + cmap='viridis', # Colormap (if c is numeric) + alpha=0.5, # Transparency + edgecolors='black', # Edge color + linewidths=1, # Edge width + vmin=0, vmax=1, # Color scale limits + label='data' # Legend label +) +``` + +### Text Parameters + +```python +ax.text(x, y, text, + fontsize=12, # Font size + fontweight='normal', # 'normal', 'bold', 'heavy', 'light' + fontstyle='normal', # 'normal', 'italic', 'oblique' + fontfamily='sans-serif',# Font family + color='black', # Text color + alpha=1.0, # Transparency + ha='center', # Horizontal alignment: 'left', 'center', 'right' + va='center', # Vertical alignment: 'top', 'center', 'bottom', 'baseline' + rotation=0, # Rotation angle in degrees + bbox=dict( # Background box + facecolor='white', + edgecolor='black', + boxstyle='round' + ) +) +``` + +## rcParams Configuration + +Common rcParams settings for global customization: + +```python +# Font settings +plt.rcParams['font.family'] = 'sans-serif' +plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica'] +plt.rcParams['font.size'] = 12 + +# Figure settings +plt.rcParams['figure.figsize'] = (10, 6) +plt.rcParams['figure.dpi'] = 100 +plt.rcParams['figure.facecolor'] = 'white' +plt.rcParams['savefig.dpi'] = 300 +plt.rcParams['savefig.bbox'] = 'tight' + +# Axes settings +plt.rcParams['axes.labelsize'] = 14 +plt.rcParams['axes.titlesize'] = 16 +plt.rcParams['axes.grid'] = True +plt.rcParams['axes.grid.alpha'] = 0.3 + +# Line settings +plt.rcParams['lines.linewidth'] = 2 +plt.rcParams['lines.markersize'] = 8 + +# Tick settings +plt.rcParams['xtick.labelsize'] = 10 +plt.rcParams['ytick.labelsize'] = 10 +plt.rcParams['xtick.direction'] = 'in' # 'in', 'out', 'inout' +plt.rcParams['ytick.direction'] = 'in' + +# Legend settings +plt.rcParams['legend.fontsize'] = 12 +plt.rcParams['legend.frameon'] = True +plt.rcParams['legend.framealpha'] = 0.8 + +# Grid settings +plt.rcParams['grid.alpha'] = 0.3 +plt.rcParams['grid.linestyle'] = '--' +``` + +## GridSpec for Complex Layouts + +```python +from matplotlib.gridspec import GridSpec + +fig = plt.figure(figsize=(12, 8)) +gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3) + +# Span multiple cells +ax1 = fig.add_subplot(gs[0, :]) # Top row, all columns +ax2 = fig.add_subplot(gs[1:, 0]) # Bottom two rows, first column +ax3 = fig.add_subplot(gs[1, 1:]) # Middle row, last two columns +ax4 = fig.add_subplot(gs[2, 1]) # Bottom row, middle column +ax5 = fig.add_subplot(gs[2, 2]) # Bottom row, right column +``` + +## 3D Plotting + +```python +from mpl_toolkits.mplot3d import Axes3D + +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') + +# Plot types +ax.plot(x, y, z) # 3D line +ax.scatter(x, y, z) # 3D scatter +ax.plot_surface(X, Y, Z) # 3D surface +ax.plot_wireframe(X, Y, Z) # 3D wireframe +ax.contour(X, Y, Z) # 3D contour +ax.bar3d(x, y, z, dx, dy, dz) # 3D bar + +# Customization +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_zlabel('Z') +ax.view_init(elev=30, azim=45) # Set viewing angle +``` + +## Animation + +```python +from matplotlib.animation import FuncAnimation + +fig, ax = plt.subplots() +line, = ax.plot([], []) + +def init(): + ax.set_xlim(0, 2*np.pi) + ax.set_ylim(-1, 1) + return line, + +def update(frame): + x = np.linspace(0, 2*np.pi, 100) + y = np.sin(x + frame/10) + line.set_data(x, y) + return line, + +anim = FuncAnimation(fig, update, init_func=init, + frames=100, interval=50, blit=True) + +# Save animation +anim.save('animation.gif', writer='pillow', fps=20) +anim.save('animation.mp4', writer='ffmpeg', fps=20) +``` + +## Image Operations + +```python +# Read and display image +img = plt.imread('image.png') +ax.imshow(img) + +# Display matrix as image +ax.imshow(matrix, cmap='viridis', aspect='auto', + interpolation='nearest', origin='lower') + +# Colorbar +cbar = plt.colorbar(im, ax=ax) +cbar.set_label('Values') + +# Image extent (set coordinates) +ax.imshow(img, extent=[x_min, x_max, y_min, y_max]) +``` + +## Event Handling + +```python +# Mouse click event +def on_click(event): + if event.inaxes: + print(f'Clicked at x={event.xdata:.2f}, y={event.ydata:.2f}') + +fig.canvas.mpl_connect('button_press_event', on_click) + +# Key press event +def on_key(event): + print(f'Key pressed: {event.key}') + +fig.canvas.mpl_connect('key_press_event', on_key) +``` + +## Useful Utilities + +```python +# Get current axis limits +xlims = ax.get_xlim() +ylims = ax.get_ylim() + +# Set equal aspect ratio +ax.set_aspect('equal', adjustable='box') + +# Share axes between subplots +fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) + +# Twin axes (two y-axes) +ax2 = ax1.twinx() + +# Remove tick labels +ax.set_xticklabels([]) +ax.set_yticklabels([]) + +# Scientific notation +ax.ticklabel_format(style='scientific', axis='y', scilimits=(0,0)) + +# Date formatting +import matplotlib.dates as mdates +ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) +ax.xaxis.set_major_locator(mdates.DayLocator(interval=7)) +``` diff --git a/scientific-packages/matplotlib/references/common_issues.md b/scientific-packages/matplotlib/references/common_issues.md new file mode 100644 index 0000000..e1304c7 --- /dev/null +++ b/scientific-packages/matplotlib/references/common_issues.md @@ -0,0 +1,563 @@ +# Matplotlib Common Issues and Solutions + +Troubleshooting guide for frequently encountered matplotlib problems. + +## Display and Backend Issues + +### Issue: Plots Not Showing + +**Problem:** `plt.show()` doesn't display anything + +**Solutions:** +```python +# 1. Check if backend is properly set (for interactive use) +import matplotlib +print(matplotlib.get_backend()) + +# 2. Try different backends +matplotlib.use('TkAgg') # or 'Qt5Agg', 'MacOSX' +import matplotlib.pyplot as plt + +# 3. In Jupyter notebooks, use magic command +%matplotlib inline # Static images +# or +%matplotlib widget # Interactive plots + +# 4. Ensure plt.show() is called +plt.plot([1, 2, 3]) +plt.show() +``` + +### Issue: "RuntimeError: main thread is not in main loop" + +**Problem:** Interactive mode issues with threading + +**Solution:** +```python +# Switch to non-interactive backend +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +# Or turn off interactive mode +plt.ioff() +``` + +### Issue: Figures Not Updating Interactively + +**Problem:** Changes not reflected in interactive windows + +**Solution:** +```python +# Enable interactive mode +plt.ion() + +# Draw after each change +plt.plot(x, y) +plt.draw() +plt.pause(0.001) # Brief pause to update display +``` + +## Layout and Spacing Issues + +### Issue: Overlapping Labels and Titles + +**Problem:** Labels, titles, or tick labels overlap or get cut off + +**Solutions:** +```python +# Solution 1: Constrained layout (RECOMMENDED) +fig, ax = plt.subplots(constrained_layout=True) + +# Solution 2: Tight layout +fig, ax = plt.subplots() +plt.tight_layout() + +# Solution 3: Adjust margins manually +plt.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + +# Solution 4: Save with bbox_inches='tight' +plt.savefig('figure.png', bbox_inches='tight') + +# Solution 5: Rotate long tick labels +ax.set_xticklabels(labels, rotation=45, ha='right') +``` + +### Issue: Colorbar Affects Subplot Size + +**Problem:** Adding colorbar shrinks the plot + +**Solution:** +```python +# Solution 1: Use constrained layout +fig, ax = plt.subplots(constrained_layout=True) +im = ax.imshow(data) +plt.colorbar(im, ax=ax) + +# Solution 2: Manually specify colorbar dimensions +from mpl_toolkits.axes_grid1 import make_axes_locatable +divider = make_axes_locatable(ax) +cax = divider.append_axes("right", size="5%", pad=0.05) +plt.colorbar(im, cax=cax) + +# Solution 3: For multiple subplots, share colorbar +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +for ax in axes: + im = ax.imshow(data) +fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.95) +``` + +### Issue: Subplots Too Close Together + +**Problem:** Multiple subplots overlapping + +**Solution:** +```python +# Solution 1: Use constrained_layout +fig, axes = plt.subplots(2, 2, constrained_layout=True) + +# Solution 2: Adjust spacing with subplots_adjust +fig, axes = plt.subplots(2, 2) +plt.subplots_adjust(hspace=0.4, wspace=0.4) + +# Solution 3: Specify spacing in tight_layout +plt.tight_layout(h_pad=2.0, w_pad=2.0) +``` + +## Memory and Performance Issues + +### Issue: Memory Leak with Multiple Figures + +**Problem:** Memory usage grows when creating many figures + +**Solution:** +```python +# Close figures explicitly +fig, ax = plt.subplots() +ax.plot(x, y) +plt.savefig('plot.png') +plt.close(fig) # or plt.close('all') + +# Clear current figure without closing +plt.clf() + +# Clear current axes +plt.cla() +``` + +### Issue: Large File Sizes + +**Problem:** Saved figures are too large + +**Solutions:** +```python +# Solution 1: Reduce DPI +plt.savefig('figure.png', dpi=150) # Instead of 300 + +# Solution 2: Use rasterization for complex plots +ax.plot(x, y, rasterized=True) + +# Solution 3: Use vector format for simple plots +plt.savefig('figure.pdf') # or .svg + +# Solution 4: Compress PNG +plt.savefig('figure.png', dpi=300, optimize=True) +``` + +### Issue: Slow Plotting with Large Datasets + +**Problem:** Plotting takes too long with many points + +**Solutions:** +```python +# Solution 1: Downsample data +from scipy.signal import decimate +y_downsampled = decimate(y, 10) # Keep every 10th point + +# Solution 2: Use rasterization +ax.plot(x, y, rasterized=True) + +# Solution 3: Use line simplification +ax.plot(x, y) +for line in ax.get_lines(): + line.set_rasterized(True) + +# Solution 4: For scatter plots, consider hexbin or 2d histogram +ax.hexbin(x, y, gridsize=50, cmap='viridis') +``` + +## Font and Text Issues + +### Issue: Font Warnings + +**Problem:** "findfont: Font family [...] not found" + +**Solutions:** +```python +# Solution 1: Use available fonts +from matplotlib.font_manager import findfont, FontProperties +print(findfont(FontProperties(family='sans-serif'))) + +# Solution 2: Rebuild font cache +import matplotlib.font_manager +matplotlib.font_manager._rebuild() + +# Solution 3: Suppress warnings +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +# Solution 4: Specify fallback fonts +plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'sans-serif'] +``` + +### Issue: LaTeX Rendering Errors + +**Problem:** Math text not rendering correctly + +**Solutions:** +```python +# Solution 1: Use raw strings with r prefix +ax.set_xlabel(r'$\alpha$') # Not '\alpha' + +# Solution 2: Escape backslashes in regular strings +ax.set_xlabel('$\\alpha$') + +# Solution 3: Disable LaTeX if not installed +plt.rcParams['text.usetex'] = False + +# Solution 4: Use mathtext instead of full LaTeX +# Mathtext is always available, no LaTeX installation needed +ax.text(x, y, r'$\int_0^\infty e^{-x} dx$') +``` + +### Issue: Text Cut Off or Outside Figure + +**Problem:** Labels or annotations appear outside figure bounds + +**Solutions:** +```python +# Solution 1: Use bbox_inches='tight' +plt.savefig('figure.png', bbox_inches='tight') + +# Solution 2: Adjust figure bounds +plt.subplots_adjust(left=0.15, right=0.85, top=0.85, bottom=0.15) + +# Solution 3: Clip text to axes +ax.text(x, y, 'text', clip_on=True) + +# Solution 4: Use constrained_layout +fig, ax = plt.subplots(constrained_layout=True) +``` + +## Color and Colormap Issues + +### Issue: Colorbar Not Matching Plot + +**Problem:** Colorbar shows different range than data + +**Solution:** +```python +# Explicitly set vmin and vmax +im = ax.imshow(data, vmin=0, vmax=1, cmap='viridis') +plt.colorbar(im, ax=ax) + +# Or use the same norm for multiple plots +import matplotlib.colors as mcolors +norm = mcolors.Normalize(vmin=data.min(), vmax=data.max()) +im1 = ax1.imshow(data1, norm=norm, cmap='viridis') +im2 = ax2.imshow(data2, norm=norm, cmap='viridis') +``` + +### Issue: Colors Look Wrong + +**Problem:** Unexpected colors in plots + +**Solutions:** +```python +# Solution 1: Check color specification format +ax.plot(x, y, color='blue') # Correct +ax.plot(x, y, color=(0, 0, 1)) # Correct RGB +ax.plot(x, y, color='#0000FF') # Correct hex + +# Solution 2: Verify colormap exists +print(plt.colormaps()) # List available colormaps + +# Solution 3: For scatter plots, ensure c shape matches +ax.scatter(x, y, c=colors) # colors should have same length as x, y + +# Solution 4: Check if alpha is set correctly +ax.plot(x, y, alpha=1.0) # 0=transparent, 1=opaque +``` + +### Issue: Reversed Colormap + +**Problem:** Colormap direction is backwards + +**Solution:** +```python +# Add _r suffix to reverse any colormap +ax.imshow(data, cmap='viridis_r') +``` + +## Axis and Scale Issues + +### Issue: Axis Limits Not Working + +**Problem:** `set_xlim` or `set_ylim` not taking effect + +**Solutions:** +```python +# Solution 1: Set after plotting +ax.plot(x, y) +ax.set_xlim(0, 10) +ax.set_ylim(-1, 1) + +# Solution 2: Disable autoscaling +ax.autoscale(False) +ax.set_xlim(0, 10) + +# Solution 3: Use axis method +ax.axis([xmin, xmax, ymin, ymax]) +``` + +### Issue: Log Scale with Zero or Negative Values + +**Problem:** ValueError when using log scale with data ≤ 0 + +**Solutions:** +```python +# Solution 1: Filter out non-positive values +mask = (data > 0) +ax.plot(x[mask], data[mask]) +ax.set_yscale('log') + +# Solution 2: Use symlog for data with positive and negative values +ax.set_yscale('symlog') + +# Solution 3: Add small offset +ax.plot(x, data + 1e-10) +ax.set_yscale('log') +``` + +### Issue: Dates Not Displaying Correctly + +**Problem:** Date axis shows numbers instead of dates + +**Solution:** +```python +import matplotlib.dates as mdates +import pandas as pd + +# Convert to datetime if needed +dates = pd.to_datetime(date_strings) + +ax.plot(dates, values) + +# Format date axis +ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) +ax.xaxis.set_major_locator(mdates.DayLocator(interval=7)) +plt.xticks(rotation=45) +``` + +## Legend Issues + +### Issue: Legend Covers Data + +**Problem:** Legend obscures important parts of plot + +**Solutions:** +```python +# Solution 1: Use 'best' location +ax.legend(loc='best') + +# Solution 2: Place outside plot area +ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + +# Solution 3: Make legend semi-transparent +ax.legend(framealpha=0.7) + +# Solution 4: Put legend below plot +ax.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=3) +``` + +### Issue: Too Many Items in Legend + +**Problem:** Legend is cluttered with many entries + +**Solutions:** +```python +# Solution 1: Only label selected items +for i, (x, y) in enumerate(data): + label = f'Data {i}' if i % 5 == 0 else None + ax.plot(x, y, label=label) + +# Solution 2: Use multiple columns +ax.legend(ncol=3) + +# Solution 3: Create custom legend with fewer entries +from matplotlib.lines import Line2D +custom_lines = [Line2D([0], [0], color='r'), + Line2D([0], [0], color='b')] +ax.legend(custom_lines, ['Category A', 'Category B']) + +# Solution 4: Use separate legend figure +fig_leg = plt.figure(figsize=(3, 2)) +ax_leg = fig_leg.add_subplot(111) +ax_leg.legend(*ax.get_legend_handles_labels(), loc='center') +ax_leg.axis('off') +``` + +## 3D Plot Issues + +### Issue: 3D Plots Look Flat + +**Problem:** Difficult to perceive depth in 3D plots + +**Solutions:** +```python +# Solution 1: Adjust viewing angle +ax.view_init(elev=30, azim=45) + +# Solution 2: Add gridlines +ax.grid(True) + +# Solution 3: Use color for depth +scatter = ax.scatter(x, y, z, c=z, cmap='viridis') + +# Solution 4: Rotate interactively (if using interactive backend) +# User can click and drag to rotate +``` + +### Issue: 3D Axis Labels Cut Off + +**Problem:** 3D axis labels appear outside figure + +**Solution:** +```python +from mpl_toolkits.mplot3d import Axes3D + +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_surface(X, Y, Z) + +# Add padding +fig.tight_layout(pad=3.0) + +# Or save with tight bounding box +plt.savefig('3d_plot.png', bbox_inches='tight', pad_inches=0.5) +``` + +## Image and Colorbar Issues + +### Issue: Images Appear Flipped + +**Problem:** Image orientation is wrong + +**Solution:** +```python +# Set origin parameter +ax.imshow(img, origin='lower') # or 'upper' (default) + +# Or flip array +ax.imshow(np.flipud(img)) +``` + +### Issue: Images Look Pixelated + +**Problem:** Image appears blocky when zoomed + +**Solutions:** +```python +# Solution 1: Use interpolation +ax.imshow(img, interpolation='bilinear') +# Options: 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', etc. + +# Solution 2: Increase DPI when saving +plt.savefig('figure.png', dpi=300) + +# Solution 3: Use vector format if appropriate +plt.savefig('figure.pdf') +``` + +## Common Errors and Fixes + +### "TypeError: 'AxesSubplot' object is not subscriptable" + +**Problem:** Trying to index single axes +```python +# Wrong +fig, ax = plt.subplots() +ax[0].plot(x, y) # Error! + +# Correct +fig, ax = plt.subplots() +ax.plot(x, y) +``` + +### "ValueError: x and y must have same first dimension" + +**Problem:** Data arrays have mismatched lengths +```python +# Check shapes +print(f"x shape: {x.shape}, y shape: {y.shape}") + +# Ensure they match +assert len(x) == len(y), "x and y must have same length" +``` + +### "AttributeError: 'numpy.ndarray' object has no attribute 'plot'" + +**Problem:** Calling plot on array instead of axes +```python +# Wrong +data.plot(x, y) + +# Correct +ax.plot(x, y) +# or for pandas +data.plot(ax=ax) +``` + +## Best Practices to Avoid Issues + +1. **Always use the OO interface** - Avoid pyplot state machine + ```python + fig, ax = plt.subplots() # Good + ax.plot(x, y) + ``` + +2. **Use constrained_layout** - Prevents overlap issues + ```python + fig, ax = plt.subplots(constrained_layout=True) + ``` + +3. **Close figures explicitly** - Prevents memory leaks + ```python + plt.close(fig) + ``` + +4. **Set figure size at creation** - Better than resizing later + ```python + fig, ax = plt.subplots(figsize=(10, 6)) + ``` + +5. **Use raw strings for math text** - Avoids escape issues + ```python + ax.set_xlabel(r'$\alpha$') + ``` + +6. **Check data shapes before plotting** - Catch size mismatches early + ```python + assert len(x) == len(y) + ``` + +7. **Use appropriate DPI** - 300 for print, 150 for web + ```python + plt.savefig('figure.png', dpi=300) + ``` + +8. **Test with different backends** - If display issues occur + ```python + import matplotlib + matplotlib.use('TkAgg') + ``` diff --git a/scientific-packages/matplotlib/references/plot_types.md b/scientific-packages/matplotlib/references/plot_types.md new file mode 100644 index 0000000..2aad9aa --- /dev/null +++ b/scientific-packages/matplotlib/references/plot_types.md @@ -0,0 +1,476 @@ +# Matplotlib Plot Types Guide + +Comprehensive guide to different plot types in matplotlib with examples and use cases. + +## 1. Line Plots + +**Use cases:** Time series, continuous data, trends, function visualization + +### Basic Line Plot +```python +fig, ax = plt.subplots(figsize=(10, 6)) +ax.plot(x, y, linewidth=2, label='Data') +ax.set_xlabel('X axis') +ax.set_ylabel('Y axis') +ax.legend() +``` + +### Multiple Lines +```python +ax.plot(x, y1, label='Dataset 1', linewidth=2) +ax.plot(x, y2, label='Dataset 2', linewidth=2, linestyle='--') +ax.plot(x, y3, label='Dataset 3', linewidth=2, linestyle=':') +ax.legend() +``` + +### Line with Markers +```python +ax.plot(x, y, marker='o', markersize=8, linestyle='-', + linewidth=2, markerfacecolor='red', markeredgecolor='black') +``` + +### Step Plot +```python +ax.step(x, y, where='mid', linewidth=2, label='Step function') +# where options: 'pre', 'post', 'mid' +``` + +### Error Bars +```python +ax.errorbar(x, y, yerr=error, fmt='o-', linewidth=2, + capsize=5, capthick=2, label='With uncertainty') +``` + +## 2. Scatter Plots + +**Use cases:** Correlations, relationships between variables, clusters, outliers + +### Basic Scatter +```python +ax.scatter(x, y, s=50, alpha=0.6) +``` + +### Sized and Colored Scatter +```python +scatter = ax.scatter(x, y, s=sizes*100, c=colors, + cmap='viridis', alpha=0.6, edgecolors='black') +plt.colorbar(scatter, ax=ax, label='Color variable') +``` + +### Categorical Scatter +```python +for category in categories: + mask = data['category'] == category + ax.scatter(data[mask]['x'], data[mask]['y'], + label=category, s=50, alpha=0.7) +ax.legend() +``` + +## 3. Bar Charts + +**Use cases:** Categorical comparisons, discrete data, counts + +### Vertical Bar Chart +```python +ax.bar(categories, values, color='steelblue', + edgecolor='black', linewidth=1.5) +ax.set_ylabel('Values') +``` + +### Horizontal Bar Chart +```python +ax.barh(categories, values, color='coral', + edgecolor='black', linewidth=1.5) +ax.set_xlabel('Values') +``` + +### Grouped Bar Chart +```python +x = np.arange(len(categories)) +width = 0.35 + +ax.bar(x - width/2, values1, width, label='Group 1') +ax.bar(x + width/2, values2, width, label='Group 2') +ax.set_xticks(x) +ax.set_xticklabels(categories) +ax.legend() +``` + +### Stacked Bar Chart +```python +ax.bar(categories, values1, label='Part 1') +ax.bar(categories, values2, bottom=values1, label='Part 2') +ax.bar(categories, values3, bottom=values1+values2, label='Part 3') +ax.legend() +``` + +### Bar Chart with Error Bars +```python +ax.bar(categories, values, yerr=errors, capsize=5, + color='steelblue', edgecolor='black') +``` + +### Bar Chart with Patterns +```python +bars1 = ax.bar(x - width/2, values1, width, label='Group 1', + color='white', edgecolor='black', hatch='//') +bars2 = ax.bar(x + width/2, values2, width, label='Group 2', + color='white', edgecolor='black', hatch='\\\\') +``` + +## 4. Histograms + +**Use cases:** Distributions, frequency analysis + +### Basic Histogram +```python +ax.hist(data, bins=30, edgecolor='black', alpha=0.7) +ax.set_xlabel('Value') +ax.set_ylabel('Frequency') +``` + +### Multiple Overlapping Histograms +```python +ax.hist(data1, bins=30, alpha=0.5, label='Dataset 1') +ax.hist(data2, bins=30, alpha=0.5, label='Dataset 2') +ax.legend() +``` + +### Normalized Histogram (Density) +```python +ax.hist(data, bins=30, density=True, alpha=0.7, + edgecolor='black', label='Empirical') + +# Overlay theoretical distribution +from scipy.stats import norm +x = np.linspace(data.min(), data.max(), 100) +ax.plot(x, norm.pdf(x, data.mean(), data.std()), + 'r-', linewidth=2, label='Normal fit') +ax.legend() +``` + +### 2D Histogram (Hexbin) +```python +hexbin = ax.hexbin(x, y, gridsize=30, cmap='Blues') +plt.colorbar(hexbin, ax=ax, label='Counts') +``` + +### 2D Histogram (hist2d) +```python +h = ax.hist2d(x, y, bins=30, cmap='Blues') +plt.colorbar(h[3], ax=ax, label='Counts') +``` + +## 5. Box and Violin Plots + +**Use cases:** Statistical distributions, outlier detection, comparing distributions + +### Box Plot +```python +ax.boxplot([data1, data2, data3], + labels=['Group A', 'Group B', 'Group C'], + showmeans=True, meanline=True) +ax.set_ylabel('Values') +``` + +### Horizontal Box Plot +```python +ax.boxplot([data1, data2, data3], vert=False, + labels=['Group A', 'Group B', 'Group C']) +ax.set_xlabel('Values') +``` + +### Violin Plot +```python +parts = ax.violinplot([data1, data2, data3], + positions=[1, 2, 3], + showmeans=True, showmedians=True) +ax.set_xticks([1, 2, 3]) +ax.set_xticklabels(['Group A', 'Group B', 'Group C']) +``` + +## 6. Heatmaps + +**Use cases:** Matrix data, correlations, intensity maps + +### Basic Heatmap +```python +im = ax.imshow(matrix, cmap='coolwarm', aspect='auto') +plt.colorbar(im, ax=ax, label='Values') +ax.set_xlabel('X') +ax.set_ylabel('Y') +``` + +### Heatmap with Annotations +```python +im = ax.imshow(matrix, cmap='coolwarm') +plt.colorbar(im, ax=ax) + +# Add text annotations +for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + text = ax.text(j, i, f'{matrix[i, j]:.2f}', + ha='center', va='center', color='black') +``` + +### Correlation Matrix +```python +corr = data.corr() +im = ax.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1) +plt.colorbar(im, ax=ax, label='Correlation') + +# Set tick labels +ax.set_xticks(range(len(corr))) +ax.set_yticks(range(len(corr))) +ax.set_xticklabels(corr.columns, rotation=45, ha='right') +ax.set_yticklabels(corr.columns) +``` + +## 7. Contour Plots + +**Use cases:** 3D data on 2D plane, topography, function visualization + +### Contour Lines +```python +contour = ax.contour(X, Y, Z, levels=10, cmap='viridis') +ax.clabel(contour, inline=True, fontsize=8) +plt.colorbar(contour, ax=ax) +``` + +### Filled Contours +```python +contourf = ax.contourf(X, Y, Z, levels=20, cmap='viridis') +plt.colorbar(contourf, ax=ax) +``` + +### Combined Contours +```python +contourf = ax.contourf(X, Y, Z, levels=20, cmap='viridis', alpha=0.8) +contour = ax.contour(X, Y, Z, levels=10, colors='black', + linewidths=0.5, alpha=0.4) +ax.clabel(contour, inline=True, fontsize=8) +plt.colorbar(contourf, ax=ax) +``` + +## 8. Pie Charts + +**Use cases:** Proportions, percentages (use sparingly) + +### Basic Pie Chart +```python +ax.pie(sizes, labels=labels, autopct='%1.1f%%', + startangle=90, colors=colors) +ax.axis('equal') # Equal aspect ratio ensures circular pie +``` + +### Exploded Pie Chart +```python +explode = (0.1, 0, 0, 0) # Explode first slice +ax.pie(sizes, explode=explode, labels=labels, + autopct='%1.1f%%', shadow=True, startangle=90) +ax.axis('equal') +``` + +### Donut Chart +```python +ax.pie(sizes, labels=labels, autopct='%1.1f%%', + wedgeprops=dict(width=0.5), startangle=90) +ax.axis('equal') +``` + +## 9. Polar Plots + +**Use cases:** Cyclic data, directional data, radar charts + +### Basic Polar Plot +```python +theta = np.linspace(0, 2*np.pi, 100) +r = np.abs(np.sin(2*theta)) + +ax = plt.subplot(111, projection='polar') +ax.plot(theta, r, linewidth=2) +``` + +### Radar Chart +```python +categories = ['A', 'B', 'C', 'D', 'E'] +values = [4, 3, 5, 2, 4] + +# Add first value to the end to close the polygon +angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False) +values_closed = np.concatenate((values, [values[0]])) +angles_closed = np.concatenate((angles, [angles[0]])) + +ax = plt.subplot(111, projection='polar') +ax.plot(angles_closed, values_closed, 'o-', linewidth=2) +ax.fill(angles_closed, values_closed, alpha=0.25) +ax.set_xticks(angles) +ax.set_xticklabels(categories) +``` + +## 10. Stream and Quiver Plots + +**Use cases:** Vector fields, flow visualization + +### Quiver Plot (Vector Field) +```python +ax.quiver(X, Y, U, V, alpha=0.8) +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_aspect('equal') +``` + +### Stream Plot +```python +ax.streamplot(X, Y, U, V, density=1.5, color='k', linewidth=1) +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_aspect('equal') +``` + +## 11. Fill Between + +**Use cases:** Uncertainty bounds, confidence intervals, areas under curves + +### Fill Between Two Curves +```python +ax.plot(x, y, 'k-', linewidth=2, label='Mean') +ax.fill_between(x, y - std, y + std, alpha=0.3, + label='±1 std dev') +ax.legend() +``` + +### Fill Between with Condition +```python +ax.plot(x, y1, label='Line 1') +ax.plot(x, y2, label='Line 2') +ax.fill_between(x, y1, y2, where=(y2 >= y1), + alpha=0.3, label='y2 > y1', interpolate=True) +ax.legend() +``` + +## 12. 3D Plots + +**Use cases:** Three-dimensional data visualization + +### 3D Scatter +```python +from mpl_toolkits.mplot3d import Axes3D + +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +scatter = ax.scatter(x, y, z, c=colors, cmap='viridis', + marker='o', s=50) +plt.colorbar(scatter, ax=ax) +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_zlabel('Z') +``` + +### 3D Surface Plot +```python +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +surf = ax.plot_surface(X, Y, Z, cmap='viridis', + edgecolor='none', alpha=0.9) +plt.colorbar(surf, ax=ax) +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_zlabel('Z') +``` + +### 3D Wireframe +```python +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +ax.plot_wireframe(X, Y, Z, color='black', linewidth=0.5) +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_zlabel('Z') +``` + +### 3D Contour +```python +fig = plt.figure(figsize=(10, 8)) +ax = fig.add_subplot(111, projection='3d') +ax.contour(X, Y, Z, levels=15, cmap='viridis') +ax.set_xlabel('X') +ax.set_ylabel('Y') +ax.set_zlabel('Z') +``` + +## 13. Specialized Plots + +### Stem Plot +```python +ax.stem(x, y, linefmt='C0-', markerfmt='C0o', basefmt='k-') +ax.set_xlabel('X') +ax.set_ylabel('Y') +``` + +### Filled Polygon +```python +vertices = [(0, 0), (1, 0), (1, 1), (0, 1)] +from matplotlib.patches import Polygon +polygon = Polygon(vertices, closed=True, edgecolor='black', + facecolor='lightblue', alpha=0.5) +ax.add_patch(polygon) +ax.set_xlim(-0.5, 1.5) +ax.set_ylim(-0.5, 1.5) +``` + +### Staircase Plot +```python +ax.stairs(values, edges, fill=True, alpha=0.5) +``` + +### Broken Barh (Gantt-style) +```python +ax.broken_barh([(10, 50), (100, 20), (130, 10)], (10, 9), + facecolors='tab:blue') +ax.broken_barh([(10, 20), (50, 50), (120, 30)], (20, 9), + facecolors='tab:orange') +ax.set_ylim(5, 35) +ax.set_xlim(0, 200) +ax.set_xlabel('Time') +ax.set_yticks([15, 25]) +ax.set_yticklabels(['Task 1', 'Task 2']) +``` + +## 14. Time Series Plots + +### Basic Time Series +```python +import pandas as pd +import matplotlib.dates as mdates + +ax.plot(dates, values, linewidth=2) +ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) +ax.xaxis.set_major_locator(mdates.DayLocator(interval=7)) +plt.xticks(rotation=45) +ax.set_xlabel('Date') +ax.set_ylabel('Value') +``` + +### Time Series with Shaded Regions +```python +ax.plot(dates, values, linewidth=2) +# Shade weekends or specific periods +ax.axvspan(start_date, end_date, alpha=0.2, color='gray') +``` + +## Plot Selection Guide + +| Data Type | Recommended Plot | Alternative Options | +|-----------|-----------------|---------------------| +| Single continuous variable | Histogram, KDE | Box plot, Violin plot | +| Two continuous variables | Scatter plot | Hexbin, 2D histogram | +| Time series | Line plot | Area plot, Step plot | +| Categorical vs continuous | Bar chart, Box plot | Violin plot, Strip plot | +| Two categorical variables | Heatmap | Grouped bar chart | +| Three continuous variables | 3D scatter, Contour | Color-coded scatter | +| Proportions | Bar chart | Pie chart (use sparingly) | +| Distributions comparison | Box plot, Violin plot | Overlaid histograms | +| Correlation matrix | Heatmap | Clustered heatmap | +| Vector field | Quiver plot, Stream plot | - | +| Function visualization | Line plot, Contour | 3D surface | diff --git a/scientific-packages/matplotlib/references/styling_guide.md b/scientific-packages/matplotlib/references/styling_guide.md new file mode 100644 index 0000000..8f9fbaf --- /dev/null +++ b/scientific-packages/matplotlib/references/styling_guide.md @@ -0,0 +1,589 @@ +# Matplotlib Styling Guide + +Comprehensive guide for styling and customizing matplotlib visualizations. + +## Colormaps + +### Colormap Categories + +**1. Perceptually Uniform Sequential** +Best for ordered data that progresses from low to high values. +- `viridis` (default, colorblind-friendly) +- `plasma` +- `inferno` +- `magma` +- `cividis` (optimized for colorblind viewers) + +**Usage:** +```python +im = ax.imshow(data, cmap='viridis') +scatter = ax.scatter(x, y, c=values, cmap='plasma') +``` + +**2. Sequential** +Traditional colormaps for ordered data. +- `Blues`, `Greens`, `Reds`, `Oranges`, `Purples` +- `YlOrBr`, `YlOrRd`, `OrRd`, `PuRd` +- `BuPu`, `GnBu`, `PuBu`, `YlGnBu` + +**3. Diverging** +Best for data with a meaningful center point (e.g., zero, mean). +- `coolwarm` (blue to red) +- `RdBu` (red-blue) +- `RdYlBu` (red-yellow-blue) +- `RdYlGn` (red-yellow-green) +- `PiYG`, `PRGn`, `BrBG`, `PuOr`, `RdGy` + +**Usage:** +```python +# Center colormap at zero +im = ax.imshow(data, cmap='coolwarm', vmin=-1, vmax=1) +``` + +**4. Qualitative** +Best for categorical/nominal data without inherent ordering. +- `tab10` (10 distinct colors) +- `tab20` (20 distinct colors) +- `Set1`, `Set2`, `Set3` +- `Pastel1`, `Pastel2` +- `Dark2`, `Accent`, `Paired` + +**Usage:** +```python +colors = plt.cm.tab10(np.linspace(0, 1, n_categories)) +for i, category in enumerate(categories): + ax.plot(x, y[i], color=colors[i], label=category) +``` + +**5. Cyclic** +Best for cyclic data (e.g., phase, angle). +- `twilight` +- `twilight_shifted` +- `hsv` + +### Colormap Best Practices + +1. **Avoid `jet` colormap** - Not perceptually uniform, misleading +2. **Use perceptually uniform colormaps** - `viridis`, `plasma`, `cividis` +3. **Consider colorblind users** - Use `viridis`, `cividis`, or test with colorblind simulators +4. **Match colormap to data type**: + - Sequential: increasing/decreasing data + - Diverging: data with meaningful center + - Qualitative: categories +5. **Reverse colormaps** - Add `_r` suffix: `viridis_r`, `coolwarm_r` + +### Creating Custom Colormaps + +```python +from matplotlib.colors import LinearSegmentedColormap + +# From color list +colors = ['blue', 'white', 'red'] +n_bins = 100 +cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins) + +# From RGB values +colors = [(0, 0, 1), (1, 1, 1), (1, 0, 0)] # RGB tuples +cmap = LinearSegmentedColormap.from_list('custom', colors) + +# Use the custom colormap +ax.imshow(data, cmap=cmap) +``` + +### Discrete Colormaps + +```python +import matplotlib.colors as mcolors + +# Create discrete colormap from continuous +cmap = plt.cm.viridis +bounds = np.linspace(0, 10, 11) +norm = mcolors.BoundaryNorm(bounds, cmap.N) +im = ax.imshow(data, cmap=cmap, norm=norm) +``` + +## Style Sheets + +### Using Built-in Styles + +```python +# List available styles +print(plt.style.available) + +# Apply a style +plt.style.use('seaborn-v0_8-darkgrid') + +# Apply multiple styles (later styles override earlier ones) +plt.style.use(['seaborn-v0_8-whitegrid', 'seaborn-v0_8-poster']) + +# Temporarily use a style +with plt.style.context('ggplot'): + fig, ax = plt.subplots() + ax.plot(x, y) +``` + +### Popular Built-in Styles + +- `default` - Matplotlib's default style +- `classic` - Classic matplotlib look (pre-2.0) +- `seaborn-v0_8-*` - Seaborn-inspired styles + - `seaborn-v0_8-darkgrid`, `seaborn-v0_8-whitegrid` + - `seaborn-v0_8-dark`, `seaborn-v0_8-white` + - `seaborn-v0_8-ticks`, `seaborn-v0_8-poster`, `seaborn-v0_8-talk` +- `ggplot` - ggplot2-inspired style +- `bmh` - Bayesian Methods for Hackers style +- `fivethirtyeight` - FiveThirtyEight style +- `grayscale` - Grayscale style + +### Creating Custom Style Sheets + +Create a file named `custom_style.mplstyle`: + +``` +# custom_style.mplstyle + +# Figure +figure.figsize: 10, 6 +figure.dpi: 100 +figure.facecolor: white + +# Font +font.family: sans-serif +font.sans-serif: Arial, Helvetica +font.size: 12 + +# Axes +axes.labelsize: 14 +axes.titlesize: 16 +axes.facecolor: white +axes.edgecolor: black +axes.linewidth: 1.5 +axes.grid: True +axes.axisbelow: True + +# Grid +grid.color: gray +grid.linestyle: -- +grid.linewidth: 0.5 +grid.alpha: 0.3 + +# Lines +lines.linewidth: 2 +lines.markersize: 8 + +# Ticks +xtick.labelsize: 10 +ytick.labelsize: 10 +xtick.direction: in +ytick.direction: in +xtick.major.size: 6 +ytick.major.size: 6 +xtick.minor.size: 3 +ytick.minor.size: 3 + +# Legend +legend.fontsize: 12 +legend.frameon: True +legend.framealpha: 0.8 +legend.fancybox: True + +# Savefig +savefig.dpi: 300 +savefig.bbox: tight +savefig.facecolor: white +``` + +Load and use: +```python +plt.style.use('path/to/custom_style.mplstyle') +``` + +## rcParams Configuration + +### Global Configuration + +```python +import matplotlib.pyplot as plt + +# Configure globally +plt.rcParams['figure.figsize'] = (10, 6) +plt.rcParams['font.size'] = 12 +plt.rcParams['axes.labelsize'] = 14 + +# Or update multiple at once +plt.rcParams.update({ + 'figure.figsize': (10, 6), + 'font.size': 12, + 'axes.labelsize': 14, + 'axes.titlesize': 16, + 'lines.linewidth': 2 +}) +``` + +### Temporary Configuration + +```python +# Context manager for temporary changes +with plt.rc_context({'font.size': 14, 'lines.linewidth': 2.5}): + fig, ax = plt.subplots() + ax.plot(x, y) +``` + +### Common rcParams + +**Figure settings:** +```python +plt.rcParams['figure.figsize'] = (10, 6) +plt.rcParams['figure.dpi'] = 100 +plt.rcParams['figure.facecolor'] = 'white' +plt.rcParams['figure.edgecolor'] = 'white' +plt.rcParams['figure.autolayout'] = False +plt.rcParams['figure.constrained_layout.use'] = True +``` + +**Font settings:** +```python +plt.rcParams['font.family'] = 'sans-serif' +plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans'] +plt.rcParams['font.size'] = 12 +plt.rcParams['font.weight'] = 'normal' +``` + +**Axes settings:** +```python +plt.rcParams['axes.facecolor'] = 'white' +plt.rcParams['axes.edgecolor'] = 'black' +plt.rcParams['axes.linewidth'] = 1.5 +plt.rcParams['axes.grid'] = True +plt.rcParams['axes.labelsize'] = 14 +plt.rcParams['axes.titlesize'] = 16 +plt.rcParams['axes.labelweight'] = 'normal' +plt.rcParams['axes.spines.top'] = True +plt.rcParams['axes.spines.right'] = True +``` + +**Line settings:** +```python +plt.rcParams['lines.linewidth'] = 2 +plt.rcParams['lines.linestyle'] = '-' +plt.rcParams['lines.marker'] = 'None' +plt.rcParams['lines.markersize'] = 6 +``` + +**Save settings:** +```python +plt.rcParams['savefig.dpi'] = 300 +plt.rcParams['savefig.format'] = 'png' +plt.rcParams['savefig.bbox'] = 'tight' +plt.rcParams['savefig.pad_inches'] = 0.1 +plt.rcParams['savefig.transparent'] = False +``` + +## Color Palettes + +### Named Color Sets + +```python +# Tableau colors +tableau_colors = plt.cm.tab10.colors + +# CSS4 colors (subset) +css_colors = ['steelblue', 'coral', 'teal', 'goldenrod', 'crimson'] + +# Manual definition +custom_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] +``` + +### Color Cycles + +```python +# Set default color cycle +from cycler import cycler +colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'] +plt.rcParams['axes.prop_cycle'] = cycler(color=colors) + +# Or combine color and line style +plt.rcParams['axes.prop_cycle'] = cycler(color=colors) + cycler(linestyle=['-', '--', ':', '-.']) +``` + +### Palette Generation + +```python +# Evenly spaced colors from colormap +n_colors = 5 +colors = plt.cm.viridis(np.linspace(0, 1, n_colors)) + +# Use in plot +for i, (x, y) in enumerate(data): + ax.plot(x, y, color=colors[i]) +``` + +## Typography + +### Font Configuration + +```python +# Set font family +plt.rcParams['font.family'] = 'serif' +plt.rcParams['font.serif'] = ['Times New Roman', 'DejaVu Serif'] + +# Or sans-serif +plt.rcParams['font.family'] = 'sans-serif' +plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica'] + +# Or monospace +plt.rcParams['font.family'] = 'monospace' +plt.rcParams['font.monospace'] = ['Courier New', 'DejaVu Sans Mono'] +``` + +### Font Properties in Text + +```python +from matplotlib import font_manager + +# Specify font properties +ax.text(x, y, 'Text', + fontsize=14, + fontweight='bold', # 'normal', 'bold', 'heavy', 'light' + fontstyle='italic', # 'normal', 'italic', 'oblique' + fontfamily='serif') + +# Use specific font file +prop = font_manager.FontProperties(fname='path/to/font.ttf') +ax.text(x, y, 'Text', fontproperties=prop) +``` + +### Mathematical Text + +```python +# LaTeX-style math +ax.set_title(r'$\alpha > \beta$') +ax.set_xlabel(r'$\mu \pm \sigma$') +ax.text(x, y, r'$\int_0^\infty e^{-x} dx = 1$') + +# Subscripts and superscripts +ax.set_ylabel(r'$y = x^2 + 2x + 1$') +ax.text(x, y, r'$x_1, x_2, \ldots, x_n$') + +# Greek letters +ax.text(x, y, r'$\alpha, \beta, \gamma, \delta, \epsilon$') +``` + +### Using Full LaTeX + +```python +# Enable full LaTeX rendering (requires LaTeX installation) +plt.rcParams['text.usetex'] = True +plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' + +ax.set_title(r'\textbf{Bold Title}') +ax.set_xlabel(r'Time $t$ (s)') +``` + +## Spines and Grids + +### Spine Customization + +```python +# Hide specific spines +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) + +# Move spine position +ax.spines['left'].set_position(('outward', 10)) +ax.spines['bottom'].set_position(('data', 0)) + +# Change spine color and width +ax.spines['left'].set_color('red') +ax.spines['bottom'].set_linewidth(2) +``` + +### Grid Customization + +```python +# Basic grid +ax.grid(True) + +# Customized grid +ax.grid(True, which='major', linestyle='--', linewidth=0.8, alpha=0.3) +ax.grid(True, which='minor', linestyle=':', linewidth=0.5, alpha=0.2) + +# Grid for specific axis +ax.grid(True, axis='x') # Only vertical lines +ax.grid(True, axis='y') # Only horizontal lines + +# Grid behind or in front of data +ax.set_axisbelow(True) # Grid behind data +``` + +## Legend Customization + +### Legend Positioning + +```python +# Location strings +ax.legend(loc='best') # Automatic best position +ax.legend(loc='upper right') +ax.legend(loc='upper left') +ax.legend(loc='lower right') +ax.legend(loc='lower left') +ax.legend(loc='center') +ax.legend(loc='upper center') +ax.legend(loc='lower center') +ax.legend(loc='center left') +ax.legend(loc='center right') + +# Precise positioning (bbox_to_anchor) +ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Outside plot area +ax.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=3) # Below plot +``` + +### Legend Styling + +```python +ax.legend( + fontsize=12, + frameon=True, # Show frame + framealpha=0.9, # Frame transparency + fancybox=True, # Rounded corners + shadow=True, # Shadow effect + ncol=2, # Number of columns + title='Legend Title', # Legend title + title_fontsize=14, # Title font size + edgecolor='black', # Frame edge color + facecolor='white' # Frame background color +) +``` + +### Custom Legend Entries + +```python +from matplotlib.lines import Line2D + +# Create custom legend handles +custom_lines = [Line2D([0], [0], color='red', lw=2), + Line2D([0], [0], color='blue', lw=2, linestyle='--'), + Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10)] + +ax.legend(custom_lines, ['Label 1', 'Label 2', 'Label 3']) +``` + +## Layout and Spacing + +### Constrained Layout + +```python +# Preferred method (automatic adjustment) +fig, axes = plt.subplots(2, 2, constrained_layout=True) +``` + +### Tight Layout + +```python +# Alternative method +fig, axes = plt.subplots(2, 2) +plt.tight_layout(pad=1.5, h_pad=2.0, w_pad=2.0) +``` + +### Manual Adjustment + +```python +# Fine-grained control +plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, + hspace=0.3, wspace=0.4) +``` + +## Professional Publication Style + +Example configuration for publication-quality figures: + +```python +# Publication style configuration +plt.rcParams.update({ + # Figure + 'figure.figsize': (8, 6), + 'figure.dpi': 100, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'savefig.pad_inches': 0.1, + + # Font + 'font.family': 'sans-serif', + 'font.sans-serif': ['Arial', 'Helvetica'], + 'font.size': 11, + + # Axes + 'axes.labelsize': 12, + 'axes.titlesize': 14, + 'axes.linewidth': 1.5, + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, + + # Lines + 'lines.linewidth': 2, + 'lines.markersize': 8, + + # Ticks + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'xtick.major.size': 6, + 'ytick.major.size': 6, + 'xtick.major.width': 1.5, + 'ytick.major.width': 1.5, + 'xtick.direction': 'in', + 'ytick.direction': 'in', + + # Legend + 'legend.fontsize': 10, + 'legend.frameon': True, + 'legend.framealpha': 1.0, + 'legend.edgecolor': 'black' +}) +``` + +## Dark Theme + +```python +# Dark background style +plt.style.use('dark_background') + +# Or manual configuration +plt.rcParams.update({ + 'figure.facecolor': '#1e1e1e', + 'axes.facecolor': '#1e1e1e', + 'axes.edgecolor': 'white', + 'axes.labelcolor': 'white', + 'text.color': 'white', + 'xtick.color': 'white', + 'ytick.color': 'white', + 'grid.color': 'gray', + 'legend.facecolor': '#1e1e1e', + 'legend.edgecolor': 'white' +}) +``` + +## Color Accessibility + +### Colorblind-Friendly Palettes + +```python +# Use colorblind-friendly colormaps +colorblind_friendly = ['viridis', 'plasma', 'cividis'] + +# Colorblind-friendly discrete colors +cb_colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', + '#CA9161', '#949494', '#ECE133', '#56B4E9'] + +# Test with simulation tools or use these validated palettes +``` + +### High Contrast + +```python +# Ensure sufficient contrast +plt.rcParams['axes.edgecolor'] = 'black' +plt.rcParams['axes.linewidth'] = 2 +plt.rcParams['xtick.major.width'] = 2 +plt.rcParams['ytick.major.width'] = 2 +``` diff --git a/scientific-packages/matplotlib/scripts/plot_template.py b/scientific-packages/matplotlib/scripts/plot_template.py new file mode 100644 index 0000000..88721c1 --- /dev/null +++ b/scientific-packages/matplotlib/scripts/plot_template.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +""" +Matplotlib Plot Template + +Comprehensive template demonstrating various plot types and best practices. +Use this as a starting point for creating publication-quality visualizations. + +Usage: + python plot_template.py [--plot-type TYPE] [--style STYLE] [--output FILE] + +Plot types: + line, scatter, bar, histogram, heatmap, contour, box, violin, 3d, all +""" + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec +import argparse + + +def set_publication_style(): + """Configure matplotlib for publication-quality figures.""" + plt.rcParams.update({ + 'figure.figsize': (10, 6), + 'figure.dpi': 100, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 14, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'lines.linewidth': 2, + 'axes.linewidth': 1.5, + }) + + +def generate_sample_data(): + """Generate sample data for demonstrations.""" + np.random.seed(42) + x = np.linspace(0, 10, 100) + y1 = np.sin(x) + y2 = np.cos(x) + scatter_x = np.random.randn(200) + scatter_y = np.random.randn(200) + categories = ['A', 'B', 'C', 'D', 'E'] + bar_values = np.random.randint(10, 100, len(categories)) + hist_data = np.random.normal(0, 1, 1000) + matrix = np.random.rand(10, 10) + + X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100)) + Z = np.sin(np.sqrt(X**2 + Y**2)) + + return { + 'x': x, 'y1': y1, 'y2': y2, + 'scatter_x': scatter_x, 'scatter_y': scatter_y, + 'categories': categories, 'bar_values': bar_values, + 'hist_data': hist_data, 'matrix': matrix, + 'X': X, 'Y': Y, 'Z': Z + } + + +def create_line_plot(data, ax=None): + """Create line plot with best practices.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + ax.plot(data['x'], data['y1'], label='sin(x)', linewidth=2, marker='o', + markevery=10, markersize=6) + ax.plot(data['x'], data['y2'], label='cos(x)', linewidth=2, linestyle='--') + + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_title('Line Plot Example') + ax.legend(loc='best', framealpha=0.9) + ax.grid(True, alpha=0.3, linestyle='--') + + # Remove top and right spines for cleaner look + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + if ax is None: + return fig + return ax + + +def create_scatter_plot(data, ax=None): + """Create scatter plot with color and size variations.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + # Color based on distance from origin + colors = np.sqrt(data['scatter_x']**2 + data['scatter_y']**2) + sizes = 50 * (1 + np.abs(data['scatter_x'])) + + scatter = ax.scatter(data['scatter_x'], data['scatter_y'], + c=colors, s=sizes, alpha=0.6, + cmap='viridis', edgecolors='black', linewidth=0.5) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_title('Scatter Plot Example') + ax.grid(True, alpha=0.3, linestyle='--') + + # Add colorbar + cbar = plt.colorbar(scatter, ax=ax) + cbar.set_label('Distance from origin') + + if ax is None: + return fig + return ax + + +def create_bar_chart(data, ax=None): + """Create bar chart with error bars and styling.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + x_pos = np.arange(len(data['categories'])) + errors = np.random.randint(5, 15, len(data['categories'])) + + bars = ax.bar(x_pos, data['bar_values'], yerr=errors, + color='steelblue', edgecolor='black', linewidth=1.5, + capsize=5, alpha=0.8) + + # Color bars by value + colors = plt.cm.viridis(data['bar_values'] / data['bar_values'].max()) + for bar, color in zip(bars, colors): + bar.set_facecolor(color) + + ax.set_xlabel('Category') + ax.set_ylabel('Values') + ax.set_title('Bar Chart Example') + ax.set_xticks(x_pos) + ax.set_xticklabels(data['categories']) + ax.grid(True, axis='y', alpha=0.3, linestyle='--') + + # Remove top and right spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + if ax is None: + return fig + return ax + + +def create_histogram(data, ax=None): + """Create histogram with density overlay.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + n, bins, patches = ax.hist(data['hist_data'], bins=30, density=True, + alpha=0.7, edgecolor='black', color='steelblue') + + # Overlay theoretical normal distribution + from scipy.stats import norm + mu, std = norm.fit(data['hist_data']) + x_theory = np.linspace(data['hist_data'].min(), data['hist_data'].max(), 100) + ax.plot(x_theory, norm.pdf(x_theory, mu, std), 'r-', linewidth=2, + label=f'Normal fit (μ={mu:.2f}, σ={std:.2f})') + + ax.set_xlabel('Value') + ax.set_ylabel('Density') + ax.set_title('Histogram with Normal Fit') + ax.legend() + ax.grid(True, axis='y', alpha=0.3, linestyle='--') + + if ax is None: + return fig + return ax + + +def create_heatmap(data, ax=None): + """Create heatmap with colorbar and annotations.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True) + + im = ax.imshow(data['matrix'], cmap='coolwarm', aspect='auto', + vmin=0, vmax=1) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label('Value') + + # Optional: Add text annotations + # for i in range(data['matrix'].shape[0]): + # for j in range(data['matrix'].shape[1]): + # text = ax.text(j, i, f'{data["matrix"][i, j]:.2f}', + # ha='center', va='center', color='black', fontsize=8) + + ax.set_xlabel('X Index') + ax.set_ylabel('Y Index') + ax.set_title('Heatmap Example') + + if ax is None: + return fig + return ax + + +def create_contour_plot(data, ax=None): + """Create contour plot with filled contours and labels.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True) + + # Filled contours + contourf = ax.contourf(data['X'], data['Y'], data['Z'], + levels=20, cmap='viridis', alpha=0.8) + + # Contour lines + contour = ax.contour(data['X'], data['Y'], data['Z'], + levels=10, colors='black', linewidths=0.5, alpha=0.4) + + # Add labels to contour lines + ax.clabel(contour, inline=True, fontsize=8) + + # Add colorbar + cbar = plt.colorbar(contourf, ax=ax) + cbar.set_label('Z value') + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_title('Contour Plot Example') + ax.set_aspect('equal') + + if ax is None: + return fig + return ax + + +def create_box_plot(data, ax=None): + """Create box plot comparing distributions.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + # Generate multiple distributions + box_data = [np.random.normal(0, std, 100) for std in range(1, 5)] + + bp = ax.boxplot(box_data, labels=['Group 1', 'Group 2', 'Group 3', 'Group 4'], + patch_artist=True, showmeans=True, + boxprops=dict(facecolor='lightblue', edgecolor='black'), + medianprops=dict(color='red', linewidth=2), + meanprops=dict(marker='D', markerfacecolor='green', markersize=8)) + + ax.set_xlabel('Groups') + ax.set_ylabel('Values') + ax.set_title('Box Plot Example') + ax.grid(True, axis='y', alpha=0.3, linestyle='--') + + if ax is None: + return fig + return ax + + +def create_violin_plot(data, ax=None): + """Create violin plot showing distribution shapes.""" + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6), constrained_layout=True) + + # Generate multiple distributions + violin_data = [np.random.normal(0, std, 100) for std in range(1, 5)] + + parts = ax.violinplot(violin_data, positions=range(1, 5), + showmeans=True, showmedians=True) + + # Customize colors + for pc in parts['bodies']: + pc.set_facecolor('lightblue') + pc.set_alpha(0.7) + pc.set_edgecolor('black') + + ax.set_xlabel('Groups') + ax.set_ylabel('Values') + ax.set_title('Violin Plot Example') + ax.set_xticks(range(1, 5)) + ax.set_xticklabels(['Group 1', 'Group 2', 'Group 3', 'Group 4']) + ax.grid(True, axis='y', alpha=0.3, linestyle='--') + + if ax is None: + return fig + return ax + + +def create_3d_plot(): + """Create 3D surface plot.""" + from mpl_toolkits.mplot3d import Axes3D + + fig = plt.figure(figsize=(12, 9)) + ax = fig.add_subplot(111, projection='3d') + + # Generate data + X = np.linspace(-5, 5, 50) + Y = np.linspace(-5, 5, 50) + X, Y = np.meshgrid(X, Y) + Z = np.sin(np.sqrt(X**2 + Y**2)) + + # Create surface plot + surf = ax.plot_surface(X, Y, Z, cmap='viridis', + edgecolor='none', alpha=0.9) + + # Add colorbar + fig.colorbar(surf, ax=ax, shrink=0.5) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_title('3D Surface Plot Example') + + # Set viewing angle + ax.view_init(elev=30, azim=45) + + plt.tight_layout() + return fig + + +def create_comprehensive_figure(): + """Create a comprehensive figure with multiple subplots.""" + data = generate_sample_data() + + fig = plt.figure(figsize=(16, 12), constrained_layout=True) + gs = GridSpec(3, 3, figure=fig) + + # Create subplots + ax1 = fig.add_subplot(gs[0, :2]) # Line plot - top left, spans 2 columns + create_line_plot(data, ax1) + + ax2 = fig.add_subplot(gs[0, 2]) # Bar chart - top right + create_bar_chart(data, ax2) + + ax3 = fig.add_subplot(gs[1, 0]) # Scatter plot - middle left + create_scatter_plot(data, ax3) + + ax4 = fig.add_subplot(gs[1, 1]) # Histogram - middle center + create_histogram(data, ax4) + + ax5 = fig.add_subplot(gs[1, 2]) # Box plot - middle right + create_box_plot(data, ax5) + + ax6 = fig.add_subplot(gs[2, :2]) # Contour plot - bottom left, spans 2 columns + create_contour_plot(data, ax6) + + ax7 = fig.add_subplot(gs[2, 2]) # Heatmap - bottom right + create_heatmap(data, ax7) + + fig.suptitle('Comprehensive Matplotlib Template', fontsize=18, fontweight='bold') + + return fig + + +def main(): + """Main function to run the template.""" + parser = argparse.ArgumentParser(description='Matplotlib plot template') + parser.add_argument('--plot-type', type=str, default='all', + choices=['line', 'scatter', 'bar', 'histogram', 'heatmap', + 'contour', 'box', 'violin', '3d', 'all'], + help='Type of plot to create') + parser.add_argument('--style', type=str, default='default', + help='Matplotlib style to use') + parser.add_argument('--output', type=str, default='plot.png', + help='Output filename') + + args = parser.parse_args() + + # Set style + if args.style != 'default': + plt.style.use(args.style) + else: + set_publication_style() + + # Generate data + data = generate_sample_data() + + # Create plot based on type + plot_functions = { + 'line': create_line_plot, + 'scatter': create_scatter_plot, + 'bar': create_bar_chart, + 'histogram': create_histogram, + 'heatmap': create_heatmap, + 'contour': create_contour_plot, + 'box': create_box_plot, + 'violin': create_violin_plot, + } + + if args.plot_type == '3d': + fig = create_3d_plot() + elif args.plot_type == 'all': + fig = create_comprehensive_figure() + else: + fig = plot_functions[args.plot_type](data) + + # Save figure + plt.savefig(args.output, dpi=300, bbox_inches='tight') + print(f"Plot saved to {args.output}") + + # Display + plt.show() + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/matplotlib/scripts/style_configurator.py b/scientific-packages/matplotlib/scripts/style_configurator.py new file mode 100644 index 0000000..1a0aca2 --- /dev/null +++ b/scientific-packages/matplotlib/scripts/style_configurator.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +""" +Matplotlib Style Configurator + +Interactive utility to configure matplotlib style preferences and generate +custom style sheets. Creates a preview of the style and optionally saves +it as a .mplstyle file. + +Usage: + python style_configurator.py [--preset PRESET] [--output FILE] [--preview] + +Presets: + publication, presentation, web, dark, minimal +""" + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec +import argparse +import os + + +# Predefined style presets +STYLE_PRESETS = { + 'publication': { + 'figure.figsize': (8, 6), + 'figure.dpi': 100, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'font.family': 'sans-serif', + 'font.sans-serif': ['Arial', 'Helvetica'], + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 14, + 'axes.linewidth': 1.5, + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, + 'lines.linewidth': 2, + 'lines.markersize': 8, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'xtick.direction': 'in', + 'ytick.direction': 'in', + 'xtick.major.size': 6, + 'ytick.major.size': 6, + 'xtick.major.width': 1.5, + 'ytick.major.width': 1.5, + 'legend.fontsize': 10, + 'legend.frameon': True, + 'legend.framealpha': 1.0, + 'legend.edgecolor': 'black', + }, + 'presentation': { + 'figure.figsize': (12, 8), + 'figure.dpi': 100, + 'savefig.dpi': 150, + 'font.size': 16, + 'axes.labelsize': 20, + 'axes.titlesize': 24, + 'axes.linewidth': 2, + 'lines.linewidth': 3, + 'lines.markersize': 12, + 'xtick.labelsize': 16, + 'ytick.labelsize': 16, + 'legend.fontsize': 16, + 'axes.grid': True, + 'grid.alpha': 0.3, + }, + 'web': { + 'figure.figsize': (10, 6), + 'figure.dpi': 96, + 'savefig.dpi': 150, + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 14, + 'lines.linewidth': 2, + 'axes.grid': True, + 'grid.alpha': 0.2, + 'grid.linestyle': '--', + }, + 'dark': { + 'figure.facecolor': '#1e1e1e', + 'figure.edgecolor': '#1e1e1e', + 'axes.facecolor': '#1e1e1e', + 'axes.edgecolor': 'white', + 'axes.labelcolor': 'white', + 'text.color': 'white', + 'xtick.color': 'white', + 'ytick.color': 'white', + 'grid.color': 'gray', + 'grid.alpha': 0.3, + 'axes.grid': True, + 'legend.facecolor': '#1e1e1e', + 'legend.edgecolor': 'white', + 'savefig.facecolor': '#1e1e1e', + }, + 'minimal': { + 'figure.figsize': (10, 6), + 'axes.spines.top': False, + 'axes.spines.right': False, + 'axes.spines.left': False, + 'axes.spines.bottom': False, + 'axes.grid': False, + 'xtick.bottom': True, + 'ytick.left': True, + 'axes.axisbelow': True, + 'lines.linewidth': 2.5, + 'font.size': 12, + } +} + + +def generate_preview_data(): + """Generate sample data for style preview.""" + np.random.seed(42) + x = np.linspace(0, 10, 100) + y1 = np.sin(x) + 0.1 * np.random.randn(100) + y2 = np.cos(x) + 0.1 * np.random.randn(100) + scatter_x = np.random.randn(100) + scatter_y = 2 * scatter_x + np.random.randn(100) + categories = ['A', 'B', 'C', 'D', 'E'] + bar_values = [25, 40, 30, 55, 45] + + return { + 'x': x, 'y1': y1, 'y2': y2, + 'scatter_x': scatter_x, 'scatter_y': scatter_y, + 'categories': categories, 'bar_values': bar_values + } + + +def create_style_preview(style_dict=None): + """Create a preview figure demonstrating the style.""" + if style_dict: + plt.rcParams.update(style_dict) + + data = generate_preview_data() + + fig = plt.figure(figsize=(14, 10)) + gs = GridSpec(2, 2, figure=fig, hspace=0.3, wspace=0.3) + + # Line plot + ax1 = fig.add_subplot(gs[0, 0]) + ax1.plot(data['x'], data['y1'], label='sin(x)', marker='o', markevery=10) + ax1.plot(data['x'], data['y2'], label='cos(x)', linestyle='--') + ax1.set_xlabel('X axis') + ax1.set_ylabel('Y axis') + ax1.set_title('Line Plot') + ax1.legend() + ax1.grid(True, alpha=0.3) + + # Scatter plot + ax2 = fig.add_subplot(gs[0, 1]) + colors = np.sqrt(data['scatter_x']**2 + data['scatter_y']**2) + scatter = ax2.scatter(data['scatter_x'], data['scatter_y'], + c=colors, cmap='viridis', alpha=0.6, s=50) + ax2.set_xlabel('X axis') + ax2.set_ylabel('Y axis') + ax2.set_title('Scatter Plot') + cbar = plt.colorbar(scatter, ax=ax2) + cbar.set_label('Distance') + ax2.grid(True, alpha=0.3) + + # Bar chart + ax3 = fig.add_subplot(gs[1, 0]) + bars = ax3.bar(data['categories'], data['bar_values'], + edgecolor='black', linewidth=1) + # Color bars with gradient + colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(bars))) + for bar, color in zip(bars, colors): + bar.set_facecolor(color) + ax3.set_xlabel('Categories') + ax3.set_ylabel('Values') + ax3.set_title('Bar Chart') + ax3.grid(True, axis='y', alpha=0.3) + + # Multiple line plot with fills + ax4 = fig.add_subplot(gs[1, 1]) + ax4.plot(data['x'], data['y1'], label='Signal 1', linewidth=2) + ax4.fill_between(data['x'], data['y1'] - 0.2, data['y1'] + 0.2, + alpha=0.3, label='±1 std') + ax4.plot(data['x'], data['y2'], label='Signal 2', linewidth=2) + ax4.fill_between(data['x'], data['y2'] - 0.2, data['y2'] + 0.2, + alpha=0.3) + ax4.set_xlabel('X axis') + ax4.set_ylabel('Y axis') + ax4.set_title('Time Series with Uncertainty') + ax4.legend() + ax4.grid(True, alpha=0.3) + + fig.suptitle('Style Preview', fontsize=16, fontweight='bold') + + return fig + + +def save_style_file(style_dict, filename): + """Save style dictionary as .mplstyle file.""" + with open(filename, 'w') as f: + f.write("# Custom matplotlib style\n") + f.write("# Generated by style_configurator.py\n\n") + + # Group settings by category + categories = { + 'Figure': ['figure.'], + 'Font': ['font.'], + 'Axes': ['axes.'], + 'Lines': ['lines.'], + 'Markers': ['markers.'], + 'Ticks': ['tick.', 'xtick.', 'ytick.'], + 'Grid': ['grid.'], + 'Legend': ['legend.'], + 'Savefig': ['savefig.'], + 'Text': ['text.'], + } + + for category, prefixes in categories.items(): + category_items = {k: v for k, v in style_dict.items() + if any(k.startswith(p) for p in prefixes)} + if category_items: + f.write(f"# {category}\n") + for key, value in sorted(category_items.items()): + # Format value appropriately + if isinstance(value, (list, tuple)): + value_str = ', '.join(str(v) for v in value) + elif isinstance(value, bool): + value_str = str(value) + else: + value_str = str(value) + f.write(f"{key}: {value_str}\n") + f.write("\n") + + print(f"Style saved to {filename}") + + +def print_style_info(style_dict): + """Print information about the style.""" + print("\n" + "="*60) + print("STYLE CONFIGURATION") + print("="*60) + + categories = { + 'Figure Settings': ['figure.'], + 'Font Settings': ['font.'], + 'Axes Settings': ['axes.'], + 'Line Settings': ['lines.'], + 'Grid Settings': ['grid.'], + 'Legend Settings': ['legend.'], + } + + for category, prefixes in categories.items(): + category_items = {k: v for k, v in style_dict.items() + if any(k.startswith(p) for p in prefixes)} + if category_items: + print(f"\n{category}:") + for key, value in sorted(category_items.items()): + print(f" {key}: {value}") + + print("\n" + "="*60 + "\n") + + +def list_available_presets(): + """Print available style presets.""" + print("\nAvailable style presets:") + print("-" * 40) + descriptions = { + 'publication': 'Optimized for academic publications', + 'presentation': 'Large fonts for presentations', + 'web': 'Optimized for web display', + 'dark': 'Dark background theme', + 'minimal': 'Minimal, clean style', + } + for preset, desc in descriptions.items(): + print(f" {preset:15s} - {desc}") + print("-" * 40 + "\n") + + +def interactive_mode(): + """Run interactive mode to customize style settings.""" + print("\n" + "="*60) + print("MATPLOTLIB STYLE CONFIGURATOR - Interactive Mode") + print("="*60) + + list_available_presets() + + preset = input("Choose a preset to start from (or 'custom' for default): ").strip().lower() + + if preset in STYLE_PRESETS: + style_dict = STYLE_PRESETS[preset].copy() + print(f"\nStarting from '{preset}' preset") + else: + style_dict = {} + print("\nStarting from default matplotlib style") + + print("\nCommon settings you might want to customize:") + print(" 1. Figure size") + print(" 2. Font sizes") + print(" 3. Line widths") + print(" 4. Grid settings") + print(" 5. Color scheme") + print(" 6. Done, show preview") + + while True: + choice = input("\nSelect option (1-6): ").strip() + + if choice == '1': + width = input(" Figure width (inches, default 10): ").strip() or '10' + height = input(" Figure height (inches, default 6): ").strip() or '6' + style_dict['figure.figsize'] = (float(width), float(height)) + + elif choice == '2': + base = input(" Base font size (default 12): ").strip() or '12' + style_dict['font.size'] = float(base) + style_dict['axes.labelsize'] = float(base) + 2 + style_dict['axes.titlesize'] = float(base) + 4 + + elif choice == '3': + lw = input(" Line width (default 2): ").strip() or '2' + style_dict['lines.linewidth'] = float(lw) + + elif choice == '4': + grid = input(" Enable grid? (y/n): ").strip().lower() + style_dict['axes.grid'] = grid == 'y' + if style_dict['axes.grid']: + alpha = input(" Grid transparency (0-1, default 0.3): ").strip() or '0.3' + style_dict['grid.alpha'] = float(alpha) + + elif choice == '5': + print(" Theme options: 1=Light, 2=Dark") + theme = input(" Select theme (1-2): ").strip() + if theme == '2': + style_dict.update(STYLE_PRESETS['dark']) + + elif choice == '6': + break + + return style_dict + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser( + description='Matplotlib style configurator', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Show available presets + python style_configurator.py --list + + # Preview a preset + python style_configurator.py --preset publication --preview + + # Save a preset as .mplstyle file + python style_configurator.py --preset publication --output my_style.mplstyle + + # Interactive mode + python style_configurator.py --interactive + """ + ) + parser.add_argument('--preset', type=str, choices=list(STYLE_PRESETS.keys()), + help='Use a predefined style preset') + parser.add_argument('--output', type=str, + help='Save style to .mplstyle file') + parser.add_argument('--preview', action='store_true', + help='Show style preview') + parser.add_argument('--list', action='store_true', + help='List available presets') + parser.add_argument('--interactive', action='store_true', + help='Run in interactive mode') + + args = parser.parse_args() + + if args.list: + list_available_presets() + # Also show currently available matplotlib styles + print("\nBuilt-in matplotlib styles:") + print("-" * 40) + for style in sorted(plt.style.available): + print(f" {style}") + return + + if args.interactive: + style_dict = interactive_mode() + elif args.preset: + style_dict = STYLE_PRESETS[args.preset].copy() + print(f"Using '{args.preset}' preset") + else: + print("No preset or interactive mode specified. Showing default preview.") + style_dict = {} + + if style_dict: + print_style_info(style_dict) + + if args.output: + save_style_file(style_dict, args.output) + + if args.preview or args.interactive: + print("Creating style preview...") + fig = create_style_preview(style_dict if style_dict else None) + + if args.output: + preview_filename = args.output.replace('.mplstyle', '_preview.png') + plt.savefig(preview_filename, dpi=150, bbox_inches='tight') + print(f"Preview saved to {preview_filename}") + + plt.show() + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/medchem/SKILL.md b/scientific-packages/medchem/SKILL.md new file mode 100644 index 0000000..1f76c24 --- /dev/null +++ b/scientific-packages/medchem/SKILL.md @@ -0,0 +1,398 @@ +--- +name: medchem +description: Python library for molecular filtering and prioritization in drug discovery. Use when applying medicinal chemistry rules (Rule of Five, CNS, leadlike), detecting structural alerts (PAINS, NIBR, Lilly demerits), analyzing chemical groups, calculating molecular complexity, or filtering compound libraries. Works with SMILES strings and RDKit mol objects, with built-in parallelization for large datasets. +--- + +# Medchem + +## Overview + +Medchem is a Python library for molecular filtering and prioritization in drug discovery workflows. It provides hundreds of well-established and novel molecular filters, structural alerts, and medicinal chemistry rules to efficiently triage and prioritize compound libraries at scale. + +**Key Principle:** Rules and filters are always context-specific. Avoid blindly applying filters—marketed drugs often don't pass standard medchem filters, and prodrugs may intentionally violate rules. Use these tools as guidelines combined with domain expertise. + +## Installation + +Install medchem via conda or pip: + +```bash +# Via conda +micromamba install -c conda-forge medchem + +# Via pip +pip install medchem +``` + +## Core Capabilities + +### 1. Medicinal Chemistry Rules + +Apply established drug-likeness rules to molecules using the `medchem.rules` module. + +**Available Rules:** +- Rule of Five (Lipinski) +- Rule of Oprea +- Rule of CNS +- Rule of leadlike (soft and strict) +- Rule of three +- Rule of Reos +- Rule of drug +- Rule of Veber +- Golden triangle +- PAINS filters + +**Single Rule Application:** + +```python +import medchem as mc + +# Apply Rule of Five to a SMILES string +smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin +passes = mc.rules.basic_rules.rule_of_five(smiles) +# Returns: True + +# Check specific rules +passes_oprea = mc.rules.basic_rules.rule_of_oprea(smiles) +passes_cns = mc.rules.basic_rules.rule_of_cns(smiles) +``` + +**Multiple Rules with RuleFilters:** + +```python +import datamol as dm +import medchem as mc + +# Load molecules +mols = [dm.to_mol(smiles) for smiles in smiles_list] + +# Create filter with multiple rules +rfilter = mc.rules.RuleFilters( + rule_list=[ + "rule_of_five", + "rule_of_oprea", + "rule_of_cns", + "rule_of_leadlike_soft" + ] +) + +# Apply filters with parallelization +results = rfilter( + mols=mols, + n_jobs=-1, # Use all CPU cores + progress=True +) +``` + +**Result Format:** +Results are returned as dictionaries with pass/fail status and detailed information for each rule. + +### 2. Structural Alert Filters + +Detect potentially problematic structural patterns using the `medchem.structural` module. + +**Available Filters:** + +1. **Common Alerts** - General structural alerts derived from ChEMBL curation and literature +2. **NIBR Filters** - Novartis Institutes for BioMedical Research filter set +3. **Lilly Demerits** - Eli Lilly's demerit-based system (275 rules, molecules rejected at >100 demerits) + +**Common Alerts:** + +```python +import medchem as mc + +# Create filter +alert_filter = mc.structural.CommonAlertsFilters() + +# Check single molecule +mol = dm.to_mol("c1ccccc1") +has_alerts, details = alert_filter.check_mol(mol) + +# Batch filtering with parallelization +results = alert_filter( + mols=mol_list, + n_jobs=-1, + progress=True +) +``` + +**NIBR Filters:** + +```python +import medchem as mc + +# Apply NIBR filters +nibr_filter = mc.structural.NIBRFilters() +results = nibr_filter(mols=mol_list, n_jobs=-1) +``` + +**Lilly Demerits:** + +```python +import medchem as mc + +# Calculate Lilly demerits +lilly = mc.structural.LillyDemeritsFilters() +results = lilly(mols=mol_list, n_jobs=-1) + +# Each result includes demerit score and whether it passes (≤100 demerits) +``` + +### 3. Functional API for High-Level Operations + +The `medchem.functional` module provides convenient functions for common workflows. + +**Quick Filtering:** + +```python +import medchem as mc + +# Apply NIBR filters to a list +filter_ok = mc.functional.nibr_filter( + mols=mol_list, + n_jobs=-1 +) + +# Apply common alerts +alert_results = mc.functional.common_alerts_filter( + mols=mol_list, + n_jobs=-1 +) +``` + +### 4. Chemical Groups Detection + +Identify specific chemical groups and functional groups using `medchem.groups`. + +**Available Groups:** +- Hinge binders +- Phosphate binders +- Michael acceptors +- Reactive groups +- Custom SMARTS patterns + +**Usage:** + +```python +import medchem as mc + +# Create group detector +group = mc.groups.ChemicalGroup(groups=["hinge_binders"]) + +# Check for matches +has_matches = group.has_match(mol_list) + +# Get detailed match information +matches = group.get_matches(mol) +``` + +### 5. Named Catalogs + +Access curated collections of chemical structures through `medchem.catalogs`. + +**Available Catalogs:** +- Functional groups +- Protecting groups +- Common reagents +- Standard fragments + +**Usage:** + +```python +import medchem as mc + +# Access named catalogs +catalogs = mc.catalogs.NamedCatalogs + +# Use catalog for matching +catalog = catalogs.get("functional_groups") +matches = catalog.get_matches(mol) +``` + +### 6. Molecular Complexity + +Calculate complexity metrics that approximate synthetic accessibility using `medchem.complexity`. + +**Common Metrics:** +- Bertz complexity +- Whitlock complexity +- Barone complexity + +**Usage:** + +```python +import medchem as mc + +# Calculate complexity +complexity_score = mc.complexity.calculate_complexity(mol) + +# Filter by complexity threshold +complex_filter = mc.complexity.ComplexityFilter(max_complexity=500) +results = complex_filter(mols=mol_list) +``` + +### 7. Constraints Filtering + +Apply custom property-based constraints using `medchem.constraints`. + +**Example Constraints:** +- Molecular weight ranges +- LogP bounds +- TPSA limits +- Rotatable bond counts + +**Usage:** + +```python +import medchem as mc + +# Define constraints +constraints = mc.constraints.Constraints( + mw_range=(200, 500), + logp_range=(-2, 5), + tpsa_max=140, + rotatable_bonds_max=10 +) + +# Apply constraints +results = constraints(mols=mol_list, n_jobs=-1) +``` + +### 8. Medchem Query Language + +Use a specialized query language for complex filtering criteria. + +**Query Examples:** +``` +# Molecules passing Ro5 AND not having common alerts +"rule_of_five AND NOT common_alerts" + +# CNS-like molecules with low complexity +"rule_of_cns AND complexity < 400" + +# Leadlike molecules without Lilly demerits +"rule_of_leadlike AND lilly_demerits == 0" +``` + +**Usage:** + +```python +import medchem as mc + +# Parse and apply query +query = mc.query.parse("rule_of_five AND NOT common_alerts") +results = query.apply(mols=mol_list, n_jobs=-1) +``` + +## Workflow Patterns + +### Pattern 1: Initial Triage of Compound Library + +Filter a large compound collection to identify drug-like candidates. + +```python +import datamol as dm +import medchem as mc +import pandas as pd + +# Load compound library +df = pd.read_csv("compounds.csv") +mols = [dm.to_mol(smi) for smi in df["smiles"]] + +# Apply primary filters +rule_filter = mc.rules.RuleFilters(rule_list=["rule_of_five", "rule_of_veber"]) +rule_results = rule_filter(mols=mols, n_jobs=-1, progress=True) + +# Apply structural alerts +alert_filter = mc.structural.CommonAlertsFilters() +alert_results = alert_filter(mols=mols, n_jobs=-1, progress=True) + +# Combine results +df["passes_rules"] = rule_results["pass"] +df["has_alerts"] = alert_results["has_alerts"] +df["drug_like"] = df["passes_rules"] & ~df["has_alerts"] + +# Save filtered compounds +filtered_df = df[df["drug_like"]] +filtered_df.to_csv("filtered_compounds.csv", index=False) +``` + +### Pattern 2: Lead Optimization Filtering + +Apply stricter criteria during lead optimization. + +```python +import medchem as mc + +# Create comprehensive filter +filters = { + "rules": mc.rules.RuleFilters(rule_list=["rule_of_leadlike_strict"]), + "alerts": mc.structural.NIBRFilters(), + "lilly": mc.structural.LillyDemeritsFilters(), + "complexity": mc.complexity.ComplexityFilter(max_complexity=400) +} + +# Apply all filters +results = {} +for name, filt in filters.items(): + results[name] = filt(mols=candidate_mols, n_jobs=-1) + +# Identify compounds passing all filters +passes_all = all(r["pass"] for r in results.values()) +``` + +### Pattern 3: Identify Specific Chemical Groups + +Find molecules containing specific functional groups or scaffolds. + +```python +import medchem as mc + +# Create group detector for multiple groups +group_detector = mc.groups.ChemicalGroup( + groups=["hinge_binders", "phosphate_binders"] +) + +# Screen library +matches = group_detector.get_all_matches(mol_list) + +# Filter molecules with desired groups +mol_with_groups = [mol for mol, match in zip(mol_list, matches) if match] +``` + +## Best Practices + +1. **Context Matters**: Don't blindly apply filters. Understand the biological target and chemical space. + +2. **Combine Multiple Filters**: Use rules, structural alerts, and domain knowledge together for better decisions. + +3. **Use Parallelization**: For large datasets (>1000 molecules), always use `n_jobs=-1` for parallel processing. + +4. **Iterative Refinement**: Start with broad filters (Ro5), then apply more specific criteria (CNS, leadlike) as needed. + +5. **Document Filtering Decisions**: Track which molecules were filtered out and why for reproducibility. + +6. **Validate Results**: Remember that marketed drugs often fail standard filters—use these as guidelines, not absolute rules. + +7. **Consider Prodrugs**: Molecules designed as prodrugs may intentionally violate standard medicinal chemistry rules. + +## Resources + +### references/api_guide.md +Comprehensive API reference covering all medchem modules with detailed function signatures, parameters, and return types. + +### references/rules_catalog.md +Complete catalog of available rules, filters, and alerts with descriptions, thresholds, and literature references. + +### scripts/filter_molecules.py +Production-ready script for batch filtering workflows. Supports multiple input formats (CSV, SDF, SMILES), configurable filter combinations, and detailed reporting. + +**Usage:** +```bash +python scripts/filter_molecules.py input.csv --rules rule_of_five,rule_of_cns --alerts nibr --output filtered.csv +``` + +## Documentation + +Official documentation: https://medchem-docs.datamol.io/ +GitHub repository: https://github.com/datamol-io/medchem diff --git a/scientific-packages/medchem/references/api_guide.md b/scientific-packages/medchem/references/api_guide.md new file mode 100644 index 0000000..b67214f --- /dev/null +++ b/scientific-packages/medchem/references/api_guide.md @@ -0,0 +1,600 @@ +# Medchem API Reference + +Comprehensive reference for all medchem modules and functions. + +## Module: medchem.rules + +### Class: RuleFilters + +Filter molecules based on multiple medicinal chemistry rules. + +**Constructor:** +```python +RuleFilters(rule_list: List[str]) +``` + +**Parameters:** +- `rule_list`: List of rule names to apply. See available rules below. + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1, progress: bool = False) -> Dict +``` +- `mols`: List of RDKit molecule objects +- `n_jobs`: Number of parallel jobs (-1 uses all cores) +- `progress`: Show progress bar +- **Returns**: Dictionary with results for each rule + +**Example:** +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_five", "rule_of_cns"]) +results = rfilter(mols=mol_list, n_jobs=-1, progress=True) +``` + +### Module: medchem.rules.basic_rules + +Individual rule functions that can be applied to single molecules. + +#### rule_of_five() + +```python +rule_of_five(mol: Union[str, Chem.Mol]) -> bool +``` + +Lipinski's Rule of Five for oral bioavailability. + +**Criteria:** +- Molecular weight ≤ 500 Da +- LogP ≤ 5 +- H-bond donors ≤ 5 +- H-bond acceptors ≤ 10 + +**Parameters:** +- `mol`: SMILES string or RDKit molecule object + +**Returns:** True if molecule passes all criteria + +#### rule_of_three() + +```python +rule_of_three(mol: Union[str, Chem.Mol]) -> bool +``` + +Rule of Three for fragment screening libraries. + +**Criteria:** +- Molecular weight ≤ 300 Da +- LogP ≤ 3 +- H-bond donors ≤ 3 +- H-bond acceptors ≤ 3 +- Rotatable bonds ≤ 3 +- Polar surface area ≤ 60 Ų + +#### rule_of_oprea() + +```python +rule_of_oprea(mol: Union[str, Chem.Mol]) -> bool +``` + +Oprea's lead-like criteria for hit-to-lead optimization. + +**Criteria:** +- Molecular weight: 200-350 Da +- LogP: -2 to 4 +- Rotatable bonds ≤ 7 +- Rings ≤ 4 + +#### rule_of_cns() + +```python +rule_of_cns(mol: Union[str, Chem.Mol]) -> bool +``` + +CNS drug-likeness rules. + +**Criteria:** +- Molecular weight ≤ 450 Da +- LogP: -1 to 5 +- H-bond donors ≤ 2 +- TPSA ≤ 90 Ų + +#### rule_of_leadlike_soft() + +```python +rule_of_leadlike_soft(mol: Union[str, Chem.Mol]) -> bool +``` + +Soft lead-like criteria (more permissive). + +**Criteria:** +- Molecular weight: 250-450 Da +- LogP: -3 to 4 +- Rotatable bonds ≤ 10 + +#### rule_of_leadlike_strict() + +```python +rule_of_leadlike_strict(mol: Union[str, Chem.Mol]) -> bool +``` + +Strict lead-like criteria (more restrictive). + +**Criteria:** +- Molecular weight: 200-350 Da +- LogP: -2 to 3.5 +- Rotatable bonds ≤ 7 +- Rings: 1-3 + +#### rule_of_veber() + +```python +rule_of_veber(mol: Union[str, Chem.Mol]) -> bool +``` + +Veber's rules for oral bioavailability. + +**Criteria:** +- Rotatable bonds ≤ 10 +- TPSA ≤ 140 Ų + +#### rule_of_reos() + +```python +rule_of_reos(mol: Union[str, Chem.Mol]) -> bool +``` + +Rapid Elimination Of Swill (REOS) filter. + +**Criteria:** +- Molecular weight: 200-500 Da +- LogP: -5 to 5 +- H-bond donors: 0-5 +- H-bond acceptors: 0-10 + +#### rule_of_drug() + +```python +rule_of_drug(mol: Union[str, Chem.Mol]) -> bool +``` + +Combined drug-likeness criteria. + +**Criteria:** +- Passes Rule of Five +- Passes Veber rules +- No PAINS substructures + +#### golden_triangle() + +```python +golden_triangle(mol: Union[str, Chem.Mol]) -> bool +``` + +Golden Triangle for drug-likeness balance. + +**Criteria:** +- 200 ≤ MW ≤ 50×LogP + 400 +- LogP: -2 to 5 + +#### pains_filter() + +```python +pains_filter(mol: Union[str, Chem.Mol]) -> bool +``` + +Pan Assay INterference compoundS (PAINS) filter. + +**Returns:** True if molecule does NOT contain PAINS substructures + +--- + +## Module: medchem.structural + +### Class: CommonAlertsFilters + +Filter for common structural alerts derived from ChEMBL and literature. + +**Constructor:** +```python +CommonAlertsFilters() +``` + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1, progress: bool = False) -> List[Dict] +``` + +Apply common alerts filter to a list of molecules. + +**Returns:** List of dictionaries with keys: +- `has_alerts`: Boolean indicating if molecule has alerts +- `alert_details`: List of matched alert patterns +- `num_alerts`: Number of alerts found + +```python +check_mol(mol: Chem.Mol) -> Tuple[bool, List[str]] +``` + +Check a single molecule for structural alerts. + +**Returns:** Tuple of (has_alerts, list_of_alert_names) + +### Class: NIBRFilters + +Novartis NIBR medicinal chemistry filters. + +**Constructor:** +```python +NIBRFilters() +``` + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1, progress: bool = False) -> List[bool] +``` + +Apply NIBR filters to molecules. + +**Returns:** List of booleans (True if molecule passes) + +### Class: LillyDemeritsFilters + +Eli Lilly's demerit-based structural alert system (275 rules). + +**Constructor:** +```python +LillyDemeritsFilters() +``` + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1, progress: bool = False) -> List[Dict] +``` + +Calculate Lilly demerits for molecules. + +**Returns:** List of dictionaries with keys: +- `demerits`: Total demerit score +- `passes`: Boolean (True if demerits ≤ 100) +- `matched_patterns`: List of matched patterns with scores + +--- + +## Module: medchem.functional + +High-level functional API for common operations. + +### nibr_filter() + +```python +nibr_filter(mols: List[Chem.Mol], n_jobs: int = 1) -> List[bool] +``` + +Apply NIBR filters using functional API. + +**Parameters:** +- `mols`: List of molecules +- `n_jobs`: Parallelization level + +**Returns:** List of pass/fail booleans + +### common_alerts_filter() + +```python +common_alerts_filter(mols: List[Chem.Mol], n_jobs: int = 1) -> List[Dict] +``` + +Apply common alerts filter using functional API. + +**Returns:** List of results dictionaries + +### lilly_demerits_filter() + +```python +lilly_demerits_filter(mols: List[Chem.Mol], n_jobs: int = 1) -> List[Dict] +``` + +Calculate Lilly demerits using functional API. + +--- + +## Module: medchem.groups + +### Class: ChemicalGroup + +Detect specific chemical groups in molecules. + +**Constructor:** +```python +ChemicalGroup(groups: List[str], custom_smarts: Optional[Dict[str, str]] = None) +``` + +**Parameters:** +- `groups`: List of predefined group names +- `custom_smarts`: Dictionary mapping custom group names to SMARTS patterns + +**Predefined Groups:** +- `"hinge_binders"`: Kinase hinge binding motifs +- `"phosphate_binders"`: Phosphate binding groups +- `"michael_acceptors"`: Michael acceptor electrophiles +- `"reactive_groups"`: General reactive functionalities + +**Methods:** + +```python +has_match(mols: List[Chem.Mol]) -> List[bool] +``` + +Check if molecules contain any of the specified groups. + +```python +get_matches(mol: Chem.Mol) -> Dict[str, List[Tuple]] +``` + +Get detailed match information for a single molecule. + +**Returns:** Dictionary mapping group names to lists of atom indices + +```python +get_all_matches(mols: List[Chem.Mol]) -> List[Dict] +``` + +Get match information for all molecules. + +**Example:** +```python +group = mc.groups.ChemicalGroup(groups=["hinge_binders", "phosphate_binders"]) +matches = group.get_all_matches(mol_list) +``` + +--- + +## Module: medchem.catalogs + +### Class: NamedCatalogs + +Access to curated chemical catalogs. + +**Available Catalogs:** +- `"functional_groups"`: Common functional groups +- `"protecting_groups"`: Protecting group structures +- `"reagents"`: Common reagents +- `"fragments"`: Standard fragments + +**Usage:** +```python +catalog = mc.catalogs.NamedCatalogs.get("functional_groups") +matches = catalog.get_matches(mol) +``` + +--- + +## Module: medchem.complexity + +Calculate molecular complexity metrics. + +### calculate_complexity() + +```python +calculate_complexity(mol: Chem.Mol, method: str = "bertz") -> float +``` + +Calculate complexity score for a molecule. + +**Parameters:** +- `mol`: RDKit molecule +- `method`: Complexity metric ("bertz", "whitlock", "barone") + +**Returns:** Complexity score (higher = more complex) + +### Class: ComplexityFilter + +Filter molecules by complexity threshold. + +**Constructor:** +```python +ComplexityFilter(max_complexity: float, method: str = "bertz") +``` + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1) -> List[bool] +``` + +Filter molecules exceeding complexity threshold. + +--- + +## Module: medchem.constraints + +### Class: Constraints + +Apply custom property-based constraints. + +**Constructor:** +```python +Constraints( + mw_range: Optional[Tuple[float, float]] = None, + logp_range: Optional[Tuple[float, float]] = None, + tpsa_max: Optional[float] = None, + tpsa_range: Optional[Tuple[float, float]] = None, + hbd_max: Optional[int] = None, + hba_max: Optional[int] = None, + rotatable_bonds_max: Optional[int] = None, + rings_range: Optional[Tuple[int, int]] = None, + aromatic_rings_max: Optional[int] = None, +) +``` + +**Parameters:** All parameters are optional. Specify only the constraints needed. + +**Methods:** + +```python +__call__(mols: List[Chem.Mol], n_jobs: int = 1) -> List[Dict] +``` + +Apply constraints to molecules. + +**Returns:** List of dictionaries with keys: +- `passes`: Boolean indicating if all constraints pass +- `violations`: List of constraint names that failed + +**Example:** +```python +constraints = mc.constraints.Constraints( + mw_range=(200, 500), + logp_range=(-2, 5), + tpsa_max=140 +) +results = constraints(mols=mol_list, n_jobs=-1) +``` + +--- + +## Module: medchem.query + +Query language for complex filtering. + +### parse() + +```python +parse(query: str) -> Query +``` + +Parse a medchem query string into a Query object. + +**Query Syntax:** +- Operators: `AND`, `OR`, `NOT` +- Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=` +- Properties: `complexity`, `lilly_demerits`, `mw`, `logp`, `tpsa` +- Rules: `rule_of_five`, `rule_of_cns`, etc. +- Filters: `common_alerts`, `nibr_filter`, `pains_filter` + +**Example Queries:** +```python +"rule_of_five AND NOT common_alerts" +"rule_of_cns AND complexity < 400" +"mw > 200 AND mw < 500 AND logp < 5" +"(rule_of_five OR rule_of_oprea) AND NOT pains_filter" +``` + +### Class: Query + +**Methods:** + +```python +apply(mols: List[Chem.Mol], n_jobs: int = 1) -> List[bool] +``` + +Apply parsed query to molecules. + +**Example:** +```python +query = mc.query.parse("rule_of_five AND NOT common_alerts") +results = query.apply(mols=mol_list, n_jobs=-1) +passing_mols = [mol for mol, passes in zip(mol_list, results) if passes] +``` + +--- + +## Module: medchem.utils + +Utility functions for working with molecules. + +### batch_process() + +```python +batch_process( + mols: List[Chem.Mol], + func: Callable, + n_jobs: int = 1, + progress: bool = False, + batch_size: Optional[int] = None +) -> List +``` + +Process molecules in parallel batches. + +**Parameters:** +- `mols`: List of molecules +- `func`: Function to apply to each molecule +- `n_jobs`: Number of parallel workers +- `progress`: Show progress bar +- `batch_size`: Size of processing batches + +### standardize_mol() + +```python +standardize_mol(mol: Chem.Mol) -> Chem.Mol +``` + +Standardize molecule representation (sanitize, neutralize charges, etc.). + +--- + +## Common Patterns + +### Pattern: Parallel Processing + +All filters support parallelization: + +```python +# Use all CPU cores +results = filter_object(mols=mol_list, n_jobs=-1, progress=True) + +# Use specific number of cores +results = filter_object(mols=mol_list, n_jobs=4, progress=True) +``` + +### Pattern: Combining Multiple Filters + +```python +import medchem as mc + +# Apply multiple filters +rule_filter = mc.rules.RuleFilters(rule_list=["rule_of_five"]) +alert_filter = mc.structural.CommonAlertsFilters() +lilly_filter = mc.structural.LillyDemeritsFilters() + +# Get results +rule_results = rule_filter(mols=mol_list, n_jobs=-1) +alert_results = alert_filter(mols=mol_list, n_jobs=-1) +lilly_results = lilly_filter(mols=mol_list, n_jobs=-1) + +# Combine criteria +passing_mols = [ + mol for i, mol in enumerate(mol_list) + if rule_results[i]["passes"] + and not alert_results[i]["has_alerts"] + and lilly_results[i]["passes"] +] +``` + +### Pattern: Working with DataFrames + +```python +import pandas as pd +import datamol as dm +import medchem as mc + +# Load data +df = pd.read_csv("molecules.csv") +df["mol"] = df["smiles"].apply(dm.to_mol) + +# Apply filters +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_five", "rule_of_cns"]) +results = rfilter(mols=df["mol"].tolist(), n_jobs=-1) + +# Add results to dataframe +df["passes_ro5"] = [r["rule_of_five"] for r in results] +df["passes_cns"] = [r["rule_of_cns"] for r in results] + +# Filter dataframe +filtered_df = df[df["passes_ro5"] & df["passes_cns"]] +``` diff --git a/scientific-packages/medchem/references/rules_catalog.md b/scientific-packages/medchem/references/rules_catalog.md new file mode 100644 index 0000000..a7754c6 --- /dev/null +++ b/scientific-packages/medchem/references/rules_catalog.md @@ -0,0 +1,604 @@ +# Medchem Rules and Filters Catalog + +Comprehensive catalog of all available medicinal chemistry rules, structural alerts, and filters in medchem. + +## Table of Contents + +1. [Drug-Likeness Rules](#drug-likeness-rules) +2. [Lead-Likeness Rules](#lead-likeness-rules) +3. [Fragment Rules](#fragment-rules) +4. [CNS Rules](#cns-rules) +5. [Structural Alert Filters](#structural-alert-filters) +6. [Chemical Group Patterns](#chemical-group-patterns) + +--- + +## Drug-Likeness Rules + +### Rule of Five (Lipinski) + +**Reference:** Lipinski et al., Adv Drug Deliv Rev (1997) 23:3-25 + +**Purpose:** Predict oral bioavailability + +**Criteria:** +- Molecular Weight ≤ 500 Da +- LogP ≤ 5 +- Hydrogen Bond Donors ≤ 5 +- Hydrogen Bond Acceptors ≤ 10 + +**Usage:** +```python +mc.rules.basic_rules.rule_of_five(mol) +``` + +**Notes:** +- One of the most widely used filters in drug discovery +- About 90% of orally active drugs comply with these rules +- Exceptions exist, especially for natural products and antibiotics + +--- + +### Rule of Veber + +**Reference:** Veber et al., J Med Chem (2002) 45:2615-2623 + +**Purpose:** Additional criteria for oral bioavailability + +**Criteria:** +- Rotatable Bonds ≤ 10 +- Topological Polar Surface Area (TPSA) ≤ 140 Ų + +**Usage:** +```python +mc.rules.basic_rules.rule_of_veber(mol) +``` + +**Notes:** +- Complements Rule of Five +- TPSA correlates with cell permeability +- Rotatable bonds affect molecular flexibility + +--- + +### Rule of Drug + +**Purpose:** Combined drug-likeness assessment + +**Criteria:** +- Passes Rule of Five +- Passes Veber rules +- Does not contain PAINS substructures + +**Usage:** +```python +mc.rules.basic_rules.rule_of_drug(mol) +``` + +--- + +### REOS (Rapid Elimination Of Swill) + +**Reference:** Walters & Murcko, Adv Drug Deliv Rev (2002) 54:255-271 + +**Purpose:** Filter out compounds unlikely to be drugs + +**Criteria:** +- Molecular Weight: 200-500 Da +- LogP: -5 to 5 +- Hydrogen Bond Donors: 0-5 +- Hydrogen Bond Acceptors: 0-10 + +**Usage:** +```python +mc.rules.basic_rules.rule_of_reos(mol) +``` + +--- + +### Golden Triangle + +**Reference:** Johnson et al., J Med Chem (2009) 52:5487-5500 + +**Purpose:** Balance lipophilicity and molecular weight + +**Criteria:** +- 200 ≤ MW ≤ 50 × LogP + 400 +- LogP: -2 to 5 + +**Usage:** +```python +mc.rules.basic_rules.golden_triangle(mol) +``` + +**Notes:** +- Defines optimal physicochemical space +- Visual representation resembles a triangle on MW vs LogP plot + +--- + +## Lead-Likeness Rules + +### Rule of Oprea + +**Reference:** Oprea et al., J Chem Inf Comput Sci (2001) 41:1308-1315 + +**Purpose:** Identify lead-like compounds for optimization + +**Criteria:** +- Molecular Weight: 200-350 Da +- LogP: -2 to 4 +- Rotatable Bonds ≤ 7 +- Number of Rings ≤ 4 + +**Usage:** +```python +mc.rules.basic_rules.rule_of_oprea(mol) +``` + +**Rationale:** Lead compounds should have "room to grow" during optimization + +--- + +### Rule of Leadlike (Soft) + +**Purpose:** Permissive lead-like criteria + +**Criteria:** +- Molecular Weight: 250-450 Da +- LogP: -3 to 4 +- Rotatable Bonds ≤ 10 + +**Usage:** +```python +mc.rules.basic_rules.rule_of_leadlike_soft(mol) +``` + +--- + +### Rule of Leadlike (Strict) + +**Purpose:** Restrictive lead-like criteria + +**Criteria:** +- Molecular Weight: 200-350 Da +- LogP: -2 to 3.5 +- Rotatable Bonds ≤ 7 +- Number of Rings: 1-3 + +**Usage:** +```python +mc.rules.basic_rules.rule_of_leadlike_strict(mol) +``` + +--- + +## Fragment Rules + +### Rule of Three + +**Reference:** Congreve et al., Drug Discov Today (2003) 8:876-877 + +**Purpose:** Screen fragment libraries for fragment-based drug discovery + +**Criteria:** +- Molecular Weight ≤ 300 Da +- LogP ≤ 3 +- Hydrogen Bond Donors ≤ 3 +- Hydrogen Bond Acceptors ≤ 3 +- Rotatable Bonds ≤ 3 +- Polar Surface Area ≤ 60 Ų + +**Usage:** +```python +mc.rules.basic_rules.rule_of_three(mol) +``` + +**Notes:** +- Fragments are grown into leads during optimization +- Lower complexity allows more starting points + +--- + +## CNS Rules + +### Rule of CNS + +**Purpose:** Central nervous system drug-likeness + +**Criteria:** +- Molecular Weight ≤ 450 Da +- LogP: -1 to 5 +- Hydrogen Bond Donors ≤ 2 +- TPSA ≤ 90 Ų + +**Usage:** +```python +mc.rules.basic_rules.rule_of_cns(mol) +``` + +**Rationale:** +- Blood-brain barrier penetration requires specific properties +- Lower TPSA and HBD count improve BBB permeability +- Tight constraints reflect CNS challenges + +--- + +## Structural Alert Filters + +### PAINS (Pan Assay INterference compoundS) + +**Reference:** Baell & Holloway, J Med Chem (2010) 53:2719-2740 + +**Purpose:** Identify compounds that interfere with assays + +**Categories:** +- Catechols +- Quinones +- Rhodanines +- Hydroxyphenylhydrazones +- Alkyl/aryl aldehydes +- Michael acceptors (specific patterns) + +**Usage:** +```python +mc.rules.basic_rules.pains_filter(mol) +# Returns True if NO PAINS found +``` + +**Notes:** +- PAINS compounds show activity in multiple assays through non-specific mechanisms +- Common false positives in screening campaigns +- Should be deprioritized in lead selection + +--- + +### Common Alerts Filters + +**Source:** Derived from ChEMBL curation and medicinal chemistry literature + +**Purpose:** Flag common problematic structural patterns + +**Alert Categories:** +1. **Reactive Groups** + - Epoxides + - Aziridines + - Acid halides + - Isocyanates + +2. **Metabolic Liabilities** + - Hydrazines + - Thioureas + - Anilines (certain patterns) + +3. **Aggregators** + - Polyaromatic systems + - Long aliphatic chains + +4. **Toxicophores** + - Nitro aromatics + - Aromatic N-oxides + - Certain heterocycles + +**Usage:** +```python +alert_filter = mc.structural.CommonAlertsFilters() +has_alerts, details = alert_filter.check_mol(mol) +``` + +**Return Format:** +```python +{ + "has_alerts": True, + "alert_details": ["reactive_epoxide", "metabolic_hydrazine"], + "num_alerts": 2 +} +``` + +--- + +### NIBR Filters + +**Source:** Novartis Institutes for BioMedical Research + +**Purpose:** Industrial medicinal chemistry filtering rules + +**Features:** +- Proprietary filter set developed from Novartis experience +- Balances drug-likeness with practical medicinal chemistry +- Includes both structural alerts and property filters + +**Usage:** +```python +nibr_filter = mc.structural.NIBRFilters() +results = nibr_filter(mols=mol_list, n_jobs=-1) +``` + +**Return Format:** Boolean list (True = passes) + +--- + +### Lilly Demerits Filter + +**Reference:** Based on Eli Lilly medicinal chemistry rules + +**Source:** 275 structural patterns accumulated over 18 years + +**Purpose:** Identify assay interference and problematic functionalities + +**Mechanism:** +- Each matched pattern adds demerits +- Molecules with >100 demerits are rejected +- Some patterns add 10-50 demerits, others add 100+ (instant rejection) + +**Demerit Categories:** + +1. **High Demerits (>50):** + - Known toxic groups + - Highly reactive functionalities + - Strong metal chelators + +2. **Medium Demerits (20-50):** + - Metabolic liabilities + - Aggregation-prone structures + - Frequent hitters + +3. **Low Demerits (5-20):** + - Minor concerns + - Context-dependent issues + +**Usage:** +```python +lilly_filter = mc.structural.LillyDemeritsFilters() +results = lilly_filter(mols=mol_list, n_jobs=-1) +``` + +**Return Format:** +```python +{ + "demerits": 35, + "passes": True, # (demerits ≤ 100) + "matched_patterns": [ + {"pattern": "phenolic_ester", "demerits": 20}, + {"pattern": "aniline_derivative", "demerits": 15} + ] +} +``` + +--- + +## Chemical Group Patterns + +### Hinge Binders + +**Purpose:** Identify kinase hinge-binding motifs + +**Common Patterns:** +- Aminopyridines +- Aminopyrimidines +- Indazoles +- Benzimidazoles + +**Usage:** +```python +group = mc.groups.ChemicalGroup(groups=["hinge_binders"]) +has_hinge = group.has_match(mol_list) +``` + +**Application:** Kinase inhibitor design + +--- + +### Phosphate Binders + +**Purpose:** Identify phosphate-binding groups + +**Common Patterns:** +- Basic amines in specific geometries +- Guanidinium groups +- Arginine mimetics + +**Usage:** +```python +group = mc.groups.ChemicalGroup(groups=["phosphate_binders"]) +``` + +**Application:** Kinase inhibitors, phosphatase inhibitors + +--- + +### Michael Acceptors + +**Purpose:** Identify electrophilic Michael acceptor groups + +**Common Patterns:** +- α,β-Unsaturated carbonyls +- α,β-Unsaturated nitriles +- Vinyl sulfones +- Acrylamides + +**Usage:** +```python +group = mc.groups.ChemicalGroup(groups=["michael_acceptors"]) +``` + +**Notes:** +- Can be desirable for covalent inhibitors +- Often flagged as reactive alerts in screening + +--- + +### Reactive Groups + +**Purpose:** Identify generally reactive functionalities + +**Common Patterns:** +- Epoxides +- Aziridines +- Acyl halides +- Isocyanates +- Sulfonyl chlorides + +**Usage:** +```python +group = mc.groups.ChemicalGroup(groups=["reactive_groups"]) +``` + +--- + +## Custom SMARTS Patterns + +Define custom structural patterns using SMARTS: + +```python +custom_patterns = { + "my_warhead": "[C;H0](=O)C(F)(F)F", # Trifluoromethyl ketone + "my_scaffold": "c1ccc2c(c1)ncc(n2)N", # Aminobenzimidazole +} + +group = mc.groups.ChemicalGroup( + groups=["hinge_binders"], + custom_smarts=custom_patterns +) +``` + +--- + +## Filter Selection Guidelines + +### Initial Screening (High-Throughput) + +Recommended filters: +- Rule of Five +- PAINS filter +- Common Alerts (permissive settings) + +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_five", "pains_filter"]) +alert_filter = mc.structural.CommonAlertsFilters() +``` + +--- + +### Hit-to-Lead + +Recommended filters: +- Rule of Oprea or Leadlike (soft) +- NIBR filters +- Lilly Demerits + +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_oprea"]) +nibr_filter = mc.structural.NIBRFilters() +lilly_filter = mc.structural.LillyDemeritsFilters() +``` + +--- + +### Lead Optimization + +Recommended filters: +- Rule of Drug +- Leadlike (strict) +- Full structural alert analysis +- Complexity filters + +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_drug", "rule_of_leadlike_strict"]) +alert_filter = mc.structural.CommonAlertsFilters() +complexity_filter = mc.complexity.ComplexityFilter(max_complexity=400) +``` + +--- + +### CNS Targets + +Recommended filters: +- Rule of CNS +- Reduced PAINS criteria (CNS-focused) +- BBB permeability constraints + +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_cns"]) +constraints = mc.constraints.Constraints( + tpsa_max=90, + hbd_max=2, + mw_range=(300, 450) +) +``` + +--- + +### Fragment-Based Drug Discovery + +Recommended filters: +- Rule of Three +- Minimal complexity +- Basic reactive group check + +```python +rfilter = mc.rules.RuleFilters(rule_list=["rule_of_three"]) +complexity_filter = mc.complexity.ComplexityFilter(max_complexity=250) +``` + +--- + +## Important Considerations + +### False Positives and False Negatives + +**Filters are guidelines, not absolutes:** + +1. **False Positives** (good drugs flagged): + - ~10% of marketed drugs fail Rule of Five + - Natural products often violate standard rules + - Prodrugs intentionally break rules + - Antibiotics and antivirals frequently non-compliant + +2. **False Negatives** (bad compounds passing): + - Passing filters doesn't guarantee success + - Target-specific issues not captured + - In vivo properties not fully predicted + +### Context-Specific Application + +**Different contexts require different criteria:** + +- **Target Class:** Kinases vs GPCRs vs ion channels have different optimal spaces +- **Modality:** Small molecules vs PROTACs vs molecular glues +- **Administration Route:** Oral vs IV vs topical +- **Disease Area:** CNS vs oncology vs infectious disease +- **Stage:** Screening vs hit-to-lead vs lead optimization + +### Complementing with Machine Learning + +Modern approaches combine rules with ML: + +```python +# Rule-based pre-filtering +rule_results = mc.rules.RuleFilters(rule_list=["rule_of_five"])(mols) +filtered_mols = [mol for mol, r in zip(mols, rule_results) if r["passes"]] + +# ML model scoring on filtered set +ml_scores = ml_model.predict(filtered_mols) + +# Combined decision +final_candidates = [ + mol for mol, score in zip(filtered_mols, ml_scores) + if score > threshold +] +``` + +--- + +## References + +1. Lipinski CA et al. Adv Drug Deliv Rev (1997) 23:3-25 +2. Veber DF et al. J Med Chem (2002) 45:2615-2623 +3. Oprea TI et al. J Chem Inf Comput Sci (2001) 41:1308-1315 +4. Congreve M et al. Drug Discov Today (2003) 8:876-877 +5. Baell JB & Holloway GA. J Med Chem (2010) 53:2719-2740 +6. Johnson TW et al. J Med Chem (2009) 52:5487-5500 +7. Walters WP & Murcko MA. Adv Drug Deliv Rev (2002) 54:255-271 +8. Hann MM & Oprea TI. Curr Opin Chem Biol (2004) 8:255-263 +9. Rishton GM. Drug Discov Today (1997) 2:382-384 diff --git a/scientific-packages/medchem/scripts/filter_molecules.py b/scientific-packages/medchem/scripts/filter_molecules.py new file mode 100644 index 0000000..e9423cd --- /dev/null +++ b/scientific-packages/medchem/scripts/filter_molecules.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 +""" +Batch molecular filtering using medchem library. + +This script provides a production-ready workflow for filtering compound libraries +using medchem rules, structural alerts, and custom constraints. + +Usage: + python filter_molecules.py input.csv --rules rule_of_five,rule_of_cns --alerts nibr --output filtered.csv + python filter_molecules.py input.sdf --rules rule_of_drug --lilly --complexity 400 --output results.csv + python filter_molecules.py smiles.txt --nibr --pains --n-jobs -1 --output clean.csv +""" + +import argparse +import sys +from pathlib import Path +from typing import List, Dict, Optional, Tuple +import json + +try: + import pandas as pd + import datamol as dm + import medchem as mc + from rdkit import Chem + from tqdm import tqdm +except ImportError as e: + print(f"Error: Missing required package: {e}") + print("Install dependencies: pip install medchem datamol pandas tqdm") + sys.exit(1) + + +def load_molecules(input_file: Path, smiles_column: str = "smiles") -> Tuple[pd.DataFrame, List[Chem.Mol]]: + """ + Load molecules from various file formats. + + Supports: + - CSV/TSV with SMILES column + - SDF files + - Plain text files with one SMILES per line + + Returns: + Tuple of (DataFrame with metadata, list of RDKit molecules) + """ + suffix = input_file.suffix.lower() + + if suffix == ".sdf": + print(f"Loading SDF file: {input_file}") + supplier = Chem.SDMolSupplier(str(input_file)) + mols = [mol for mol in supplier if mol is not None] + + # Create DataFrame from SDF properties + data = [] + for mol in mols: + props = mol.GetPropsAsDict() + props["smiles"] = Chem.MolToSmiles(mol) + data.append(props) + df = pd.DataFrame(data) + + elif suffix in [".csv", ".tsv"]: + print(f"Loading CSV/TSV file: {input_file}") + sep = "\t" if suffix == ".tsv" else "," + df = pd.read_csv(input_file, sep=sep) + + if smiles_column not in df.columns: + print(f"Error: Column '{smiles_column}' not found in file") + print(f"Available columns: {', '.join(df.columns)}") + sys.exit(1) + + print(f"Converting SMILES to molecules...") + mols = [dm.to_mol(smi) for smi in tqdm(df[smiles_column], desc="Parsing")] + + elif suffix == ".txt": + print(f"Loading text file: {input_file}") + with open(input_file) as f: + smiles_list = [line.strip() for line in f if line.strip()] + + df = pd.DataFrame({"smiles": smiles_list}) + print(f"Converting SMILES to molecules...") + mols = [dm.to_mol(smi) for smi in tqdm(smiles_list, desc="Parsing")] + + else: + print(f"Error: Unsupported file format: {suffix}") + print("Supported formats: .csv, .tsv, .sdf, .txt") + sys.exit(1) + + # Filter out invalid molecules + valid_indices = [i for i, mol in enumerate(mols) if mol is not None] + if len(valid_indices) < len(mols): + n_invalid = len(mols) - len(valid_indices) + print(f"Warning: {n_invalid} invalid molecules removed") + df = df.iloc[valid_indices].reset_index(drop=True) + mols = [mols[i] for i in valid_indices] + + print(f"Loaded {len(mols)} valid molecules") + return df, mols + + +def apply_rule_filters(mols: List[Chem.Mol], rules: List[str], n_jobs: int) -> pd.DataFrame: + """Apply medicinal chemistry rule filters.""" + print(f"\nApplying rule filters: {', '.join(rules)}") + + rfilter = mc.rules.RuleFilters(rule_list=rules) + results = rfilter(mols=mols, n_jobs=n_jobs, progress=True) + + # Convert to DataFrame + df_results = pd.DataFrame(results) + + # Add summary column + df_results["passes_all_rules"] = df_results.all(axis=1) + + return df_results + + +def apply_structural_alerts(mols: List[Chem.Mol], alert_type: str, n_jobs: int) -> pd.DataFrame: + """Apply structural alert filters.""" + print(f"\nApplying {alert_type} structural alerts...") + + if alert_type == "common": + alert_filter = mc.structural.CommonAlertsFilters() + results = alert_filter(mols=mols, n_jobs=n_jobs, progress=True) + + df_results = pd.DataFrame({ + "has_common_alerts": [r["has_alerts"] for r in results], + "num_common_alerts": [r["num_alerts"] for r in results], + "common_alert_details": [", ".join(r["alert_details"]) if r["alert_details"] else "" for r in results] + }) + + elif alert_type == "nibr": + nibr_filter = mc.structural.NIBRFilters() + results = nibr_filter(mols=mols, n_jobs=n_jobs, progress=True) + + df_results = pd.DataFrame({ + "passes_nibr": results + }) + + elif alert_type == "lilly": + lilly_filter = mc.structural.LillyDemeritsFilters() + results = lilly_filter(mols=mols, n_jobs=n_jobs, progress=True) + + df_results = pd.DataFrame({ + "lilly_demerits": [r["demerits"] for r in results], + "passes_lilly": [r["passes"] for r in results], + "lilly_patterns": [", ".join([p["pattern"] for p in r["matched_patterns"]]) for r in results] + }) + + elif alert_type == "pains": + results = [mc.rules.basic_rules.pains_filter(mol) for mol in tqdm(mols, desc="PAINS")] + + df_results = pd.DataFrame({ + "passes_pains": results + }) + + else: + raise ValueError(f"Unknown alert type: {alert_type}") + + return df_results + + +def apply_complexity_filter(mols: List[Chem.Mol], max_complexity: float, method: str = "bertz") -> pd.DataFrame: + """Calculate molecular complexity.""" + print(f"\nCalculating molecular complexity (method={method}, max={max_complexity})...") + + complexity_scores = [ + mc.complexity.calculate_complexity(mol, method=method) + for mol in tqdm(mols, desc="Complexity") + ] + + df_results = pd.DataFrame({ + "complexity_score": complexity_scores, + "passes_complexity": [score <= max_complexity for score in complexity_scores] + }) + + return df_results + + +def apply_constraints(mols: List[Chem.Mol], constraints: Dict, n_jobs: int) -> pd.DataFrame: + """Apply custom property constraints.""" + print(f"\nApplying constraints: {constraints}") + + constraint_filter = mc.constraints.Constraints(**constraints) + results = constraint_filter(mols=mols, n_jobs=n_jobs, progress=True) + + df_results = pd.DataFrame({ + "passes_constraints": [r["passes"] for r in results], + "constraint_violations": [", ".join(r["violations"]) if r["violations"] else "" for r in results] + }) + + return df_results + + +def apply_chemical_groups(mols: List[Chem.Mol], groups: List[str]) -> pd.DataFrame: + """Detect chemical groups.""" + print(f"\nDetecting chemical groups: {', '.join(groups)}") + + group_detector = mc.groups.ChemicalGroup(groups=groups) + results = group_detector.get_all_matches(mols) + + df_results = pd.DataFrame() + for group in groups: + df_results[f"has_{group}"] = [bool(r.get(group)) for r in results] + + return df_results + + +def generate_summary(df: pd.DataFrame, output_file: Path): + """Generate filtering summary report.""" + summary_file = output_file.parent / f"{output_file.stem}_summary.txt" + + with open(summary_file, "w") as f: + f.write("=" * 80 + "\n") + f.write("MEDCHEM FILTERING SUMMARY\n") + f.write("=" * 80 + "\n\n") + + f.write(f"Total molecules processed: {len(df)}\n\n") + + # Rule results + rule_cols = [col for col in df.columns if col.startswith("rule_") or col == "passes_all_rules"] + if rule_cols: + f.write("RULE FILTERS:\n") + f.write("-" * 40 + "\n") + for col in rule_cols: + if col in df.columns and df[col].dtype == bool: + n_pass = df[col].sum() + pct = 100 * n_pass / len(df) + f.write(f" {col}: {n_pass} passed ({pct:.1f}%)\n") + f.write("\n") + + # Structural alerts + alert_cols = [col for col in df.columns if "alert" in col.lower() or "nibr" in col.lower() or "lilly" in col.lower() or "pains" in col.lower()] + if alert_cols: + f.write("STRUCTURAL ALERTS:\n") + f.write("-" * 40 + "\n") + if "has_common_alerts" in df.columns: + n_clean = (~df["has_common_alerts"]).sum() + pct = 100 * n_clean / len(df) + f.write(f" No common alerts: {n_clean} ({pct:.1f}%)\n") + if "passes_nibr" in df.columns: + n_pass = df["passes_nibr"].sum() + pct = 100 * n_pass / len(df) + f.write(f" Passes NIBR: {n_pass} ({pct:.1f}%)\n") + if "passes_lilly" in df.columns: + n_pass = df["passes_lilly"].sum() + pct = 100 * n_pass / len(df) + f.write(f" Passes Lilly: {n_pass} ({pct:.1f}%)\n") + avg_demerits = df["lilly_demerits"].mean() + f.write(f" Average Lilly demerits: {avg_demerits:.1f}\n") + if "passes_pains" in df.columns: + n_pass = df["passes_pains"].sum() + pct = 100 * n_pass / len(df) + f.write(f" Passes PAINS: {n_pass} ({pct:.1f}%)\n") + f.write("\n") + + # Complexity + if "complexity_score" in df.columns: + f.write("COMPLEXITY:\n") + f.write("-" * 40 + "\n") + avg_complexity = df["complexity_score"].mean() + f.write(f" Average complexity: {avg_complexity:.1f}\n") + if "passes_complexity" in df.columns: + n_pass = df["passes_complexity"].sum() + pct = 100 * n_pass / len(df) + f.write(f" Within threshold: {n_pass} ({pct:.1f}%)\n") + f.write("\n") + + # Constraints + if "passes_constraints" in df.columns: + f.write("CONSTRAINTS:\n") + f.write("-" * 40 + "\n") + n_pass = df["passes_constraints"].sum() + pct = 100 * n_pass / len(df) + f.write(f" Passes all constraints: {n_pass} ({pct:.1f}%)\n") + f.write("\n") + + # Overall pass rate + pass_cols = [col for col in df.columns if col.startswith("passes_")] + if pass_cols: + df["passes_all_filters"] = df[pass_cols].all(axis=1) + n_pass = df["passes_all_filters"].sum() + pct = 100 * n_pass / len(df) + f.write("OVERALL:\n") + f.write("-" * 40 + "\n") + f.write(f" Molecules passing all filters: {n_pass} ({pct:.1f}%)\n") + + f.write("\n" + "=" * 80 + "\n") + + print(f"\nSummary report saved to: {summary_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Batch molecular filtering using medchem", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + + # Input/Output + parser.add_argument("input", type=Path, help="Input file (CSV, TSV, SDF, or TXT)") + parser.add_argument("--output", "-o", type=Path, required=True, help="Output CSV file") + parser.add_argument("--smiles-column", default="smiles", help="Name of SMILES column (default: smiles)") + + # Rule filters + parser.add_argument("--rules", help="Comma-separated list of rules (e.g., rule_of_five,rule_of_cns)") + + # Structural alerts + parser.add_argument("--common-alerts", action="store_true", help="Apply common structural alerts") + parser.add_argument("--nibr", action="store_true", help="Apply NIBR filters") + parser.add_argument("--lilly", action="store_true", help="Apply Lilly demerits filter") + parser.add_argument("--pains", action="store_true", help="Apply PAINS filter") + + # Complexity + parser.add_argument("--complexity", type=float, help="Maximum complexity threshold") + parser.add_argument("--complexity-method", default="bertz", choices=["bertz", "whitlock", "barone"], + help="Complexity calculation method") + + # Constraints + parser.add_argument("--mw-range", help="Molecular weight range (e.g., 200,500)") + parser.add_argument("--logp-range", help="LogP range (e.g., -2,5)") + parser.add_argument("--tpsa-max", type=float, help="Maximum TPSA") + parser.add_argument("--hbd-max", type=int, help="Maximum H-bond donors") + parser.add_argument("--hba-max", type=int, help="Maximum H-bond acceptors") + parser.add_argument("--rotatable-bonds-max", type=int, help="Maximum rotatable bonds") + + # Chemical groups + parser.add_argument("--groups", help="Comma-separated chemical groups to detect") + + # Processing options + parser.add_argument("--n-jobs", type=int, default=-1, help="Number of parallel jobs (-1 = all cores)") + parser.add_argument("--no-summary", action="store_true", help="Don't generate summary report") + parser.add_argument("--filter-output", action="store_true", help="Only output molecules passing all filters") + + args = parser.parse_args() + + # Load molecules + df, mols = load_molecules(args.input, args.smiles_column) + + # Apply filters + result_dfs = [df] + + # Rules + if args.rules: + rule_list = [r.strip() for r in args.rules.split(",")] + df_rules = apply_rule_filters(mols, rule_list, args.n_jobs) + result_dfs.append(df_rules) + + # Structural alerts + if args.common_alerts: + df_alerts = apply_structural_alerts(mols, "common", args.n_jobs) + result_dfs.append(df_alerts) + + if args.nibr: + df_nibr = apply_structural_alerts(mols, "nibr", args.n_jobs) + result_dfs.append(df_nibr) + + if args.lilly: + df_lilly = apply_structural_alerts(mols, "lilly", args.n_jobs) + result_dfs.append(df_lilly) + + if args.pains: + df_pains = apply_structural_alerts(mols, "pains", args.n_jobs) + result_dfs.append(df_pains) + + # Complexity + if args.complexity: + df_complexity = apply_complexity_filter(mols, args.complexity, args.complexity_method) + result_dfs.append(df_complexity) + + # Constraints + constraints = {} + if args.mw_range: + mw_min, mw_max = map(float, args.mw_range.split(",")) + constraints["mw_range"] = (mw_min, mw_max) + if args.logp_range: + logp_min, logp_max = map(float, args.logp_range.split(",")) + constraints["logp_range"] = (logp_min, logp_max) + if args.tpsa_max: + constraints["tpsa_max"] = args.tpsa_max + if args.hbd_max: + constraints["hbd_max"] = args.hbd_max + if args.hba_max: + constraints["hba_max"] = args.hba_max + if args.rotatable_bonds_max: + constraints["rotatable_bonds_max"] = args.rotatable_bonds_max + + if constraints: + df_constraints = apply_constraints(mols, constraints, args.n_jobs) + result_dfs.append(df_constraints) + + # Chemical groups + if args.groups: + group_list = [g.strip() for g in args.groups.split(",")] + df_groups = apply_chemical_groups(mols, group_list) + result_dfs.append(df_groups) + + # Combine results + df_final = pd.concat(result_dfs, axis=1) + + # Filter output if requested + if args.filter_output: + pass_cols = [col for col in df_final.columns if col.startswith("passes_")] + if pass_cols: + df_final["passes_all"] = df_final[pass_cols].all(axis=1) + df_final = df_final[df_final["passes_all"]] + print(f"\nFiltered to {len(df_final)} molecules passing all filters") + + # Save results + args.output.parent.mkdir(parents=True, exist_ok=True) + df_final.to_csv(args.output, index=False) + print(f"\nResults saved to: {args.output}") + + # Generate summary + if not args.no_summary: + generate_summary(df_final, args.output) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/molfeat/SKILL.md b/scientific-packages/molfeat/SKILL.md new file mode 100644 index 0000000..1ec60e6 --- /dev/null +++ b/scientific-packages/molfeat/SKILL.md @@ -0,0 +1,516 @@ +--- +name: molfeat +description: Comprehensive molecular featurization toolkit for converting chemical structures into numerical representations for machine learning. Use this skill when working with molecular data, SMILES strings, chemical fingerprints, molecular descriptors, or building QSAR/QSPR models. Provides access to 100+ featurizers including traditional fingerprints (ECFP, MACCS), molecular descriptors (RDKit, Mordred), and pretrained deep learning models (ChemBERTa, ChemGPT, GNN models) for cheminformatics and drug discovery tasks. +--- + +# Molfeat - Molecular Featurization Hub + +## Overview + +Molfeat is a comprehensive Python library for molecular featurization that unifies pre-trained embeddings and hand-crafted featurizers into a single, fast, and user-friendly package. Convert chemical structures (SMILES strings or RDKit molecules) into numerical representations suitable for machine learning tasks including QSAR modeling, virtual screening, similarity searching, and deep learning applications. + +**Key Capabilities:** +- 100+ featurizers including fingerprints, descriptors, and pretrained models +- Fast parallel processing with simple API +- Scikit-learn compatible transformers +- Built-in caching and state persistence +- Integration with PyTorch, TensorFlow, and graph neural networks + +## When to Use This Skill + +Apply molfeat when working with: +- **Molecular machine learning**: Building QSAR/QSPR models, property prediction +- **Virtual screening**: Ranking compound libraries for biological activity +- **Similarity searching**: Finding structurally similar molecules +- **Chemical space analysis**: Clustering, visualization, dimensionality reduction +- **Deep learning**: Training neural networks on molecular data +- **Featurization pipelines**: Converting SMILES to ML-ready representations +- **Cheminformatics**: Any task requiring molecular feature extraction + +## Installation + +```bash +# Recommended: Using conda/mamba +mamba install -c conda-forge molfeat + +# Alternative: Using pip +pip install molfeat + +# With all optional dependencies +pip install "molfeat[all]" +``` + +**Optional dependencies for specific featurizers:** +- `molfeat[dgl]` - GNN models (GIN variants) +- `molfeat[graphormer]` - Graphormer models +- `molfeat[transformer]` - ChemBERTa, ChemGPT, MolT5 +- `molfeat[fcd]` - FCD descriptors +- `molfeat[map4]` - MAP4 fingerprints + +## Core Concepts + +Molfeat organizes featurization into three hierarchical classes: + +### 1. Calculators (`molfeat.calc`) + +Callable objects that convert individual molecules into feature vectors. Accept RDKit `Chem.Mol` objects or SMILES strings. + +**Use calculators for:** +- Single molecule featurization +- Custom processing loops +- Direct feature computation + +**Example:** +```python +from molfeat.calc import FPCalculator + +calc = FPCalculator("ecfp", radius=3, fpSize=2048) +features = calc("CCO") # Returns numpy array (2048,) +``` + +### 2. Transformers (`molfeat.trans`) + +Scikit-learn compatible transformers that wrap calculators for batch processing with parallelization. + +**Use transformers for:** +- Batch featurization of molecular datasets +- Integration with scikit-learn pipelines +- Parallel processing (automatic CPU utilization) + +**Example:** +```python +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator + +transformer = MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1) +features = transformer(smiles_list) # Parallel processing +``` + +### 3. Pretrained Transformers (`molfeat.trans.pretrained`) + +Specialized transformers for deep learning models with batched inference and caching. + +**Use pretrained transformers for:** +- State-of-the-art molecular embeddings +- Transfer learning from large chemical datasets +- Deep learning feature extraction + +**Example:** +```python +from molfeat.trans.pretrained import PretrainedMolTransformer + +transformer = PretrainedMolTransformer("ChemBERTa-77M-MLM", n_jobs=-1) +embeddings = transformer(smiles_list) # Deep learning embeddings +``` + +## Quick Start Workflow + +### Basic Featurization + +```python +import datamol as dm +from molfeat.calc import FPCalculator +from molfeat.trans import MoleculeTransformer + +# Load molecular data +smiles = ["CCO", "CC(=O)O", "c1ccccc1", "CC(C)O"] + +# Create calculator and transformer +calc = FPCalculator("ecfp", radius=3) +transformer = MoleculeTransformer(calc, n_jobs=-1) + +# Featurize molecules +features = transformer(smiles) +print(f"Shape: {features.shape}") # (4, 2048) +``` + +### Save and Load Configuration + +```python +# Save featurizer configuration for reproducibility +transformer.to_state_yaml_file("featurizer_config.yml") + +# Reload exact configuration +loaded = MoleculeTransformer.from_state_yaml_file("featurizer_config.yml") +``` + +### Handle Errors Gracefully + +```python +# Process dataset with potentially invalid SMILES +transformer = MoleculeTransformer( + calc, + n_jobs=-1, + ignore_errors=True, # Continue on failures + verbose=True # Log error details +) + +features = transformer(smiles_with_errors) +# Returns None for failed molecules +``` + +## Choosing the Right Featurizer + +### For Traditional Machine Learning (RF, SVM, XGBoost) + +**Start with fingerprints:** +```python +# ECFP - Most popular, general-purpose +FPCalculator("ecfp", radius=3, fpSize=2048) + +# MACCS - Fast, good for scaffold hopping +FPCalculator("maccs") + +# MAP4 - Efficient for large-scale screening +FPCalculator("map4") +``` + +**For interpretable models:** +```python +# RDKit 2D descriptors (200+ named properties) +from molfeat.calc import RDKitDescriptors2D +RDKitDescriptors2D() + +# Mordred (1800+ comprehensive descriptors) +from molfeat.calc import MordredDescriptors +MordredDescriptors() +``` + +**Combine multiple featurizers:** +```python +from molfeat.trans import FeatConcat + +concat = FeatConcat([ + FPCalculator("maccs"), # 167 dimensions + FPCalculator("ecfp") # 2048 dimensions +]) # Result: 2215-dimensional combined features +``` + +### For Deep Learning + +**Transformer-based embeddings:** +```python +# ChemBERTa - Pre-trained on 77M PubChem compounds +PretrainedMolTransformer("ChemBERTa-77M-MLM") + +# ChemGPT - Autoregressive language model +PretrainedMolTransformer("ChemGPT-1.2B") +``` + +**Graph neural networks:** +```python +# GIN models with different pre-training objectives +PretrainedMolTransformer("gin-supervised-masking") +PretrainedMolTransformer("gin-supervised-infomax") + +# Graphormer for quantum chemistry +PretrainedMolTransformer("Graphormer-pcqm4mv2") +``` + +### For Similarity Searching + +```python +# ECFP - General purpose, most widely used +FPCalculator("ecfp") + +# MACCS - Fast, scaffold-based similarity +FPCalculator("maccs") + +# MAP4 - Efficient for large databases +FPCalculator("map4") + +# USR/USRCAT - 3D shape similarity +from molfeat.calc import USRDescriptors +USRDescriptors() +``` + +### For Pharmacophore-Based Approaches + +```python +# FCFP - Functional group based +FPCalculator("fcfp") + +# CATS - Pharmacophore pair distributions +from molfeat.calc import CATSCalculator +CATSCalculator(mode="2D") + +# Gobbi - Explicit pharmacophore features +FPCalculator("gobbi2D") +``` + +## Common Workflows + +### Building a QSAR Model + +```python +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator +from sklearn.ensemble import RandomForestRegressor +from sklearn.model_selection import cross_val_score + +# Featurize molecules +transformer = MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1) +X = transformer(smiles_train) + +# Train model +model = RandomForestRegressor(n_estimators=100) +scores = cross_val_score(model, X, y_train, cv=5) +print(f"R² = {scores.mean():.3f}") + +# Save configuration for deployment +transformer.to_state_yaml_file("production_featurizer.yml") +``` + +### Virtual Screening Pipeline + +```python +from sklearn.ensemble import RandomForestClassifier + +# Train on known actives/inactives +transformer = MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1) +X_train = transformer(train_smiles) +clf = RandomForestClassifier(n_estimators=500) +clf.fit(X_train, train_labels) + +# Screen large library +X_screen = transformer(screening_library) # e.g., 1M compounds +predictions = clf.predict_proba(X_screen)[:, 1] + +# Rank and select top hits +top_indices = predictions.argsort()[::-1][:1000] +top_hits = [screening_library[i] for i in top_indices] +``` + +### Similarity Search + +```python +from sklearn.metrics.pairwise import cosine_similarity + +# Query molecule +calc = FPCalculator("ecfp") +query_fp = calc(query_smiles).reshape(1, -1) + +# Database fingerprints +transformer = MoleculeTransformer(calc, n_jobs=-1) +database_fps = transformer(database_smiles) + +# Compute similarity +similarities = cosine_similarity(query_fp, database_fps)[0] +top_similar = similarities.argsort()[-10:][::-1] +``` + +### Scikit-learn Pipeline Integration + +```python +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestClassifier + +# Create end-to-end pipeline +pipeline = Pipeline([ + ('featurizer', MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1)), + ('classifier', RandomForestClassifier(n_estimators=100)) +]) + +# Train and predict directly on SMILES +pipeline.fit(smiles_train, y_train) +predictions = pipeline.predict(smiles_test) +``` + +### Comparing Multiple Featurizers + +```python +featurizers = { + 'ECFP': FPCalculator("ecfp"), + 'MACCS': FPCalculator("maccs"), + 'Descriptors': RDKitDescriptors2D(), + 'ChemBERTa': PretrainedMolTransformer("ChemBERTa-77M-MLM") +} + +results = {} +for name, feat in featurizers.items(): + transformer = MoleculeTransformer(feat, n_jobs=-1) + X = transformer(smiles) + # Evaluate with your ML model + score = evaluate_model(X, y) + results[name] = score +``` + +## Discovering Available Featurizers + +Use the ModelStore to explore all available featurizers: + +```python +from molfeat.store.modelstore import ModelStore + +store = ModelStore() + +# List all available models +all_models = store.available_models +print(f"Total featurizers: {len(all_models)}") + +# Search for specific models +chemberta_models = store.search(name="ChemBERTa") +for model in chemberta_models: + print(f"- {model.name}: {model.description}") + +# Get usage information +model_card = store.search(name="ChemBERTa-77M-MLM")[0] +model_card.usage() # Display usage examples + +# Load model +transformer = store.load("ChemBERTa-77M-MLM") +``` + +## Advanced Features + +### Custom Preprocessing + +```python +class CustomTransformer(MoleculeTransformer): + def preprocess(self, mol): + """Custom preprocessing pipeline""" + if isinstance(mol, str): + mol = dm.to_mol(mol) + mol = dm.standardize_mol(mol) + mol = dm.remove_salts(mol) + return mol + +transformer = CustomTransformer(FPCalculator("ecfp"), n_jobs=-1) +``` + +### Batch Processing Large Datasets + +```python +def featurize_in_chunks(smiles_list, transformer, chunk_size=10000): + """Process large datasets in chunks to manage memory""" + all_features = [] + for i in range(0, len(smiles_list), chunk_size): + chunk = smiles_list[i:i+chunk_size] + features = transformer(chunk) + all_features.append(features) + return np.vstack(all_features) +``` + +### Caching Expensive Embeddings + +```python +import pickle + +cache_file = "embeddings_cache.pkl" +transformer = PretrainedMolTransformer("ChemBERTa-77M-MLM", n_jobs=-1) + +try: + with open(cache_file, "rb") as f: + embeddings = pickle.load(f) +except FileNotFoundError: + embeddings = transformer(smiles_list) + with open(cache_file, "wb") as f: + pickle.dump(embeddings, f) +``` + +## Performance Tips + +1. **Use parallelization**: Set `n_jobs=-1` to utilize all CPU cores +2. **Batch processing**: Process multiple molecules at once instead of loops +3. **Choose appropriate featurizers**: Fingerprints are faster than deep learning models +4. **Cache pretrained models**: Leverage built-in caching for repeated use +5. **Use float32**: Set `dtype=np.float32` when precision allows +6. **Handle errors efficiently**: Use `ignore_errors=True` for large datasets + +## Common Featurizers Reference + +**Quick reference for frequently used featurizers:** + +| Featurizer | Type | Dimensions | Speed | Use Case | +|------------|------|------------|-------|----------| +| `ecfp` | Fingerprint | 2048 | Fast | General purpose | +| `maccs` | Fingerprint | 167 | Very fast | Scaffold similarity | +| `desc2D` | Descriptors | 200+ | Fast | Interpretable models | +| `mordred` | Descriptors | 1800+ | Medium | Comprehensive features | +| `map4` | Fingerprint | 1024 | Fast | Large-scale screening | +| `ChemBERTa-77M-MLM` | Deep learning | 768 | Slow* | Transfer learning | +| `gin-supervised-masking` | GNN | Variable | Slow* | Graph-based models | + +*First run is slow; subsequent runs benefit from caching + +## Resources + +This skill includes comprehensive reference documentation: + +### references/api_reference.md +Complete API documentation covering: +- `molfeat.calc` - All calculator classes and parameters +- `molfeat.trans` - Transformer classes and methods +- `molfeat.store` - ModelStore usage +- Common patterns and integration examples +- Performance optimization tips + +**When to load:** Reference when implementing specific calculators, understanding transformer parameters, or integrating with scikit-learn/PyTorch. + +### references/available_featurizers.md +Comprehensive catalog of all 100+ featurizers organized by category: +- Transformer-based language models (ChemBERTa, ChemGPT) +- Graph neural networks (GIN, Graphormer) +- Molecular descriptors (RDKit, Mordred) +- Fingerprints (ECFP, MACCS, MAP4, and 15+ others) +- Pharmacophore descriptors (CATS, Gobbi) +- Shape descriptors (USR, ElectroShape) +- Scaffold-based descriptors + +**When to load:** Reference when selecting the optimal featurizer for a specific task, exploring available options, or understanding featurizer characteristics. + +**Search tip:** Use grep to find specific featurizer types: +```bash +grep -i "chembert" references/available_featurizers.md +grep -i "pharmacophore" references/available_featurizers.md +``` + +### references/examples.md +Practical code examples for common scenarios: +- Installation and quick start +- Calculator and transformer examples +- Pretrained model usage +- Scikit-learn and PyTorch integration +- Virtual screening workflows +- QSAR model building +- Similarity searching +- Troubleshooting and best practices + +**When to load:** Reference when implementing specific workflows, troubleshooting issues, or learning molfeat patterns. + +## Troubleshooting + +### Invalid Molecules +Enable error handling to skip invalid SMILES: +```python +transformer = MoleculeTransformer( + calc, + ignore_errors=True, + verbose=True +) +``` + +### Memory Issues with Large Datasets +Process in chunks or use streaming approaches for datasets > 100K molecules. + +### Pretrained Model Dependencies +Some models require additional packages. Install specific extras: +```bash +pip install "molfeat[transformer]" # For ChemBERTa/ChemGPT +pip install "molfeat[dgl]" # For GIN models +``` + +### Reproducibility +Save exact configurations and document versions: +```python +transformer.to_state_yaml_file("config.yml") +import molfeat +print(f"molfeat version: {molfeat.__version__}") +``` + +## Additional Resources + +- **Official Documentation**: https://molfeat-docs.datamol.io/ +- **GitHub Repository**: https://github.com/datamol-io/molfeat +- **PyPI Package**: https://pypi.org/project/molfeat/ +- **Tutorial**: https://portal.valencelabs.com/datamol/post/types-of-featurizers-b1e8HHrbFMkbun6 diff --git a/scientific-packages/molfeat/references/api_reference.md b/scientific-packages/molfeat/references/api_reference.md new file mode 100644 index 0000000..752b14a --- /dev/null +++ b/scientific-packages/molfeat/references/api_reference.md @@ -0,0 +1,428 @@ +# Molfeat API Reference + +## Core Modules + +Molfeat is organized into several key modules that provide different aspects of molecular featurization: + +- **`molfeat.store`** - Manages model loading, listing, and registration +- **`molfeat.calc`** - Provides calculators for single-molecule featurization +- **`molfeat.trans`** - Offers scikit-learn compatible transformers for batch processing +- **`molfeat.utils`** - Utility functions for data handling +- **`molfeat.viz`** - Visualization tools for molecular features + +--- + +## molfeat.calc - Calculators + +Calculators are callable objects that convert individual molecules into feature vectors. They accept either RDKit `Chem.Mol` objects or SMILES strings as input. + +### SerializableCalculator (Base Class) + +Base abstract class for all calculators. When subclassing, must implement: +- `__call__()` - Required method for featurization +- `__len__()` - Optional, returns output length +- `columns` - Optional property, returns feature names +- `batch_compute()` - Optional, for efficient batch processing + +**State Management Methods:** +- `to_state_json()` - Save calculator state as JSON +- `to_state_yaml()` - Save calculator state as YAML +- `from_state_dict()` - Load calculator from state dictionary +- `to_state_dict()` - Export calculator state as dictionary + +### FPCalculator + +Computes molecular fingerprints. Supports 15+ fingerprint methods. + +**Supported Fingerprint Types:** + +**Structural Fingerprints:** +- `ecfp` - Extended-connectivity fingerprints (circular) +- `fcfp` - Functional-class fingerprints +- `rdkit` - RDKit topological fingerprints +- `maccs` - MACCS keys (166-bit structural keys) +- `avalon` - Avalon fingerprints +- `pattern` - Pattern fingerprints +- `layered` - Layered fingerprints + +**Atom-based Fingerprints:** +- `atompair` - Atom pair fingerprints +- `atompair-count` - Counted atom pairs +- `topological` - Topological torsion fingerprints +- `topological-count` - Counted topological torsions + +**Specialized Fingerprints:** +- `map4` - MinHashed atom-pair fingerprint up to 4 bonds +- `secfp` - SMILES extended connectivity fingerprint +- `erg` - Extended reduced graphs +- `estate` - Electrotopological state indices + +**Parameters:** +- `method` (str) - Fingerprint type name +- `radius` (int) - Radius for circular fingerprints (default: 3) +- `fpSize` (int) - Fingerprint size (default: 2048) +- `includeChirality` (bool) - Include chirality information +- `counting` (bool) - Use count vectors instead of binary + +**Usage:** +```python +from molfeat.calc import FPCalculator + +# Create fingerprint calculator +calc = FPCalculator("ecfp", radius=3, fpSize=2048) + +# Compute fingerprint for single molecule +fp = calc("CCO") # Returns numpy array + +# Get fingerprint length +length = len(calc) # 2048 + +# Get feature names +names = calc.columns +``` + +**Common Fingerprint Dimensions:** +- MACCS: 167 dimensions +- ECFP (default): 2048 dimensions +- MAP4 (default): 1024 dimensions + +### Descriptor Calculators + +**RDKitDescriptors2D** +Computes 2D molecular descriptors using RDKit. + +```python +from molfeat.calc import RDKitDescriptors2D + +calc = RDKitDescriptors2D() +descriptors = calc("CCO") # Returns 200+ descriptors +``` + +**RDKitDescriptors3D** +Computes 3D molecular descriptors (requires conformer generation). + +**MordredDescriptors** +Calculates over 1800 molecular descriptors using Mordred. + +```python +from molfeat.calc import MordredDescriptors + +calc = MordredDescriptors() +descriptors = calc("CCO") +``` + +### Pharmacophore Calculators + +**Pharmacophore2D** +RDKit's 2D pharmacophore fingerprint generation. + +**Pharmacophore3D** +Consensus pharmacophore fingerprints from multiple conformers. + +**CATSCalculator** +Computes Chemically Advanced Template Search (CATS) descriptors - pharmacophore point pair distributions. + +**Parameters:** +- `mode` - "2D" or "3D" distance calculations +- `dist_bins` - Distance bins for pair distributions +- `scale` - Scaling mode: "raw", "num", or "count" + +```python +from molfeat.calc import CATSCalculator + +calc = CATSCalculator(mode="2D", scale="raw") +cats = calc("CCO") # Returns 21 descriptors by default +``` + +### Shape Descriptors + +**USRDescriptors** +Ultrafast shape recognition descriptors (multiple variants). + +**ElectroShapeDescriptors** +Electrostatic shape descriptors combining shape, chirality, and electrostatics. + +### Graph-Based Calculators + +**ScaffoldKeyCalculator** +Computes 40+ scaffold-based molecular properties. + +**AtomCalculator** +Atom-level featurization for graph neural networks. + +**BondCalculator** +Bond-level featurization for graph neural networks. + +### Utility Function + +**get_calculator()** +Factory function to instantiate calculators by name. + +```python +from molfeat.calc import get_calculator + +# Instantiate any calculator by name +calc = get_calculator("ecfp", radius=3) +calc = get_calculator("maccs") +calc = get_calculator("desc2D") +``` + +Raises `ValueError` for unsupported featurizers. + +--- + +## molfeat.trans - Transformers + +Transformers wrap calculators into complete featurization pipelines for batch processing. + +### MoleculeTransformer + +Scikit-learn compatible transformer for batch molecular featurization. + +**Key Parameters:** +- `featurizer` - Calculator or featurizer to use +- `n_jobs` (int) - Number of parallel jobs (-1 for all cores) +- `dtype` - Output data type (numpy float32/64, torch tensors) +- `verbose` (bool) - Enable verbose logging +- `ignore_errors` (bool) - Continue on failures (returns None for failed molecules) + +**Essential Methods:** +- `transform(mols)` - Processes batches and returns representations +- `_transform(mol)` - Handles individual molecule featurization +- `__call__(mols)` - Convenience wrapper around transform() +- `preprocess(mol)` - Prepares input molecules (not automatically applied) +- `to_state_yaml_file(path)` - Save transformer configuration +- `from_state_yaml_file(path)` - Load transformer configuration + +**Usage:** +```python +from molfeat.calc import FPCalculator +from molfeat.trans import MoleculeTransformer +import datamol as dm + +# Load molecules +smiles = dm.data.freesolv().sample(100).smiles.values + +# Create transformer +calc = FPCalculator("ecfp") +transformer = MoleculeTransformer(calc, n_jobs=-1) + +# Featurize batch +features = transformer(smiles) # Returns numpy array (100, 2048) + +# Save configuration +transformer.to_state_yaml_file("ecfp_config.yml") + +# Reload +transformer = MoleculeTransformer.from_state_yaml_file("ecfp_config.yml") +``` + +**Performance:** Testing on 642 molecules showed 3.4x speedup using 4 parallel jobs versus single-threaded processing. + +### FeatConcat + +Concatenates multiple featurizers into unified representations. + +```python +from molfeat.trans import FeatConcat +from molfeat.calc import FPCalculator + +# Combine multiple fingerprints +concat = FeatConcat([ + FPCalculator("maccs"), # 167 dimensions + FPCalculator("ecfp") # 2048 dimensions +]) + +# Result: 2167-dimensional features +transformer = MoleculeTransformer(concat, n_jobs=-1) +features = transformer(smiles) +``` + +### PretrainedMolTransformer + +Subclass of `MoleculeTransformer` for pre-trained deep learning models. + +**Unique Features:** +- `_embed()` - Batched inference for neural networks +- `_convert()` - Transforms SMILES/molecules into model-compatible formats + - SELFIES strings for language models + - DGL graphs for graph neural networks +- Integrated caching system for efficient storage + +**Usage:** +```python +from molfeat.trans.pretrained import PretrainedMolTransformer + +# Load pretrained model +transformer = PretrainedMolTransformer("ChemBERTa-77M-MLM", n_jobs=-1) + +# Generate embeddings +embeddings = transformer(smiles) +``` + +### PrecomputedMolTransformer + +Transformer for cached/precomputed features. + +--- + +## molfeat.store - Model Store + +Manages featurizer discovery, loading, and registration. + +### ModelStore + +Central hub for accessing available featurizers. + +**Key Methods:** +- `available_models` - Property listing all available featurizers +- `search(name=None, **kwargs)` - Search for specific featurizers +- `load(name, **kwargs)` - Load a featurizer by name +- `register(name, card)` - Register custom featurizer + +**Usage:** +```python +from molfeat.store.modelstore import ModelStore + +# Initialize store +store = ModelStore() + +# List all available models +all_models = store.available_models +print(f"Found {len(all_models)} featurizers") + +# Search for specific model +results = store.search(name="ChemBERTa-77M-MLM") +if results: + model_card = results[0] + + # View usage information + model_card.usage() + + # Load the model + transformer = model_card.load() + +# Direct loading +transformer = store.load("ChemBERTa-77M-MLM") +``` + +**ModelCard Attributes:** +- `name` - Model identifier +- `description` - Model description +- `version` - Model version +- `authors` - Model authors +- `tags` - Categorization tags +- `usage()` - Display usage examples +- `load(**kwargs)` - Load the model + +--- + +## Common Patterns + +### Error Handling + +```python +# Enable error tolerance +featurizer = MoleculeTransformer( + calc, + n_jobs=-1, + verbose=True, + ignore_errors=True +) + +# Failed molecules return None +features = featurizer(smiles_with_errors) +``` + +### Data Type Control + +```python +# NumPy float32 (default) +features = transformer(smiles, enforce_dtype=True) + +# PyTorch tensors +import torch +transformer = MoleculeTransformer(calc, dtype=torch.float32) +features = transformer(smiles) +``` + +### Persistence and Reproducibility + +```python +# Save transformer state +transformer.to_state_yaml_file("config.yml") +transformer.to_state_json_file("config.json") + +# Load from saved state +transformer = MoleculeTransformer.from_state_yaml_file("config.yml") +transformer = MoleculeTransformer.from_state_json_file("config.json") +``` + +### Preprocessing + +```python +# Manual preprocessing +mol = transformer.preprocess("CCO") + +# Transform with preprocessing +features = transformer.transform(smiles_list) +``` + +--- + +## Integration Examples + +### Scikit-learn Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestClassifier +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator + +# Create pipeline +pipeline = Pipeline([ + ('featurizer', MoleculeTransformer(FPCalculator("ecfp"))), + ('classifier', RandomForestClassifier()) +]) + +# Fit and predict +pipeline.fit(smiles_train, y_train) +predictions = pipeline.predict(smiles_test) +``` + +### PyTorch Integration + +```python +import torch +from torch.utils.data import Dataset, DataLoader +from molfeat.trans import MoleculeTransformer + +class MoleculeDataset(Dataset): + def __init__(self, smiles, labels, transformer): + self.smiles = smiles + self.labels = labels + self.transformer = transformer + + def __len__(self): + return len(self.smiles) + + def __getitem__(self, idx): + features = self.transformer(self.smiles[idx]) + return torch.tensor(features), torch.tensor(self.labels[idx]) + +# Create dataset and dataloader +transformer = MoleculeTransformer(FPCalculator("ecfp")) +dataset = MoleculeDataset(smiles, labels, transformer) +loader = DataLoader(dataset, batch_size=32) +``` + +--- + +## Performance Tips + +1. **Parallelization**: Use `n_jobs=-1` to utilize all CPU cores +2. **Batch Processing**: Process multiple molecules at once instead of loops +3. **Caching**: Leverage built-in caching for pretrained models +4. **Data Types**: Use float32 instead of float64 when precision allows +5. **Error Handling**: Set `ignore_errors=True` for large datasets with potential invalid molecules diff --git a/scientific-packages/molfeat/references/available_featurizers.md b/scientific-packages/molfeat/references/available_featurizers.md new file mode 100644 index 0000000..08e0019 --- /dev/null +++ b/scientific-packages/molfeat/references/available_featurizers.md @@ -0,0 +1,333 @@ +# Available Featurizers in Molfeat + +This document provides a comprehensive catalog of all featurizers available in molfeat, organized by category. + +## Transformer-Based Language Models + +Pre-trained transformer models for molecular embeddings using SMILES/SELFIES representations. + +### RoBERTa-style Models +- **Roberta-Zinc480M-102M** - RoBERTa masked language model trained on ~480M SMILES strings from ZINC database +- **ChemBERTa-77M-MLM** - Masked language model based on RoBERTa trained on 77M PubChem compounds +- **ChemBERTa-77M-MTR** - Multitask regression version trained on PubChem compounds + +### GPT-style Autoregressive Models +- **GPT2-Zinc480M-87M** - GPT-2 autoregressive language model trained on ~480M SMILES from ZINC +- **ChemGPT-1.2B** - Large transformer (1.2B parameters) pretrained on PubChem10M +- **ChemGPT-19M** - Medium transformer (19M parameters) pretrained on PubChem10M +- **ChemGPT-4.7M** - Small transformer (4.7M parameters) pretrained on PubChem10M + +### Specialized Transformer Models +- **MolT5** - Self-supervised framework for molecule captioning and text-based generation + +## Graph Neural Networks (GNNs) + +Pre-trained graph neural network models operating on molecular graph structures. + +### GIN (Graph Isomorphism Network) Variants +All pre-trained on ChEMBL molecules with different objectives: +- **gin-supervised-masking** - Supervised with node masking objective +- **gin-supervised-infomax** - Supervised with graph-level mutual information maximization +- **gin-supervised-edgepred** - Supervised with edge prediction objective +- **gin-supervised-contextpred** - Supervised with context prediction objective + +### Other Graph-Based Models +- **JTVAE_zinc_no_kl** - Junction-tree VAE for molecule generation (trained on ZINC) +- **Graphormer-pcqm4mv2** - Graph transformer pretrained on PCQM4Mv2 quantum chemistry dataset for HOMO-LUMO gap prediction + +## Molecular Descriptors + +Calculators for physico-chemical properties and molecular characteristics. + +### 2D Descriptors +- **desc2D** / **rdkit2D** - 200+ RDKit 2D molecular descriptors including: + - Molecular weight, logP, TPSA + - H-bond donors/acceptors + - Rotatable bonds + - Ring counts and aromaticity + - Molecular complexity metrics + +### 3D Descriptors +- **desc3D** / **rdkit3D** - RDKit 3D molecular descriptors (requires conformer generation) + - Inertial moments + - PMI (Principal Moments of Inertia) ratios + - Asphericity, eccentricity + - Radius of gyration + +### Comprehensive Descriptor Sets +- **mordred** - Over 1800 molecular descriptors covering: + - Constitutional descriptors + - Topological indices + - Connectivity indices + - Information content + - 2D/3D autocorrelations + - WHIM descriptors + - GETAWAY descriptors + - And many more + +### Electrotopological Descriptors +- **estate** - Electrotopological state (E-State) indices encoding: + - Atomic environment information + - Electronic and topological properties + - Heteroatom contributions + +## Molecular Fingerprints + +Binary or count-based fixed-length vectors representing molecular substructures. + +### Circular Fingerprints (ECFP-style) +- **ecfp** / **ecfp:2** / **ecfp:4** / **ecfp:6** - Extended-connectivity fingerprints + - Radius variants (2, 4, 6 correspond to diameter) + - Default: radius=3, 2048 bits + - Most popular for similarity searching +- **ecfp-count** - Count version of ECFP (non-binary) +- **fcfp** / **fcfp-count** - Functional-class circular fingerprints + - Similar to ECFP but uses functional groups + - Better for pharmacophore-based similarity + +### Path-Based Fingerprints +- **rdkit** - RDKit topological fingerprints based on linear paths +- **pattern** - Pattern fingerprints (similar to MACCS but automated) +- **layered** - Layered fingerprints with multiple substructure layers + +### Key-Based Fingerprints +- **maccs** - MACCS keys (166-bit structural keys) + - Fixed set of predefined substructures + - Good for scaffold hopping + - Fast computation +- **avalon** - Avalon fingerprints + - Similar to MACCS but more features + - Optimized for similarity searching + +### Atom-Pair Fingerprints +- **atompair** - Atom pair fingerprints + - Encodes pairs of atoms and distance between them + - Good for 3D similarity +- **atompair-count** - Count version of atom pairs + +### Topological Torsion Fingerprints +- **topological** - Topological torsion fingerprints + - Encodes sequences of 4 connected atoms + - Captures local topology +- **topological-count** - Count version of topological torsions + +### MinHashed Fingerprints +- **map4** - MinHashed Atom-Pair fingerprint up to 4 bonds + - Combines atom-pair and ECFP concepts + - Default: 1024 dimensions + - Fast and efficient for large datasets +- **secfp** - SMILES Extended Connectivity Fingerprint + - Operates directly on SMILES strings + - Captures both substructure and atom-pair information + +### Extended Reduced Graph +- **erg** - Extended Reduced Graph + - Uses pharmacophoric points instead of atoms + - Reduces graph complexity while preserving key features + +## Pharmacophore Descriptors + +Features based on pharmacologically relevant functional groups and their spatial relationships. + +### CATS (Chemically Advanced Template Search) +- **cats2D** - 2D CATS descriptors + - Pharmacophore point pair distributions + - Distance based on shortest path + - 21 descriptors by default +- **cats3D** - 3D CATS descriptors + - Euclidean distance based + - Requires conformer generation +- **cats2D_pharm** / **cats3D_pharm** - Pharmacophore variants + +### Gobbi Pharmacophores +- **gobbi2D** - 2D pharmacophore fingerprints + - 8 pharmacophore feature types: + - Hydrophobic + - Aromatic + - H-bond acceptor + - H-bond donor + - Positive ionizable + - Negative ionizable + - Lumped hydrophobe + - Good for virtual screening + +### Pmapper Pharmacophores +- **pmapper2D** - 2D pharmacophore signatures +- **pmapper3D** - 3D pharmacophore signatures + - High-dimensional pharmacophore descriptors + - Useful for QSAR and similarity searching + +## Shape Descriptors + +Descriptors capturing 3D molecular shape and electrostatic properties. + +### USR (Ultrafast Shape Recognition) +- **usr** - Basic USR descriptors + - 12 dimensions encoding shape distribution + - Extremely fast computation +- **usrcat** - USR with pharmacophoric constraints + - 60 dimensions (12 per feature type) + - Combines shape and pharmacophore information + +### Electrostatic Shape +- **electroshape** - ElectroShape descriptors + - Combines molecular shape, chirality, and electrostatics + - Useful for protein-ligand docking predictions + +## Scaffold-Based Descriptors + +Descriptors based on molecular scaffolds and core structures. + +### Scaffold Keys +- **scaffoldkeys** - Scaffold key calculator + - 40+ scaffold-based properties + - Bioisosteric scaffold representation + - Captures core structural features + +## Graph Featurizers for GNN Input + +Atom and bond-level features for constructing graph representations for Graph Neural Networks. + +### Atom-Level Features +- **atom-onehot** - One-hot encoded atom features +- **atom-default** - Default atom featurization including: + - Atomic number + - Degree, formal charge + - Hybridization + - Aromaticity + - Number of hydrogen atoms + +### Bond-Level Features +- **bond-onehot** - One-hot encoded bond features +- **bond-default** - Default bond featurization including: + - Bond type (single, double, triple, aromatic) + - Conjugation + - Ring membership + - Stereochemistry + +## Integrated Pretrained Model Collections + +Molfeat integrates models from various sources: + +### HuggingFace Models +Access to transformer models through HuggingFace hub: +- ChemBERTa variants +- ChemGPT variants +- MolT5 +- Custom uploaded models + +### DGL-LifeSci Models +Pre-trained GNN models from DGL-Life: +- GIN variants with different pre-training tasks +- AttentiveFP models +- MPNN models + +### FCD (Fréchet ChemNet Distance) +- **fcd** - Pre-trained CNN for molecular generation evaluation + +### Graphormer Models +- Graph transformers from Microsoft Research +- Pre-trained on quantum chemistry datasets + +## Usage Notes + +### Choosing a Featurizer + +**For traditional ML (Random Forest, SVM, etc.):** +- Start with **ecfp** or **maccs** fingerprints +- Try **desc2D** for interpretable models +- Use **FeatConcat** to combine multiple fingerprints + +**For deep learning:** +- Use **ChemBERTa** or **ChemGPT** for transformer embeddings +- Use **gin-supervised-*** for graph neural network embeddings +- Consider **Graphormer** for quantum property predictions + +**For similarity searching:** +- **ecfp** - General purpose, most popular +- **maccs** - Fast, good for scaffold hopping +- **map4** - Efficient for large-scale searches +- **usr** / **usrcat** - 3D shape similarity + +**For pharmacophore-based approaches:** +- **fcfp** - Functional group based +- **cats2D/3D** - Pharmacophore pair distributions +- **gobbi2D** - Explicit pharmacophore features + +**For interpretability:** +- **desc2D** / **mordred** - Named descriptors +- **maccs** - Interpretable substructure keys +- **scaffoldkeys** - Scaffold-based features + +### Model Dependencies + +Some featurizers require optional dependencies: + +- **DGL models** (gin-*, jtvae): `pip install "molfeat[dgl]"` +- **Graphormer**: `pip install "molfeat[graphormer]"` +- **Transformers** (ChemBERTa, ChemGPT, MolT5): `pip install "molfeat[transformer]"` +- **FCD**: `pip install "molfeat[fcd]"` +- **MAP4**: `pip install "molfeat[map4]"` +- **All dependencies**: `pip install "molfeat[all]"` + +### Accessing All Available Models + +```python +from molfeat.store.modelstore import ModelStore + +store = ModelStore() +all_models = store.available_models + +# Print all available featurizers +for model in all_models: + print(f"{model.name}: {model.description}") + +# Search for specific types +transformers = [m for m in all_models if "transformer" in m.tags] +gnn_models = [m for m in all_models if "gnn" in m.tags] +fingerprints = [m for m in all_models if "fingerprint" in m.tags] +``` + +## Performance Characteristics + +### Computational Speed (relative) +**Fastest:** +- maccs +- ecfp +- rdkit fingerprints +- usr + +**Medium:** +- desc2D +- cats2D +- Most fingerprints + +**Slower:** +- mordred (1800+ descriptors) +- desc3D (requires conformer generation) +- 3D descriptors in general + +**Slowest (first run):** +- Pretrained models (ChemBERTa, ChemGPT, GIN) +- Note: Subsequent runs benefit from caching + +### Dimensionality + +**Low (< 200 dims):** +- maccs (167) +- usr (12) +- usrcat (60) + +**Medium (200-2000 dims):** +- desc2D (~200) +- ecfp (2048 default, configurable) +- map4 (1024 default) + +**High (> 2000 dims):** +- mordred (1800+) +- Concatenated fingerprints +- Some transformer embeddings + +**Variable:** +- Transformer models (typically 768-1024) +- GNN models (depends on architecture) diff --git a/scientific-packages/molfeat/references/examples.md b/scientific-packages/molfeat/references/examples.md new file mode 100644 index 0000000..16c937b --- /dev/null +++ b/scientific-packages/molfeat/references/examples.md @@ -0,0 +1,723 @@ +# Molfeat Usage Examples + +This document provides practical examples for common molfeat use cases. + +## Installation + +```bash +# Recommended: Using conda/mamba +mamba install -c conda-forge molfeat + +# Alternative: Using pip +pip install molfeat + +# With all optional dependencies +pip install "molfeat[all]" + +# With specific dependencies +pip install "molfeat[dgl]" # For GNN models +pip install "molfeat[graphormer]" # For Graphormer +pip install "molfeat[transformer]" # For ChemBERTa, ChemGPT +``` + +--- + +## Quick Start + +### Basic Featurization Workflow + +```python +import datamol as dm +from molfeat.calc import FPCalculator +from molfeat.trans import MoleculeTransformer + +# Load sample data +data = dm.data.freesolv().sample(100).smiles.values + +# Single molecule featurization +calc = FPCalculator("ecfp") +features_single = calc(data[0]) +print(f"Single molecule features shape: {features_single.shape}") +# Output: (2048,) + +# Batch featurization with parallelization +transformer = MoleculeTransformer(calc, n_jobs=-1) +features_batch = transformer(data) +print(f"Batch features shape: {features_batch.shape}") +# Output: (100, 2048) +``` + +--- + +## Calculator Examples + +### Fingerprint Calculators + +```python +from molfeat.calc import FPCalculator + +# ECFP (Extended-Connectivity Fingerprints) +ecfp = FPCalculator("ecfp", radius=3, fpSize=2048) +fp = ecfp("CCO") # Ethanol +print(f"ECFP shape: {fp.shape}") # (2048,) + +# MACCS keys +maccs = FPCalculator("maccs") +fp = maccs("c1ccccc1") # Benzene +print(f"MACCS shape: {fp.shape}") # (167,) + +# Count-based fingerprints +ecfp_count = FPCalculator("ecfp-count", radius=3) +fp_count = ecfp_count("CC(C)CC(C)C") # Non-binary counts + +# MAP4 fingerprints +map4 = FPCalculator("map4") +fp = map4("CC(=O)Oc1ccccc1C(=O)O") # Aspirin +``` + +### Descriptor Calculators + +```python +from molfeat.calc import RDKitDescriptors2D, MordredDescriptors + +# RDKit 2D descriptors (200+ properties) +desc2d = RDKitDescriptors2D() +descriptors = desc2d("CCO") +print(f"Number of 2D descriptors: {len(descriptors)}") + +# Get descriptor names +names = desc2d.columns +print(f"First 5 descriptors: {names[:5]}") + +# Mordred descriptors (1800+ properties) +mordred = MordredDescriptors() +descriptors = mordred("c1ccccc1O") # Phenol +print(f"Mordred descriptors: {len(descriptors)}") +``` + +### Pharmacophore Calculators + +```python +from molfeat.calc import CATSCalculator + +# 2D CATS descriptors +cats = CATSCalculator(mode="2D", scale="raw") +descriptors = cats("CC(C)Cc1ccc(C)cc1C") # Cymene +print(f"CATS descriptors: {descriptors.shape}") # (21,) + +# 3D CATS descriptors (requires conformer) +cats3d = CATSCalculator(mode="3D", scale="num") +``` + +--- + +## Transformer Examples + +### Basic Transformer Usage + +```python +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator +import datamol as dm + +# Prepare data +smiles_list = [ + "CCO", + "CC(=O)O", + "c1ccccc1", + "CC(C)O", + "CCCC" +] + +# Create transformer +calc = FPCalculator("ecfp") +transformer = MoleculeTransformer(calc, n_jobs=-1) + +# Transform molecules +features = transformer(smiles_list) +print(f"Features shape: {features.shape}") # (5, 2048) +``` + +### Error Handling + +```python +# Handle invalid SMILES gracefully +smiles_with_errors = [ + "CCO", # Valid + "invalid", # Invalid + "CC(=O)O", # Valid + "xyz123", # Invalid +] + +transformer = MoleculeTransformer( + FPCalculator("ecfp"), + n_jobs=-1, + verbose=True, # Log errors + ignore_errors=True # Continue on failure +) + +features = transformer(smiles_with_errors) +# Returns: array with None for failed molecules +print(features) # [array(...), None, array(...), None] +``` + +### Concatenating Multiple Featurizers + +```python +from molfeat.trans import FeatConcat, MoleculeTransformer +from molfeat.calc import FPCalculator + +# Combine MACCS (167) + ECFP (2048) = 2215 dimensions +concat_calc = FeatConcat([ + FPCalculator("maccs"), + FPCalculator("ecfp", radius=3, fpSize=2048) +]) + +transformer = MoleculeTransformer(concat_calc, n_jobs=-1) +features = transformer(smiles_list) +print(f"Combined features shape: {features.shape}") # (n, 2215) + +# Triple combination +triple_concat = FeatConcat([ + FPCalculator("maccs"), + FPCalculator("ecfp"), + FPCalculator("rdkit") +]) +``` + +### Saving and Loading Configurations + +```python +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator + +# Create and save transformer +transformer = MoleculeTransformer( + FPCalculator("ecfp", radius=3, fpSize=2048), + n_jobs=-1 +) + +# Save to YAML +transformer.to_state_yaml_file("my_featurizer.yml") + +# Save to JSON +transformer.to_state_json_file("my_featurizer.json") + +# Load from saved state +loaded_transformer = MoleculeTransformer.from_state_yaml_file("my_featurizer.yml") + +# Use loaded transformer +features = loaded_transformer(smiles_list) +``` + +--- + +## Pretrained Model Examples + +### Using the ModelStore + +```python +from molfeat.store.modelstore import ModelStore + +# Initialize model store +store = ModelStore() + +# List all available models +print(f"Total available models: {len(store.available_models)}") + +# Search for specific models +chemberta_models = store.search(name="ChemBERTa") +for model in chemberta_models: + print(f"- {model.name}: {model.description}") + +# Get model information +model_card = store.search(name="ChemBERTa-77M-MLM")[0] +print(f"Model: {model_card.name}") +print(f"Version: {model_card.version}") +print(f"Authors: {model_card.authors}") + +# View usage instructions +model_card.usage() + +# Load model directly +transformer = store.load("ChemBERTa-77M-MLM") +``` + +### ChemBERTa Embeddings + +```python +from molfeat.trans.pretrained import PretrainedMolTransformer + +# Load ChemBERTa model +chemberta = PretrainedMolTransformer("ChemBERTa-77M-MLM", n_jobs=-1) + +# Generate embeddings +smiles = ["CCO", "CC(=O)O", "c1ccccc1"] +embeddings = chemberta(smiles) +print(f"ChemBERTa embeddings shape: {embeddings.shape}") +# Output: (3, 768) - 768-dimensional embeddings + +# Use in ML pipeline +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split + +X_train, X_test, y_train, y_test = train_test_split( + embeddings, labels, test_size=0.2 +) + +clf = RandomForestClassifier() +clf.fit(X_train, y_train) +predictions = clf.predict(X_test) +``` + +### ChemGPT Models + +```python +# Small model (4.7M parameters) +chemgpt_small = PretrainedMolTransformer("ChemGPT-4.7M", n_jobs=-1) + +# Medium model (19M parameters) +chemgpt_medium = PretrainedMolTransformer("ChemGPT-19M", n_jobs=-1) + +# Large model (1.2B parameters) +chemgpt_large = PretrainedMolTransformer("ChemGPT-1.2B", n_jobs=-1) + +# Generate embeddings +embeddings = chemgpt_small(smiles) +``` + +### Graph Neural Network Models + +```python +# GIN models with different pre-training objectives +gin_masking = PretrainedMolTransformer("gin-supervised-masking", n_jobs=-1) +gin_infomax = PretrainedMolTransformer("gin-supervised-infomax", n_jobs=-1) +gin_edgepred = PretrainedMolTransformer("gin-supervised-edgepred", n_jobs=-1) + +# Generate graph embeddings +embeddings = gin_masking(smiles) +print(f"GIN embeddings shape: {embeddings.shape}") + +# Graphormer (for quantum chemistry) +graphormer = PretrainedMolTransformer("Graphormer-pcqm4mv2", n_jobs=-1) +embeddings = graphormer(smiles) +``` + +--- + +## Machine Learning Integration + +### Scikit-learn Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import cross_val_score +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator + +# Create ML pipeline +pipeline = Pipeline([ + ('featurizer', MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1)), + ('classifier', RandomForestClassifier(n_estimators=100)) +]) + +# Train and evaluate +pipeline.fit(smiles_train, y_train) +predictions = pipeline.predict(smiles_test) + +# Cross-validation +scores = cross_val_score(pipeline, smiles_all, y_all, cv=5) +print(f"CV scores: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +### Grid Search for Hyperparameter Tuning + +```python +from sklearn.model_selection import GridSearchCV +from sklearn.svm import SVC + +# Define pipeline +pipeline = Pipeline([ + ('featurizer', MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1)), + ('classifier', SVC()) +]) + +# Define parameter grid +param_grid = { + 'classifier__C': [0.1, 1, 10], + 'classifier__kernel': ['rbf', 'linear'], + 'classifier__gamma': ['scale', 'auto'] +} + +# Grid search +grid_search = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1) +grid_search.fit(smiles_train, y_train) + +print(f"Best parameters: {grid_search.best_params_}") +print(f"Best score: {grid_search.best_score_:.3f}") +``` + +### Multiple Featurizer Comparison + +```python +from sklearn.metrics import roc_auc_score + +# Test different featurizers +featurizers = { + 'ECFP': FPCalculator("ecfp"), + 'MACCS': FPCalculator("maccs"), + 'RDKit': FPCalculator("rdkit"), + 'Descriptors': RDKitDescriptors2D(), + 'Combined': FeatConcat([ + FPCalculator("maccs"), + FPCalculator("ecfp") + ]) +} + +results = {} +for name, calc in featurizers.items(): + transformer = MoleculeTransformer(calc, n_jobs=-1) + X_train = transformer(smiles_train) + X_test = transformer(smiles_test) + + clf = RandomForestClassifier(n_estimators=100) + clf.fit(X_train, y_train) + + y_pred = clf.predict_proba(X_test)[:, 1] + auc = roc_auc_score(y_test, y_pred) + results[name] = auc + + print(f"{name}: AUC = {auc:.3f}") +``` + +### PyTorch Deep Learning + +```python +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator + +# Custom dataset +class MoleculeDataset(Dataset): + def __init__(self, smiles, labels, transformer): + self.features = transformer(smiles) + self.labels = torch.tensor(labels, dtype=torch.float32) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return ( + torch.tensor(self.features[idx], dtype=torch.float32), + self.labels[idx] + ) + +# Prepare data +transformer = MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1) +train_dataset = MoleculeDataset(smiles_train, y_train, transformer) +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + +# Simple neural network +class MoleculeClassifier(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.network = nn.Sequential( + nn.Linear(input_dim, 512), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(512, 256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.network(x) + +# Train model +model = MoleculeClassifier(input_dim=2048) +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +criterion = nn.BCELoss() + +for epoch in range(10): + for batch_features, batch_labels in train_loader: + optimizer.zero_grad() + outputs = model(batch_features).squeeze() + loss = criterion(outputs, batch_labels) + loss.backward() + optimizer.step() +``` + +--- + +## Advanced Usage Patterns + +### Custom Preprocessing + +```python +from molfeat.trans import MoleculeTransformer +import datamol as dm + +class CustomTransformer(MoleculeTransformer): + def preprocess(self, mol): + """Custom preprocessing: standardize molecule""" + if isinstance(mol, str): + mol = dm.to_mol(mol) + + # Standardize + mol = dm.standardize_mol(mol) + + # Remove salts + mol = dm.remove_salts(mol) + + return mol + +# Use custom transformer +transformer = CustomTransformer(FPCalculator("ecfp"), n_jobs=-1) +features = transformer(smiles_list) +``` + +### Featurization with Conformers + +```python +import datamol as dm +from molfeat.calc import RDKitDescriptors3D + +# Generate conformers +def prepare_3d_mol(smiles): + mol = dm.to_mol(smiles) + mol = dm.add_hs(mol) + mol = dm.conform.generate_conformers(mol, n_confs=1) + return mol + +# 3D descriptors +calc_3d = RDKitDescriptors3D() + +smiles = "CC(C)Cc1ccc(C)cc1C" +mol_3d = prepare_3d_mol(smiles) +descriptors_3d = calc_3d(mol_3d) +``` + +### Parallel Batch Processing + +```python +from molfeat.trans import MoleculeTransformer +from molfeat.calc import FPCalculator +import time + +# Large dataset +smiles_large = load_large_dataset() # e.g., 100,000 molecules + +# Test different parallelization levels +for n_jobs in [1, 2, 4, -1]: + transformer = MoleculeTransformer( + FPCalculator("ecfp"), + n_jobs=n_jobs + ) + + start = time.time() + features = transformer(smiles_large) + elapsed = time.time() - start + + print(f"n_jobs={n_jobs}: {elapsed:.2f}s") +``` + +### Caching for Expensive Operations + +```python +from molfeat.trans.pretrained import PretrainedMolTransformer +import pickle + +# Load expensive pretrained model +transformer = PretrainedMolTransformer("ChemBERTa-77M-MLM", n_jobs=-1) + +# Cache embeddings for reuse +cache_file = "embeddings_cache.pkl" + +try: + # Try loading cached embeddings + with open(cache_file, "rb") as f: + embeddings = pickle.load(f) + print("Loaded cached embeddings") +except FileNotFoundError: + # Compute and cache + embeddings = transformer(smiles_list) + with open(cache_file, "wb") as f: + pickle.dump(embeddings, f) + print("Computed and cached embeddings") +``` + +--- + +## Common Workflows + +### Virtual Screening Workflow + +```python +from molfeat.calc import FPCalculator +from sklearn.ensemble import RandomForestClassifier +import datamol as dm + +# 1. Prepare training data (known actives/inactives) +train_smiles = load_training_data() +train_labels = load_training_labels() # 1=active, 0=inactive + +# 2. Featurize training set +transformer = MoleculeTransformer(FPCalculator("ecfp"), n_jobs=-1) +X_train = transformer(train_smiles) + +# 3. Train classifier +clf = RandomForestClassifier(n_estimators=500, n_jobs=-1) +clf.fit(X_train, train_labels) + +# 4. Featurize screening library +screening_smiles = load_screening_library() # e.g., 1M compounds +X_screen = transformer(screening_smiles) + +# 5. Predict and rank +predictions = clf.predict_proba(X_screen)[:, 1] +ranked_indices = predictions.argsort()[::-1] + +# 6. Get top hits +top_n = 1000 +top_hits = [screening_smiles[i] for i in ranked_indices[:top_n]] +``` + +### QSAR Model Building + +```python +from molfeat.calc import RDKitDescriptors2D +from sklearn.linear_model import Ridge +from sklearn.preprocessing import StandardScaler +from sklearn.model_selection import cross_val_score +import numpy as np + +# Load QSAR dataset +smiles = load_molecules() +y = load_activity_values() # e.g., IC50, logP + +# Featurize with interpretable descriptors +transformer = MoleculeTransformer(RDKitDescriptors2D(), n_jobs=-1) +X = transformer(smiles) + +# Standardize features +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# Build linear model +model = Ridge(alpha=1.0) +scores = cross_val_score(model, X_scaled, y, cv=5, scoring='r2') +print(f"R² = {scores.mean():.3f} (+/- {scores.std():.3f})") + +# Fit final model +model.fit(X_scaled, y) + +# Interpret feature importance +feature_names = transformer.featurizer.columns +importance = np.abs(model.coef_) +top_features_idx = importance.argsort()[-10:][::-1] + +print("Top 10 important features:") +for idx in top_features_idx: + print(f" {feature_names[idx]}: {model.coef_[idx]:.3f}") +``` + +### Similarity Search + +```python +from molfeat.calc import FPCalculator +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np + +# Query molecule +query_smiles = "CC(=O)Oc1ccccc1C(=O)O" # Aspirin + +# Database of molecules +database_smiles = load_molecule_database() # Large collection + +# Compute fingerprints +calc = FPCalculator("ecfp") +query_fp = calc(query_smiles).reshape(1, -1) + +transformer = MoleculeTransformer(calc, n_jobs=-1) +database_fps = transformer(database_smiles) + +# Compute similarity +similarities = cosine_similarity(query_fp, database_fps)[0] + +# Find most similar +top_k = 10 +top_indices = similarities.argsort()[-top_k:][::-1] + +print(f"Top {top_k} similar molecules:") +for i, idx in enumerate(top_indices, 1): + print(f"{i}. {database_smiles[idx]} (similarity: {similarities[idx]:.3f})") +``` + +--- + +## Troubleshooting + +### Handling Invalid Molecules + +```python +# Use ignore_errors to skip invalid molecules +transformer = MoleculeTransformer( + FPCalculator("ecfp"), + ignore_errors=True, + verbose=True +) + +# Filter out None values after transformation +features = transformer(smiles_list) +valid_mask = [f is not None for f in features] +valid_features = [f for f in features if f is not None] +valid_smiles = [s for s, m in zip(smiles_list, valid_mask) if m] +``` + +### Memory Management for Large Datasets + +```python +# Process in chunks for very large datasets +def featurize_in_chunks(smiles_list, transformer, chunk_size=10000): + all_features = [] + + for i in range(0, len(smiles_list), chunk_size): + chunk = smiles_list[i:i+chunk_size] + features = transformer(chunk) + all_features.append(features) + print(f"Processed {i+len(chunk)}/{len(smiles_list)}") + + return np.vstack(all_features) + +# Use with large dataset +features = featurize_in_chunks(large_smiles_list, transformer) +``` + +### Reproducibility + +```python +import random +import numpy as np +import torch + +# Set all random seeds +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +set_seed(42) + +# Save exact configuration +transformer.to_state_yaml_file("config.yml") + +# Document version +import molfeat +print(f"molfeat version: {molfeat.__version__}") +``` diff --git a/scientific-packages/polars/SKILL.md b/scientific-packages/polars/SKILL.md new file mode 100644 index 0000000..f69834e --- /dev/null +++ b/scientific-packages/polars/SKILL.md @@ -0,0 +1,381 @@ +--- +name: polars +description: This skill should be used when working with the Polars DataFrame library for high-performance data manipulation in Python. Use when users ask about Polars operations, migrating from pandas, optimizing data processing pipelines, or working with large datasets that benefit from lazy evaluation and parallel processing. +--- + +# Polars + +## Overview + +Polars is a lightning-fast DataFrame library for Python (and Rust) built on Apache Arrow. This skill provides guidance for working with Polars, including its expression-based API, lazy evaluation framework, and high-performance data manipulation capabilities. Use this skill when helping users write efficient data processing code, migrate from pandas, or optimize data pipelines. + +## Quick Start + +### Installation and Basic Usage + +Install Polars: +```python +pip install polars +``` + +Basic DataFrame creation and operations: +```python +import polars as pl + +# Create DataFrame +df = pl.DataFrame({ + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "city": ["NY", "LA", "SF"] +}) + +# Select columns +df.select("name", "age") + +# Filter rows +df.filter(pl.col("age") > 25) + +# Add computed columns +df.with_columns( + age_plus_10=pl.col("age") + 10 +) +``` + +## Core Concepts + +### Expressions + +Expressions are the fundamental building blocks of Polars operations. They describe transformations on data and can be composed, reused, and optimized. + +**Key principles:** +- Use `pl.col("column_name")` to reference columns +- Chain methods to build complex transformations +- Expressions are lazy and only execute within contexts (select, with_columns, filter, group_by) + +**Example:** +```python +# Expression-based computation +df.select( + pl.col("name"), + (pl.col("age") * 12).alias("age_in_months") +) +``` + +### Lazy vs Eager Evaluation + +**Eager (DataFrame):** Operations execute immediately +```python +df = pl.read_csv("file.csv") # Reads immediately +result = df.filter(pl.col("age") > 25) # Executes immediately +``` + +**Lazy (LazyFrame):** Operations build a query plan, optimized before execution +```python +lf = pl.scan_csv("file.csv") # Doesn't read yet +result = lf.filter(pl.col("age") > 25).select("name", "age") +df = result.collect() # Now executes optimized query +``` + +**When to use lazy:** +- Working with large datasets +- Complex query pipelines +- When only some columns/rows are needed +- Performance is critical + +**Benefits of lazy evaluation:** +- Automatic query optimization +- Predicate pushdown +- Projection pushdown +- Parallel execution + +For detailed concepts, load `references/core_concepts.md`. + +## Common Operations + +### Select +Select and manipulate columns: +```python +# Select specific columns +df.select("name", "age") + +# Select with expressions +df.select( + pl.col("name"), + (pl.col("age") * 2).alias("double_age") +) + +# Select all columns matching a pattern +df.select(pl.col("^.*_id$")) +``` + +### Filter +Filter rows by conditions: +```python +# Single condition +df.filter(pl.col("age") > 25) + +# Multiple conditions (cleaner than using &) +df.filter( + pl.col("age") > 25, + pl.col("city") == "NY" +) + +# Complex conditions +df.filter( + (pl.col("age") > 25) | (pl.col("city") == "LA") +) +``` + +### With Columns +Add or modify columns while preserving existing ones: +```python +# Add new columns +df.with_columns( + age_plus_10=pl.col("age") + 10, + name_upper=pl.col("name").str.to_uppercase() +) + +# Parallel computation (all columns computed in parallel) +df.with_columns( + pl.col("value") * 10, + pl.col("value") * 100, +) +``` + +### Group By and Aggregations +Group data and compute aggregations: +```python +# Basic grouping +df.group_by("city").agg( + pl.col("age").mean().alias("avg_age"), + pl.len().alias("count") +) + +# Multiple group keys +df.group_by("city", "department").agg( + pl.col("salary").sum() +) + +# Conditional aggregations +df.group_by("city").agg( + (pl.col("age") > 30).sum().alias("over_30") +) +``` + +For detailed operation patterns, load `references/operations.md`. + +## Aggregations and Window Functions + +### Aggregation Functions +Common aggregations within `group_by` context: +- `pl.len()` - count rows +- `pl.col("x").sum()` - sum values +- `pl.col("x").mean()` - average +- `pl.col("x").min()` / `pl.col("x").max()` - extremes +- `pl.first()` / `pl.last()` - first/last values + +### Window Functions with `over()` +Apply aggregations while preserving row count: +```python +# Add group statistics to each row +df.with_columns( + avg_age_by_city=pl.col("age").mean().over("city"), + rank_in_city=pl.col("salary").rank().over("city") +) + +# Multiple grouping columns +df.with_columns( + group_avg=pl.col("value").mean().over("category", "region") +) +``` + +**Mapping strategies:** +- `group_to_rows` (default): Preserves original row order +- `explode`: Faster but groups rows together +- `join`: Creates list columns + +## Data I/O + +### Supported Formats +Polars supports reading and writing: +- CSV, Parquet, JSON, Excel +- Databases (via connectors) +- Cloud storage (S3, Azure, GCS) +- Google BigQuery +- Multiple/partitioned files + +### Common I/O Operations + +**CSV:** +```python +# Eager +df = pl.read_csv("file.csv") +df.write_csv("output.csv") + +# Lazy (preferred for large files) +lf = pl.scan_csv("file.csv") +result = lf.filter(...).select(...).collect() +``` + +**Parquet (recommended for performance):** +```python +df = pl.read_parquet("file.parquet") +df.write_parquet("output.parquet") +``` + +**JSON:** +```python +df = pl.read_json("file.json") +df.write_json("output.json") +``` + +For comprehensive I/O documentation, load `references/io_guide.md`. + +## Transformations + +### Joins +Combine DataFrames: +```python +# Inner join +df1.join(df2, on="id", how="inner") + +# Left join +df1.join(df2, on="id", how="left") + +# Join on different column names +df1.join(df2, left_on="user_id", right_on="id") +``` + +### Concatenation +Stack DataFrames: +```python +# Vertical (stack rows) +pl.concat([df1, df2], how="vertical") + +# Horizontal (add columns) +pl.concat([df1, df2], how="horizontal") + +# Diagonal (union with different schemas) +pl.concat([df1, df2], how="diagonal") +``` + +### Pivot and Unpivot +Reshape data: +```python +# Pivot (wide format) +df.pivot(values="sales", index="date", columns="product") + +# Unpivot (long format) +df.unpivot(index="id", on=["col1", "col2"]) +``` + +For detailed transformation examples, load `references/transformations.md`. + +## Pandas Migration + +Polars offers significant performance improvements over pandas with a cleaner API. Key differences: + +### Conceptual Differences +- **No index**: Polars uses integer positions only +- **Strict typing**: No silent type conversions +- **Lazy evaluation**: Available via LazyFrame +- **Parallel by default**: Operations parallelized automatically + +### Common Operation Mappings + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Select column | `df["col"]` | `df.select("col")` | +| Filter | `df[df["col"] > 10]` | `df.filter(pl.col("col") > 10)` | +| Add column | `df.assign(x=...)` | `df.with_columns(x=...)` | +| Group by | `df.groupby("col").agg(...)` | `df.group_by("col").agg(...)` | +| Window | `df.groupby("col").transform(...)` | `df.with_columns(...).over("col")` | + +### Key Syntax Patterns + +**Pandas sequential (slow):** +```python +df.assign( + col_a=lambda df_: df_.value * 10, + col_b=lambda df_: df_.value * 100 +) +``` + +**Polars parallel (fast):** +```python +df.with_columns( + col_a=pl.col("value") * 10, + col_b=pl.col("value") * 100, +) +``` + +For comprehensive migration guide, load `references/pandas_migration.md`. + +## Best Practices + +### Performance Optimization + +1. **Use lazy evaluation for large datasets:** + ```python + lf = pl.scan_csv("large.csv") # Don't use read_csv + result = lf.filter(...).select(...).collect() + ``` + +2. **Avoid Python functions in hot paths:** + - Stay within expression API for parallelization + - Use `.map_elements()` only when necessary + - Prefer native Polars operations + +3. **Use streaming for very large data:** + ```python + lf.collect(streaming=True) + ``` + +4. **Select only needed columns early:** + ```python + # Good: Select columns early + lf.select("col1", "col2").filter(...) + + # Bad: Filter on all columns first + lf.filter(...).select("col1", "col2") + ``` + +5. **Use appropriate data types:** + - Categorical for low-cardinality strings + - Appropriate integer sizes (i32 vs i64) + - Date types for temporal data + +### Expression Patterns + +**Conditional operations:** +```python +pl.when(condition).then(value).otherwise(other_value) +``` + +**Column operations across multiple columns:** +```python +df.select(pl.col("^.*_value$") * 2) # Regex pattern +``` + +**Null handling:** +```python +pl.col("x").fill_null(0) +pl.col("x").is_null() +pl.col("x").drop_nulls() +``` + +For additional best practices and patterns, load `references/best_practices.md`. + +## Resources + +This skill includes comprehensive reference documentation: + +### references/ +- `core_concepts.md` - Detailed explanations of expressions, lazy evaluation, and type system +- `operations.md` - Comprehensive guide to all common operations with examples +- `pandas_migration.md` - Complete migration guide from pandas to Polars +- `io_guide.md` - Data I/O operations for all supported formats +- `transformations.md` - Joins, concatenation, pivots, and reshaping operations +- `best_practices.md` - Performance optimization tips and common patterns + +Load these references as needed when users require detailed information about specific topics. diff --git a/scientific-packages/polars/references/best_practices.md b/scientific-packages/polars/references/best_practices.md new file mode 100644 index 0000000..0585ed6 --- /dev/null +++ b/scientific-packages/polars/references/best_practices.md @@ -0,0 +1,649 @@ +# Polars Best Practices and Performance Guide + +Comprehensive guide to writing efficient Polars code and avoiding common pitfalls. + +## Performance Optimization + +### 1. Use Lazy Evaluation + +**Always prefer lazy mode for large datasets:** + +```python +# Bad: Eager mode loads everything immediately +df = pl.read_csv("large_file.csv") +result = df.filter(pl.col("age") > 25).select("name", "age") + +# Good: Lazy mode optimizes before execution +lf = pl.scan_csv("large_file.csv") +result = lf.filter(pl.col("age") > 25).select("name", "age").collect() +``` + +**Benefits of lazy evaluation:** +- Predicate pushdown (filter at source) +- Projection pushdown (read only needed columns) +- Query optimization +- Parallel execution planning + +### 2. Filter and Select Early + +Push filters and column selection as early as possible in the pipeline: + +```python +# Bad: Process all data, then filter and select +result = ( + lf.group_by("category") + .agg(pl.col("value").mean()) + .join(other, on="category") + .filter(pl.col("value") > 100) + .select("category", "value") +) + +# Good: Filter and select early +result = ( + lf.select("category", "value") # Only needed columns + .filter(pl.col("value") > 100) # Filter early + .group_by("category") + .agg(pl.col("value").mean()) + .join(other.select("category", "other_col"), on="category") +) +``` + +### 3. Avoid Python Functions + +Stay within the expression API to maintain parallelization: + +```python +# Bad: Python function disables parallelization +df = df.with_columns( + result=pl.col("value").map_elements(lambda x: x * 2, return_dtype=pl.Float64) +) + +# Good: Use native expressions (parallelized) +df = df.with_columns(result=pl.col("value") * 2) +``` + +**When you must use custom functions:** +```python +# If truly needed, be explicit +df = df.with_columns( + result=pl.col("value").map_elements( + custom_function, + return_dtype=pl.Float64, + skip_nulls=True # Optimize null handling + ) +) +``` + +### 4. Use Streaming for Very Large Data + +Enable streaming for datasets larger than RAM: + +```python +# Streaming mode processes data in chunks +lf = pl.scan_parquet("very_large.parquet") +result = lf.filter(pl.col("value") > 100).collect(streaming=True) + +# Or use sink for direct streaming writes +lf.filter(pl.col("value") > 100).sink_parquet("output.parquet") +``` + +### 5. Optimize Data Types + +Choose appropriate data types to reduce memory and improve performance: + +```python +# Bad: Default types may be wasteful +df = pl.read_csv("data.csv") + +# Good: Specify optimal types +df = pl.read_csv( + "data.csv", + dtypes={ + "id": pl.UInt32, # Instead of Int64 if values fit + "category": pl.Categorical, # For low-cardinality strings + "date": pl.Date, # Instead of String + "small_int": pl.Int16, # Instead of Int64 + } +) +``` + +**Type optimization guidelines:** +- Use smallest integer type that fits your data +- Use `Categorical` for strings with low cardinality (<50% unique) +- Use `Date` instead of `Datetime` when time isn't needed +- Use `Boolean` instead of integers for binary flags + +### 6. Parallel Operations + +Structure code to maximize parallelization: + +```python +# Bad: Sequential pipe operations disable parallelization +df = ( + df.pipe(operation1) + .pipe(operation2) + .pipe(operation3) +) + +# Good: Combined operations enable parallelization +df = df.with_columns( + result1=operation1_expr(), + result2=operation2_expr(), + result3=operation3_expr() +) +``` + +### 7. Rechunk After Concatenation + +```python +# Concatenation can fragment data +combined = pl.concat([df1, df2, df3]) + +# Rechunk for better performance in subsequent operations +combined = pl.concat([df1, df2, df3], rechunk=True) +``` + +## Expression Patterns + +### Conditional Logic + +**Simple conditions:** +```python +df.with_columns( + status=pl.when(pl.col("age") >= 18) + .then("adult") + .otherwise("minor") +) +``` + +**Multiple conditions:** +```python +df.with_columns( + grade=pl.when(pl.col("score") >= 90) + .then("A") + .when(pl.col("score") >= 80) + .then("B") + .when(pl.col("score") >= 70) + .then("C") + .when(pl.col("score") >= 60) + .then("D") + .otherwise("F") +) +``` + +**Complex conditions:** +```python +df.with_columns( + category=pl.when( + (pl.col("revenue") > 1000000) & (pl.col("customers") > 100) + ) + .then("enterprise") + .when( + (pl.col("revenue") > 100000) | (pl.col("customers") > 50) + ) + .then("business") + .otherwise("starter") +) +``` + +### Null Handling + +**Check for nulls:** +```python +df.filter(pl.col("value").is_null()) +df.filter(pl.col("value").is_not_null()) +``` + +**Fill nulls:** +```python +# Constant value +df.with_columns(pl.col("value").fill_null(0)) + +# Forward fill +df.with_columns(pl.col("value").fill_null(strategy="forward")) + +# Backward fill +df.with_columns(pl.col("value").fill_null(strategy="backward")) + +# Mean +df.with_columns(pl.col("value").fill_null(strategy="mean")) + +# Per-group fill +df.with_columns( + pl.col("value").fill_null(pl.col("value").mean()).over("group") +) +``` + +**Coalesce (first non-null):** +```python +df.with_columns( + combined=pl.coalesce(["col1", "col2", "col3"]) +) +``` + +### Column Selection Patterns + +**By name:** +```python +df.select("col1", "col2", "col3") +``` + +**By pattern:** +```python +# Regex +df.select(pl.col("^sales_.*$")) + +# Starts with +df.select(pl.col("^sales")) + +# Ends with +df.select(pl.col("_total$")) + +# Contains +df.select(pl.col(".*revenue.*")) +``` + +**By type:** +```python +# All numeric columns +df.select(pl.col(pl.NUMERIC_DTYPES)) + +# All string columns +df.select(pl.col(pl.Utf8)) + +# Multiple types +df.select(pl.col(pl.NUMERIC_DTYPES, pl.Boolean)) +``` + +**Exclude columns:** +```python +df.select(pl.all().exclude("id", "timestamp")) +``` + +**Transform multiple columns:** +```python +# Apply same operation to multiple columns +df.select( + pl.col("^sales_.*$") * 1.1 # 10% increase to all sales columns +) +``` + +### Aggregation Patterns + +**Multiple aggregations:** +```python +df.group_by("category").agg( + pl.col("value").sum().alias("total"), + pl.col("value").mean().alias("average"), + pl.col("value").std().alias("std_dev"), + pl.col("id").count().alias("count"), + pl.col("id").n_unique().alias("unique_count"), + pl.col("value").min().alias("minimum"), + pl.col("value").max().alias("maximum"), + pl.col("value").quantile(0.5).alias("median"), + pl.col("value").quantile(0.95).alias("p95") +) +``` + +**Conditional aggregations:** +```python +df.group_by("category").agg( + # Count high values + (pl.col("value") > 100).sum().alias("high_count"), + + # Average of filtered values + pl.col("value").filter(pl.col("active")).mean().alias("active_avg"), + + # Conditional sum + pl.when(pl.col("status") == "completed") + .then(pl.col("amount")) + .otherwise(0) + .sum() + .alias("completed_total") +) +``` + +**Grouped transformations:** +```python +df.with_columns( + # Group statistics + group_mean=pl.col("value").mean().over("category"), + group_std=pl.col("value").std().over("category"), + + # Rank within groups + rank=pl.col("value").rank().over("category"), + + # Percentage of group total + pct_of_group=(pl.col("value") / pl.col("value").sum().over("category")) * 100 +) +``` + +## Common Pitfalls and Anti-Patterns + +### Pitfall 1: Row Iteration + +```python +# Bad: Never iterate rows +for row in df.iter_rows(): + # Process row + result = row[0] * 2 + +# Good: Use vectorized operations +df = df.with_columns(result=pl.col("value") * 2) +``` + +### Pitfall 2: Modifying in Place + +```python +# Bad: Polars is immutable, this doesn't work as expected +df["new_col"] = df["old_col"] * 2 # May work but not recommended + +# Good: Functional style +df = df.with_columns(new_col=pl.col("old_col") * 2) +``` + +### Pitfall 3: Not Using Expressions + +```python +# Bad: String-based operations +df.select("value * 2") # Won't work + +# Good: Expression-based +df.select(pl.col("value") * 2) +``` + +### Pitfall 4: Inefficient Joins + +```python +# Bad: Join large tables without filtering +result = large_df1.join(large_df2, on="id") + +# Good: Filter before joining +result = ( + large_df1.filter(pl.col("active")) + .join( + large_df2.filter(pl.col("status") == "valid"), + on="id" + ) +) +``` + +### Pitfall 5: Not Specifying Types + +```python +# Bad: Let Polars infer everything +df = pl.read_csv("data.csv") + +# Good: Specify types for correctness and performance +df = pl.read_csv( + "data.csv", + dtypes={"id": pl.Int64, "date": pl.Date, "category": pl.Categorical} +) +``` + +### Pitfall 6: Creating Many Small DataFrames + +```python +# Bad: Many operations creating intermediate DataFrames +df1 = df.filter(pl.col("age") > 25) +df2 = df1.select("name", "age") +df3 = df2.sort("age") +result = df3.head(10) + +# Good: Chain operations +result = ( + df.filter(pl.col("age") > 25) + .select("name", "age") + .sort("age") + .head(10) +) + +# Better: Use lazy mode +result = ( + df.lazy() + .filter(pl.col("age") > 25) + .select("name", "age") + .sort("age") + .head(10) + .collect() +) +``` + +## Memory Management + +### Monitor Memory Usage + +```python +# Check DataFrame size +print(f"Estimated size: {df.estimated_size('mb'):.2f} MB") + +# Profile memory during operations +lf = pl.scan_csv("large.csv") +print(lf.explain()) # See query plan +``` + +### Reduce Memory Footprint + +```python +# 1. Use lazy mode +lf = pl.scan_parquet("data.parquet") + +# 2. Stream results +result = lf.collect(streaming=True) + +# 3. Select only needed columns +lf = lf.select("col1", "col2") + +# 4. Optimize data types +df = df.with_columns( + pl.col("int_col").cast(pl.Int32), # Downcast if possible + pl.col("category").cast(pl.Categorical) # For low cardinality +) + +# 5. Drop columns not needed +df = df.drop("large_text_col", "unused_col") +``` + +## Testing and Debugging + +### Inspect Query Plans + +```python +lf = pl.scan_csv("data.csv") +query = lf.filter(pl.col("age") > 25).select("name", "age") + +# View the optimized query plan +print(query.explain()) + +# View detailed query plan +print(query.explain(optimized=True)) +``` + +### Sample Data for Development + +```python +# Use n_rows for testing +df = pl.read_csv("large.csv", n_rows=1000) + +# Or sample after reading +df_sample = df.sample(n=1000, seed=42) +``` + +### Validate Schemas + +```python +# Check schema +print(df.schema) + +# Ensure schema matches expectation +expected_schema = { + "id": pl.Int64, + "name": pl.Utf8, + "date": pl.Date +} + +assert df.schema == expected_schema +``` + +### Profile Performance + +```python +import time + +# Time operations +start = time.time() +result = lf.collect() +print(f"Execution time: {time.time() - start:.2f}s") + +# Compare eager vs lazy +start = time.time() +df_eager = pl.read_csv("data.csv").filter(pl.col("age") > 25) +eager_time = time.time() - start + +start = time.time() +df_lazy = pl.scan_csv("data.csv").filter(pl.col("age") > 25).collect() +lazy_time = time.time() - start + +print(f"Eager: {eager_time:.2f}s, Lazy: {lazy_time:.2f}s") +``` + +## File Format Best Practices + +### Choose the Right Format + +**Parquet:** +- Best for: Large datasets, archival, data lakes +- Pros: Excellent compression, columnar, fast reads +- Cons: Not human-readable + +**CSV:** +- Best for: Small datasets, human inspection, legacy systems +- Pros: Universal, human-readable +- Cons: Slow, large file size, no type preservation + +**Arrow IPC:** +- Best for: Inter-process communication, temporary storage +- Pros: Fastest, zero-copy, preserves all types +- Cons: Less compression than Parquet + +### File Reading Best Practices + +```python +# 1. Use lazy reading +lf = pl.scan_parquet("data.parquet") # Not read_parquet + +# 2. Read multiple files efficiently +lf = pl.scan_parquet("data/*.parquet") # Parallel reading + +# 3. Specify schema when known +lf = pl.scan_csv( + "data.csv", + dtypes={"id": pl.Int64, "date": pl.Date} +) + +# 4. Use predicate pushdown +result = lf.filter(pl.col("date") >= "2023-01-01").collect() +``` + +### File Writing Best Practices + +```python +# 1. Use Parquet for large data +df.write_parquet("output.parquet", compression="zstd") + +# 2. Partition large datasets +df.write_parquet("output", partition_by=["year", "month"]) + +# 3. Use streaming for very large writes +lf.sink_parquet("output.parquet") # Streaming write + +# 4. Optimize compression +df.write_parquet( + "output.parquet", + compression="snappy", # Fast compression + statistics=True # Enable predicate pushdown on read +) +``` + +## Code Organization + +### Reusable Expressions + +```python +# Define reusable expressions +age_group = ( + pl.when(pl.col("age") < 18) + .then("minor") + .when(pl.col("age") < 65) + .then("adult") + .otherwise("senior") +) + +revenue_per_customer = pl.col("revenue") / pl.col("customer_count") + +# Use in multiple contexts +df = df.with_columns( + age_group=age_group, + rpc=revenue_per_customer +) + +# Reuse in filtering +df = df.filter(revenue_per_customer > 100) +``` + +### Pipeline Functions + +```python +def clean_data(lf: pl.LazyFrame) -> pl.LazyFrame: + """Clean and standardize data.""" + return lf.with_columns( + pl.col("name").str.to_uppercase(), + pl.col("date").str.strptime(pl.Date, "%Y-%m-%d"), + pl.col("amount").fill_null(0) + ) + +def add_features(lf: pl.LazyFrame) -> pl.LazyFrame: + """Add computed features.""" + return lf.with_columns( + month=pl.col("date").dt.month(), + year=pl.col("date").dt.year(), + amount_log=pl.col("amount").log() + ) + +# Compose pipeline +result = ( + pl.scan_csv("data.csv") + .pipe(clean_data) + .pipe(add_features) + .filter(pl.col("year") == 2023) + .collect() +) +``` + +## Documentation + +Always document complex expressions and transformations: + +```python +# Good: Document intent +df = df.with_columns( + # Calculate customer lifetime value as sum of purchases + # divided by months since first purchase + clv=( + pl.col("total_purchases") / + ((pl.col("last_purchase_date") - pl.col("first_purchase_date")) + .dt.total_days() / 30) + ) +) +``` + +## Version Compatibility + +```python +# Check Polars version +import polars as pl +print(pl.__version__) + +# Feature availability varies by version +# Document version requirements for production code +``` diff --git a/scientific-packages/polars/references/core_concepts.md b/scientific-packages/polars/references/core_concepts.md new file mode 100644 index 0000000..e3a0e56 --- /dev/null +++ b/scientific-packages/polars/references/core_concepts.md @@ -0,0 +1,378 @@ +# Polars Core Concepts + +## Expressions + +Expressions are the foundation of Polars' API. They are composable units that describe data transformations without executing them immediately. + +### What are Expressions? + +An expression describes a transformation on data. It only materializes (executes) within specific contexts: +- `select()` - Select and transform columns +- `with_columns()` - Add or modify columns +- `filter()` - Filter rows +- `group_by().agg()` - Aggregate data + +### Expression Syntax + +**Basic column reference:** +```python +pl.col("column_name") +``` + +**Computed expressions:** +```python +# Arithmetic +pl.col("height") * 2 +pl.col("price") + pl.col("tax") + +# With alias +(pl.col("weight") / (pl.col("height") ** 2)).alias("bmi") + +# Method chaining +pl.col("name").str.to_uppercase().str.slice(0, 3) +``` + +### Expression Contexts + +**Select context:** +```python +df.select( + "name", # Simple column name + pl.col("age"), # Expression + (pl.col("age") * 12).alias("age_in_months") # Computed expression +) +``` + +**With_columns context:** +```python +df.with_columns( + age_doubled=pl.col("age") * 2, + name_upper=pl.col("name").str.to_uppercase() +) +``` + +**Filter context:** +```python +df.filter( + pl.col("age") > 25, + pl.col("city").is_in(["NY", "LA", "SF"]) +) +``` + +**Group_by context:** +```python +df.group_by("department").agg( + pl.col("salary").mean(), + pl.col("employee_id").count() +) +``` + +### Expression Expansion + +Apply operations to multiple columns at once: + +**All columns:** +```python +df.select(pl.all() * 2) +``` + +**Pattern matching:** +```python +# All columns ending with "_value" +df.select(pl.col("^.*_value$") * 100) + +# All numeric columns +df.select(pl.col(pl.NUMERIC_DTYPES) + 1) +``` + +**Exclude patterns:** +```python +df.select(pl.all().exclude("id", "name")) +``` + +### Expression Composition + +Expressions can be stored and reused: + +```python +# Define reusable expressions +age_expression = pl.col("age") * 12 +name_expression = pl.col("name").str.to_uppercase() + +# Use in multiple contexts +df.select(age_expression, name_expression) +df.with_columns(age_months=age_expression) +``` + +## Data Types + +Polars has a strict type system based on Apache Arrow. + +### Core Data Types + +**Numeric:** +- `Int8`, `Int16`, `Int32`, `Int64` - Signed integers +- `UInt8`, `UInt16`, `UInt32`, `UInt64` - Unsigned integers +- `Float32`, `Float64` - Floating point numbers + +**Text:** +- `Utf8` / `String` - UTF-8 encoded strings +- `Categorical` - Categorized strings (low cardinality) +- `Enum` - Fixed set of string values + +**Temporal:** +- `Date` - Calendar date (no time) +- `Datetime` - Date and time with optional timezone +- `Time` - Time of day +- `Duration` - Time duration/difference + +**Boolean:** +- `Boolean` - True/False values + +**Nested:** +- `List` - Variable-length lists +- `Array` - Fixed-length arrays +- `Struct` - Nested record structures + +**Other:** +- `Binary` - Binary data +- `Object` - Python objects (avoid in production) +- `Null` - Null type + +### Type Casting + +Convert between types explicitly: + +```python +# Cast to different type +df.select( + pl.col("age").cast(pl.Float64), + pl.col("date_string").str.strptime(pl.Date, "%Y-%m-%d"), + pl.col("id").cast(pl.Utf8) +) +``` + +### Null Handling + +Polars uses consistent null handling across all types: + +**Check for nulls:** +```python +df.filter(pl.col("value").is_null()) +df.filter(pl.col("value").is_not_null()) +``` + +**Fill nulls:** +```python +pl.col("value").fill_null(0) +pl.col("value").fill_null(strategy="forward") +pl.col("value").fill_null(strategy="backward") +pl.col("value").fill_null(strategy="mean") +``` + +**Drop nulls:** +```python +df.drop_nulls() # Drop any row with nulls +df.drop_nulls(subset=["col1", "col2"]) # Drop rows with nulls in specific columns +``` + +### Categorical Data + +Use categorical types for string columns with low cardinality (repeated values): + +```python +# Cast to categorical +df.with_columns( + pl.col("category").cast(pl.Categorical) +) + +# Benefits: +# - Reduced memory usage +# - Faster grouping and joining +# - Maintains order information +``` + +## Lazy vs Eager Evaluation + +Polars supports two execution modes: eager (DataFrame) and lazy (LazyFrame). + +### Eager Evaluation (DataFrame) + +Operations execute immediately: + +```python +import polars as pl + +# DataFrame operations execute right away +df = pl.read_csv("data.csv") # Reads file immediately +result = df.filter(pl.col("age") > 25) # Filters immediately +final = result.select("name", "age") # Selects immediately +``` + +**When to use eager:** +- Small datasets that fit in memory +- Interactive exploration in notebooks +- Simple one-off operations +- Immediate feedback needed + +### Lazy Evaluation (LazyFrame) + +Operations build a query plan, optimized before execution: + +```python +import polars as pl + +# LazyFrame operations build a query plan +lf = pl.scan_csv("data.csv") # Doesn't read yet +lf2 = lf.filter(pl.col("age") > 25) # Adds to plan +lf3 = lf2.select("name", "age") # Adds to plan +df = lf3.collect() # NOW executes optimized plan +``` + +**When to use lazy:** +- Large datasets +- Complex query pipelines +- Only need subset of data +- Performance is critical +- Streaming required + +### Query Optimization + +Polars automatically optimizes lazy queries: + +**Predicate Pushdown:** +Filter operations pushed to data source when possible: +```python +# Only reads rows where age > 25 from CSV +lf = pl.scan_csv("data.csv") +result = lf.filter(pl.col("age") > 25).collect() +``` + +**Projection Pushdown:** +Only read needed columns from data source: +```python +# Only reads "name" and "age" columns from CSV +lf = pl.scan_csv("data.csv") +result = lf.select("name", "age").collect() +``` + +**Query Plan Inspection:** +```python +# View the optimized query plan +lf = pl.scan_csv("data.csv") +result = lf.filter(pl.col("age") > 25).select("name", "age") +print(result.explain()) # Shows optimized plan +``` + +### Streaming Mode + +Process data larger than memory: + +```python +# Enable streaming for very large datasets +lf = pl.scan_csv("very_large.csv") +result = lf.filter(pl.col("age") > 25).collect(streaming=True) +``` + +**Streaming benefits:** +- Process data larger than RAM +- Lower peak memory usage +- Chunk-based processing +- Automatic memory management + +**Streaming limitations:** +- Not all operations support streaming +- May be slower for small data +- Some operations require materializing entire dataset + +### Converting Between Eager and Lazy + +**Eager to Lazy:** +```python +df = pl.read_csv("data.csv") +lf = df.lazy() # Convert to LazyFrame +``` + +**Lazy to Eager:** +```python +lf = pl.scan_csv("data.csv") +df = lf.collect() # Execute and return DataFrame +``` + +## Memory Format + +Polars uses Apache Arrow columnar memory format: + +**Benefits:** +- Zero-copy data sharing with other Arrow libraries +- Efficient columnar operations +- SIMD vectorization +- Reduced memory overhead +- Fast serialization + +**Implications:** +- Data stored column-wise, not row-wise +- Column operations very fast +- Random row access slower than pandas +- Best for analytical workloads + +## Parallelization + +Polars parallelizes operations automatically using Rust's concurrency: + +**What gets parallelized:** +- Aggregations within groups +- Window functions +- Most expression evaluations +- File reading (multiple files) +- Join operations + +**What to avoid for parallelization:** +- Python user-defined functions (UDFs) +- Lambda functions in `.map_elements()` +- Sequential `.pipe()` chains + +**Best practice:** +```python +# Good: Stays in expression API (parallelized) +df.with_columns( + pl.col("value") * 10, + pl.col("value").log(), + pl.col("value").sqrt() +) + +# Bad: Uses Python function (sequential) +df.with_columns( + pl.col("value").map_elements(lambda x: x * 10) +) +``` + +## Strict Type System + +Polars enforces strict typing: + +**No silent conversions:** +```python +# This will error - can't mix types +# df.with_columns(pl.col("int_col") + "string") + +# Must cast explicitly +df.with_columns( + pl.col("int_col").cast(pl.Utf8) + "_suffix" +) +``` + +**Benefits:** +- Prevents silent bugs +- Predictable behavior +- Better performance +- Clearer code intent + +**Integer nulls:** +Unlike pandas, integer columns can have nulls without converting to float: +```python +# In pandas: Int column with null becomes Float +# In polars: Int column with null stays Int (with null values) +df = pl.DataFrame({"int_col": [1, 2, None, 4]}) +# dtype: Int64 (not Float64) +``` diff --git a/scientific-packages/polars/references/io_guide.md b/scientific-packages/polars/references/io_guide.md new file mode 100644 index 0000000..bbb9dc9 --- /dev/null +++ b/scientific-packages/polars/references/io_guide.md @@ -0,0 +1,557 @@ +# Polars Data I/O Guide + +Comprehensive guide to reading and writing data in various formats with Polars. + +## CSV Files + +### Reading CSV + +**Eager mode (loads into memory):** +```python +import polars as pl + +# Basic read +df = pl.read_csv("data.csv") + +# With options +df = pl.read_csv( + "data.csv", + separator=",", + has_header=True, + columns=["col1", "col2"], # Select specific columns + n_rows=1000, # Read only first 1000 rows + skip_rows=10, # Skip first 10 rows + dtypes={"col1": pl.Int64, "col2": pl.Utf8}, # Specify types + null_values=["NA", "null", ""], # Define null values + encoding="utf-8", + ignore_errors=False +) +``` + +**Lazy mode (scans without loading - recommended for large files):** +```python +# Scan CSV (builds query plan) +lf = pl.scan_csv("data.csv") + +# Apply operations +result = lf.filter(pl.col("age") > 25).select("name", "age") + +# Execute and load +df = result.collect() +``` + +### Writing CSV + +```python +# Basic write +df.write_csv("output.csv") + +# With options +df.write_csv( + "output.csv", + separator=",", + include_header=True, + null_value="", # How to represent nulls + quote_char='"', + line_terminator="\n" +) +``` + +### Multiple CSV Files + +**Read multiple files:** +```python +# Read all CSVs in directory +lf = pl.scan_csv("data/*.csv") + +# Read specific files +lf = pl.scan_csv(["file1.csv", "file2.csv", "file3.csv"]) +``` + +## Parquet Files + +Parquet is the recommended format for performance and compression. + +### Reading Parquet + +**Eager:** +```python +df = pl.read_parquet("data.parquet") + +# With options +df = pl.read_parquet( + "data.parquet", + columns=["col1", "col2"], # Select specific columns + n_rows=1000, # Read first N rows + parallel="auto" # Control parallelization +) +``` + +**Lazy (recommended):** +```python +lf = pl.scan_parquet("data.parquet") + +# Automatic predicate and projection pushdown +result = lf.filter(pl.col("age") > 25).select("name", "age").collect() +``` + +### Writing Parquet + +```python +# Basic write +df.write_parquet("output.parquet") + +# With compression +df.write_parquet( + "output.parquet", + compression="snappy", # Options: "snappy", "gzip", "brotli", "lz4", "zstd" + statistics=True, # Write statistics (enables predicate pushdown) + use_pyarrow=False # Use Rust writer (faster) +) +``` + +### Partitioned Parquet (Hive-style) + +**Write partitioned:** +```python +# Write with partitioning +df.write_parquet( + "output_dir", + partition_by=["year", "month"] # Creates directory structure +) +# Creates: output_dir/year=2023/month=01/data.parquet +``` + +**Read partitioned:** +```python +lf = pl.scan_parquet("output_dir/**/*.parquet") + +# Hive partitioning columns are automatically added +result = lf.filter(pl.col("year") == 2023).collect() +``` + +## JSON Files + +### Reading JSON + +**NDJSON (newline-delimited JSON) - recommended:** +```python +df = pl.read_ndjson("data.ndjson") + +# Lazy +lf = pl.scan_ndjson("data.ndjson") +``` + +**Standard JSON:** +```python +df = pl.read_json("data.json") + +# From JSON string +df = pl.read_json('{"col1": [1, 2], "col2": ["a", "b"]}') +``` + +### Writing JSON + +```python +# Write NDJSON +df.write_ndjson("output.ndjson") + +# Write standard JSON +df.write_json("output.json") + +# Pretty printed +df.write_json("output.json", pretty=True, row_oriented=False) +``` + +## Excel Files + +### Reading Excel + +```python +# Read first sheet +df = pl.read_excel("data.xlsx") + +# Specific sheet +df = pl.read_excel("data.xlsx", sheet_name="Sheet1") +# Or by index +df = pl.read_excel("data.xlsx", sheet_id=0) + +# With options +df = pl.read_excel( + "data.xlsx", + sheet_name="Sheet1", + columns=["A", "B", "C"], # Excel columns + n_rows=100, + skip_rows=5, + has_header=True +) +``` + +### Writing Excel + +```python +# Write to Excel +df.write_excel("output.xlsx") + +# Multiple sheets +with pl.ExcelWriter("output.xlsx") as writer: + df1.write_excel(writer, worksheet="Sheet1") + df2.write_excel(writer, worksheet="Sheet2") +``` + +## Database Connectivity + +### Read from Database + +```python +import polars as pl + +# Read entire table +df = pl.read_database("SELECT * FROM users", connection_uri="postgresql://...") + +# Using connectorx for better performance +df = pl.read_database_uri( + "SELECT * FROM users WHERE age > 25", + uri="postgresql://user:pass@localhost/db" +) +``` + +### Write to Database + +```python +# Using SQLAlchemy +from sqlalchemy import create_engine + +engine = create_engine("postgresql://user:pass@localhost/db") +df.write_database("table_name", connection=engine) + +# With options +df.write_database( + "table_name", + connection=engine, + if_exists="replace", # or "append", "fail" +) +``` + +### Common Database Connectors + +**PostgreSQL:** +```python +uri = "postgresql://username:password@localhost:5432/database" +df = pl.read_database_uri("SELECT * FROM table", uri=uri) +``` + +**MySQL:** +```python +uri = "mysql://username:password@localhost:3306/database" +df = pl.read_database_uri("SELECT * FROM table", uri=uri) +``` + +**SQLite:** +```python +uri = "sqlite:///path/to/database.db" +df = pl.read_database_uri("SELECT * FROM table", uri=uri) +``` + +## Cloud Storage + +### AWS S3 + +```python +# Read from S3 +df = pl.read_parquet("s3://bucket/path/to/file.parquet") +lf = pl.scan_parquet("s3://bucket/path/*.parquet") + +# Write to S3 +df.write_parquet("s3://bucket/path/output.parquet") + +# With credentials +import os +os.environ["AWS_ACCESS_KEY_ID"] = "your_key" +os.environ["AWS_SECRET_ACCESS_KEY"] = "your_secret" +os.environ["AWS_REGION"] = "us-west-2" + +df = pl.read_parquet("s3://bucket/file.parquet") +``` + +### Azure Blob Storage + +```python +# Read from Azure +df = pl.read_parquet("az://container/path/file.parquet") + +# Write to Azure +df.write_parquet("az://container/path/output.parquet") + +# With credentials +os.environ["AZURE_STORAGE_ACCOUNT_NAME"] = "account" +os.environ["AZURE_STORAGE_ACCOUNT_KEY"] = "key" +``` + +### Google Cloud Storage (GCS) + +```python +# Read from GCS +df = pl.read_parquet("gs://bucket/path/file.parquet") + +# Write to GCS +df.write_parquet("gs://bucket/path/output.parquet") + +# With credentials +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/credentials.json" +``` + +## Google BigQuery + +```python +# Read from BigQuery +df = pl.read_database( + "SELECT * FROM project.dataset.table", + connection_uri="bigquery://project" +) + +# Or using Google Cloud SDK +from google.cloud import bigquery +client = bigquery.Client() + +query = "SELECT * FROM project.dataset.table WHERE date > '2023-01-01'" +df = pl.from_pandas(client.query(query).to_dataframe()) +``` + +## Apache Arrow + +### IPC/Feather Format + +**Read:** +```python +df = pl.read_ipc("data.arrow") +lf = pl.scan_ipc("data.arrow") +``` + +**Write:** +```python +df.write_ipc("output.arrow") + +# Compressed +df.write_ipc("output.arrow", compression="zstd") +``` + +### Arrow Streaming + +```python +# Write streaming format +df.write_ipc("output.arrows", compression="zstd") + +# Read streaming +df = pl.read_ipc("output.arrows") +``` + +### From/To Arrow + +```python +import pyarrow as pa + +# From Arrow Table +arrow_table = pa.table({"col": [1, 2, 3]}) +df = pl.from_arrow(arrow_table) + +# To Arrow Table +arrow_table = df.to_arrow() +``` + +## In-Memory Formats + +### Python Dictionaries + +```python +# From dict +df = pl.DataFrame({ + "col1": [1, 2, 3], + "col2": ["a", "b", "c"] +}) + +# To dict +data_dict = df.to_dict() # Column-oriented +data_dict = df.to_dict(as_series=False) # Lists instead of Series +``` + +### NumPy Arrays + +```python +import numpy as np + +# From NumPy +arr = np.array([[1, 2], [3, 4], [5, 6]]) +df = pl.DataFrame(arr, schema=["col1", "col2"]) + +# To NumPy +arr = df.to_numpy() +``` + +### Pandas DataFrames + +```python +import pandas as pd + +# From Pandas +pd_df = pd.DataFrame({"col": [1, 2, 3]}) +pl_df = pl.from_pandas(pd_df) + +# To Pandas +pd_df = pl_df.to_pandas() + +# Zero-copy when possible +pl_df = pl.from_arrow(pd_df) +``` + +### Lists of Rows + +```python +# From list of dicts +data = [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30} +] +df = pl.DataFrame(data) + +# To list of dicts +rows = df.to_dicts() + +# From list of tuples +data = [("Alice", 25), ("Bob", 30)] +df = pl.DataFrame(data, schema=["name", "age"]) +``` + +## Streaming Large Files + +For datasets larger than memory, use lazy mode with streaming: + +```python +# Streaming mode +lf = pl.scan_csv("very_large.csv") +result = lf.filter(pl.col("value") > 100).collect(streaming=True) + +# Streaming with multiple files +lf = pl.scan_parquet("data/*.parquet") +result = lf.group_by("category").agg(pl.col("value").sum()).collect(streaming=True) +``` + +## Best Practices + +### Format Selection + +**Use Parquet when:** +- Need compression (up to 10x smaller than CSV) +- Want fast reads/writes +- Need to preserve data types +- Working with large datasets +- Need predicate pushdown + +**Use CSV when:** +- Need human-readable format +- Interfacing with legacy systems +- Data is small +- Need universal compatibility + +**Use JSON when:** +- Working with nested/hierarchical data +- Need web API compatibility +- Data has flexible schema + +**Use Arrow IPC when:** +- Need zero-copy data sharing +- Fastest serialization required +- Working between Arrow-compatible systems + +### Reading Large Files + +```python +# 1. Always use lazy mode +lf = pl.scan_csv("large.csv") # NOT read_csv + +# 2. Filter and select early (pushdown optimization) +result = ( + lf + .select("col1", "col2", "col3") # Only needed columns + .filter(pl.col("date") > "2023-01-01") # Filter early + .collect() +) + +# 3. Use streaming for very large data +result = lf.filter(...).select(...).collect(streaming=True) + +# 4. Read only needed rows during development +df = pl.read_csv("large.csv", n_rows=10000) # Sample for testing +``` + +### Writing Large Files + +```python +# 1. Use Parquet with compression +df.write_parquet("output.parquet", compression="zstd") + +# 2. Use partitioning for very large datasets +df.write_parquet("output", partition_by=["year", "month"]) + +# 3. Write streaming +lf = pl.scan_csv("input.csv") +lf.sink_parquet("output.parquet") # Streaming write +``` + +### Performance Tips + +```python +# 1. Specify dtypes when reading CSV +df = pl.read_csv( + "data.csv", + dtypes={"id": pl.Int64, "name": pl.Utf8} # Avoids inference +) + +# 2. Use appropriate compression +df.write_parquet("output.parquet", compression="snappy") # Fast +df.write_parquet("output.parquet", compression="zstd") # Better compression + +# 3. Parallel reading +df = pl.read_csv("data.csv", parallel="auto") + +# 4. Read multiple files in parallel +lf = pl.scan_parquet("data/*.parquet") # Automatic parallel read +``` + +## Error Handling + +```python +try: + df = pl.read_csv("data.csv") +except pl.exceptions.ComputeError as e: + print(f"Error reading CSV: {e}") + +# Ignore errors during parsing +df = pl.read_csv("messy.csv", ignore_errors=True) + +# Handle missing files +from pathlib import Path +if Path("data.csv").exists(): + df = pl.read_csv("data.csv") +else: + print("File not found") +``` + +## Schema Management + +```python +# Infer schema from sample +schema = pl.read_csv("data.csv", n_rows=1000).schema + +# Use inferred schema for full read +df = pl.read_csv("data.csv", dtypes=schema) + +# Define schema explicitly +schema = { + "id": pl.Int64, + "name": pl.Utf8, + "date": pl.Date, + "value": pl.Float64 +} +df = pl.read_csv("data.csv", dtypes=schema) +``` diff --git a/scientific-packages/polars/references/operations.md b/scientific-packages/polars/references/operations.md new file mode 100644 index 0000000..40441f5 --- /dev/null +++ b/scientific-packages/polars/references/operations.md @@ -0,0 +1,602 @@ +# Polars Operations Reference + +This reference covers all common Polars operations with comprehensive examples. + +## Selection Operations + +### Select Columns + +**Basic selection:** +```python +# Select specific columns +df.select("name", "age", "city") + +# Using expressions +df.select(pl.col("name"), pl.col("age")) +``` + +**Pattern-based selection:** +```python +# All columns starting with "sales_" +df.select(pl.col("^sales_.*$")) + +# All numeric columns +df.select(pl.col(pl.NUMERIC_DTYPES)) + +# All columns except specific ones +df.select(pl.all().exclude("id", "timestamp")) +``` + +**Computed columns:** +```python +df.select( + "name", + (pl.col("age") * 12).alias("age_in_months"), + (pl.col("salary") * 1.1).alias("salary_after_raise") +) +``` + +### With Columns (Add/Modify) + +Add new columns or modify existing ones while preserving all other columns: + +```python +# Add new columns +df.with_columns( + age_doubled=pl.col("age") * 2, + full_name=pl.col("first_name") + " " + pl.col("last_name") +) + +# Modify existing columns +df.with_columns( + pl.col("name").str.to_uppercase().alias("name"), + pl.col("salary").cast(pl.Float64).alias("salary") +) + +# Multiple operations in parallel +df.with_columns( + pl.col("value") * 10, + pl.col("value") * 100, + pl.col("value") * 1000, +) +``` + +## Filtering Operations + +### Basic Filtering + +```python +# Single condition +df.filter(pl.col("age") > 25) + +# Multiple conditions (AND) +df.filter( + pl.col("age") > 25, + pl.col("city") == "NY" +) + +# OR conditions +df.filter( + (pl.col("age") > 30) | (pl.col("salary") > 100000) +) + +# NOT condition +df.filter(~pl.col("active")) +df.filter(pl.col("city") != "NY") +``` + +### Advanced Filtering + +**String operations:** +```python +# Contains substring +df.filter(pl.col("name").str.contains("John")) + +# Starts with +df.filter(pl.col("email").str.starts_with("admin")) + +# Regex match +df.filter(pl.col("phone").str.contains(r"^\d{3}-\d{3}-\d{4}$")) +``` + +**Membership checks:** +```python +# In list +df.filter(pl.col("city").is_in(["NY", "LA", "SF"])) + +# Not in list +df.filter(~pl.col("status").is_in(["inactive", "deleted"])) +``` + +**Range filters:** +```python +# Between values +df.filter(pl.col("age").is_between(25, 35)) + +# Date range +df.filter( + pl.col("date") >= pl.date(2023, 1, 1), + pl.col("date") <= pl.date(2023, 12, 31) +) +``` + +**Null filtering:** +```python +# Filter out nulls +df.filter(pl.col("value").is_not_null()) + +# Keep only nulls +df.filter(pl.col("value").is_null()) +``` + +## Grouping and Aggregation + +### Basic Group By + +```python +# Group by single column +df.group_by("department").agg( + pl.col("salary").mean().alias("avg_salary"), + pl.len().alias("employee_count") +) + +# Group by multiple columns +df.group_by("department", "location").agg( + pl.col("salary").sum() +) + +# Maintain order +df.group_by("category", maintain_order=True).agg( + pl.col("value").sum() +) +``` + +### Aggregation Functions + +**Count and length:** +```python +df.group_by("category").agg( + pl.len().alias("count"), + pl.col("id").count().alias("non_null_count"), + pl.col("id").n_unique().alias("unique_count") +) +``` + +**Statistical aggregations:** +```python +df.group_by("group").agg( + pl.col("value").sum().alias("total"), + pl.col("value").mean().alias("average"), + pl.col("value").median().alias("median"), + pl.col("value").std().alias("std_dev"), + pl.col("value").var().alias("variance"), + pl.col("value").min().alias("minimum"), + pl.col("value").max().alias("maximum"), + pl.col("value").quantile(0.95).alias("p95") +) +``` + +**First and last:** +```python +df.group_by("user_id").agg( + pl.col("timestamp").first().alias("first_seen"), + pl.col("timestamp").last().alias("last_seen"), + pl.col("event").first().alias("first_event") +) +``` + +**List aggregation:** +```python +# Collect values into lists +df.group_by("category").agg( + pl.col("item").alias("all_items") # Creates list column +) +``` + +### Conditional Aggregations + +Filter within aggregations: + +```python +df.group_by("department").agg( + # Count high earners + (pl.col("salary") > 100000).sum().alias("high_earners"), + + # Average of filtered values + pl.col("salary").filter(pl.col("bonus") > 0).mean().alias("avg_with_bonus"), + + # Conditional sum + pl.when(pl.col("active")) + .then(pl.col("sales")) + .otherwise(0) + .sum() + .alias("active_sales") +) +``` + +### Multiple Aggregations + +Combine multiple aggregations efficiently: + +```python +df.group_by("store_id").agg( + pl.col("transaction_id").count().alias("num_transactions"), + pl.col("amount").sum().alias("total_sales"), + pl.col("amount").mean().alias("avg_transaction"), + pl.col("customer_id").n_unique().alias("unique_customers"), + pl.col("amount").max().alias("largest_transaction"), + pl.col("timestamp").min().alias("first_transaction_date"), + pl.col("timestamp").max().alias("last_transaction_date") +) +``` + +## Window Functions + +Window functions apply aggregations while preserving the original row count. + +### Basic Window Operations + +**Group statistics:** +```python +# Add group mean to each row +df.with_columns( + avg_age_by_dept=pl.col("age").mean().over("department") +) + +# Multiple group columns +df.with_columns( + group_avg=pl.col("value").mean().over("category", "region") +) +``` + +**Ranking:** +```python +df.with_columns( + # Rank within groups + rank=pl.col("score").rank().over("team"), + + # Dense rank (no gaps) + dense_rank=pl.col("score").rank(method="dense").over("team"), + + # Row number + row_num=pl.col("timestamp").sort().rank(method="ordinal").over("user_id") +) +``` + +### Window Mapping Strategies + +**group_to_rows (default):** +Preserves original row order: +```python +df.with_columns( + group_mean=pl.col("value").mean().over("category", mapping_strategy="group_to_rows") +) +``` + +**explode:** +Faster, groups rows together: +```python +df.with_columns( + group_mean=pl.col("value").mean().over("category", mapping_strategy="explode") +) +``` + +**join:** +Creates list columns: +```python +df.with_columns( + group_values=pl.col("value").over("category", mapping_strategy="join") +) +``` + +### Rolling Windows + +**Time-based rolling:** +```python +df.with_columns( + rolling_avg=pl.col("value").rolling_mean( + window_size="7d", + by="date" + ) +) +``` + +**Row-based rolling:** +```python +df.with_columns( + rolling_sum=pl.col("value").rolling_sum(window_size=3), + rolling_max=pl.col("value").rolling_max(window_size=5) +) +``` + +### Cumulative Operations + +```python +df.with_columns( + cumsum=pl.col("value").cum_sum().over("group"), + cummax=pl.col("value").cum_max().over("group"), + cummin=pl.col("value").cum_min().over("group"), + cumprod=pl.col("value").cum_prod().over("group") +) +``` + +### Shift and Lag/Lead + +```python +df.with_columns( + # Previous value (lag) + prev_value=pl.col("value").shift(1).over("user_id"), + + # Next value (lead) + next_value=pl.col("value").shift(-1).over("user_id"), + + # Calculate difference from previous + diff=pl.col("value") - pl.col("value").shift(1).over("user_id") +) +``` + +## Sorting + +### Basic Sorting + +```python +# Sort by single column +df.sort("age") + +# Sort descending +df.sort("age", descending=True) + +# Sort by multiple columns +df.sort("department", "age") + +# Mixed sorting order +df.sort(["department", "salary"], descending=[False, True]) +``` + +### Advanced Sorting + +**Null handling:** +```python +# Nulls first +df.sort("value", nulls_last=False) + +# Nulls last +df.sort("value", nulls_last=True) +``` + +**Sort by expression:** +```python +# Sort by computed value +df.sort(pl.col("first_name").str.len()) + +# Sort by multiple expressions +df.sort( + pl.col("last_name").str.to_lowercase(), + pl.col("age").abs() +) +``` + +## Conditional Operations + +### When/Then/Otherwise + +```python +# Basic conditional +df.with_columns( + status=pl.when(pl.col("age") >= 18) + .then("adult") + .otherwise("minor") +) + +# Multiple conditions +df.with_columns( + category=pl.when(pl.col("score") >= 90) + .then("A") + .when(pl.col("score") >= 80) + .then("B") + .when(pl.col("score") >= 70) + .then("C") + .otherwise("F") +) + +# Conditional computation +df.with_columns( + adjusted_price=pl.when(pl.col("is_member")) + .then(pl.col("price") * 0.9) + .otherwise(pl.col("price")) +) +``` + +## String Operations + +### Common String Methods + +```python +df.with_columns( + # Case conversion + upper=pl.col("name").str.to_uppercase(), + lower=pl.col("name").str.to_lowercase(), + title=pl.col("name").str.to_titlecase(), + + # Trimming + trimmed=pl.col("text").str.strip_chars(), + + # Substring + first_3=pl.col("name").str.slice(0, 3), + + # Replace + cleaned=pl.col("text").str.replace("old", "new"), + cleaned_all=pl.col("text").str.replace_all("old", "new"), + + # Split + parts=pl.col("full_name").str.split(" "), + + # Length + name_length=pl.col("name").str.len_chars() +) +``` + +### String Filtering + +```python +# Contains +df.filter(pl.col("email").str.contains("@gmail.com")) + +# Starts/ends with +df.filter(pl.col("name").str.starts_with("A")) +df.filter(pl.col("file").str.ends_with(".csv")) + +# Regex matching +df.filter(pl.col("phone").str.contains(r"^\d{3}-\d{4}$")) +``` + +## Date and Time Operations + +### Date Parsing + +```python +# Parse strings to dates +df.with_columns( + date=pl.col("date_str").str.strptime(pl.Date, "%Y-%m-%d"), + datetime=pl.col("dt_str").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S") +) +``` + +### Date Components + +```python +df.with_columns( + year=pl.col("date").dt.year(), + month=pl.col("date").dt.month(), + day=pl.col("date").dt.day(), + weekday=pl.col("date").dt.weekday(), + hour=pl.col("datetime").dt.hour(), + minute=pl.col("datetime").dt.minute() +) +``` + +### Date Arithmetic + +```python +# Add duration +df.with_columns( + next_week=pl.col("date") + pl.duration(weeks=1), + next_month=pl.col("date") + pl.duration(months=1) +) + +# Difference between dates +df.with_columns( + days_diff=(pl.col("end_date") - pl.col("start_date")).dt.total_days() +) +``` + +### Date Filtering + +```python +# Filter by date range +df.filter( + pl.col("date").is_between(pl.date(2023, 1, 1), pl.date(2023, 12, 31)) +) + +# Filter by year +df.filter(pl.col("date").dt.year() == 2023) + +# Filter by month +df.filter(pl.col("date").dt.month().is_in([6, 7, 8])) # Summer months +``` + +## List Operations + +### Working with List Columns + +```python +# Create list column +df.with_columns( + items_list=pl.col("item1", "item2", "item3").to_list() +) + +# List operations +df.with_columns( + list_len=pl.col("items").list.len(), + first_item=pl.col("items").list.first(), + last_item=pl.col("items").list.last(), + unique_items=pl.col("items").list.unique(), + sorted_items=pl.col("items").list.sort() +) + +# Explode lists to rows +df.explode("items") + +# Filter list elements +df.with_columns( + filtered=pl.col("items").list.eval(pl.element() > 10) +) +``` + +## Struct Operations + +### Working with Nested Structures + +```python +# Create struct column +df.with_columns( + address=pl.struct(["street", "city", "zip"]) +) + +# Access struct fields +df.with_columns( + city=pl.col("address").struct.field("city") +) + +# Unnest struct to columns +df.unnest("address") +``` + +## Unique and Duplicate Operations + +```python +# Get unique rows +df.unique() + +# Unique on specific columns +df.unique(subset=["name", "email"]) + +# Keep first/last duplicate +df.unique(subset=["id"], keep="first") +df.unique(subset=["id"], keep="last") + +# Identify duplicates +df.with_columns( + is_duplicate=pl.col("id").is_duplicated() +) + +# Count duplicates +df.group_by("email").agg( + pl.len().alias("count") +).filter(pl.col("count") > 1) +``` + +## Sampling + +```python +# Random sample +df.sample(n=100) + +# Sample fraction +df.sample(fraction=0.1) + +# Sample with seed for reproducibility +df.sample(n=100, seed=42) +``` + +## Column Renaming + +```python +# Rename specific columns +df.rename({"old_name": "new_name", "age": "years"}) + +# Rename with expression +df.select(pl.col("*").name.suffix("_renamed")) +df.select(pl.col("*").name.prefix("data_")) +df.select(pl.col("*").name.to_uppercase()) +``` diff --git a/scientific-packages/polars/references/pandas_migration.md b/scientific-packages/polars/references/pandas_migration.md new file mode 100644 index 0000000..aa5fd24 --- /dev/null +++ b/scientific-packages/polars/references/pandas_migration.md @@ -0,0 +1,417 @@ +# Pandas to Polars Migration Guide + +This guide helps you migrate from pandas to Polars with comprehensive operation mappings and key differences. + +## Core Conceptual Differences + +### 1. No Index System + +**Pandas:** Uses row-based indexing system +```python +df.loc[0, "column"] +df.iloc[0:5] +df.set_index("id") +``` + +**Polars:** Uses integer positions only +```python +df[0, "column"] # Row position, column name +df[0:5] # Row slice +# No set_index equivalent - use group_by instead +``` + +### 2. Memory Format + +**Pandas:** Row-oriented NumPy arrays +**Polars:** Columnar Apache Arrow format + +**Implications:** +- Polars is faster for column operations +- Polars uses less memory +- Polars has better data sharing capabilities + +### 3. Parallelization + +**Pandas:** Primarily single-threaded (requires Dask for parallelism) +**Polars:** Parallel by default using Rust's concurrency + +### 4. Lazy Evaluation + +**Pandas:** Only eager evaluation +**Polars:** Both eager (DataFrame) and lazy (LazyFrame) with query optimization + +### 5. Type Strictness + +**Pandas:** Allows silent type conversions +**Polars:** Strict typing, explicit casts required + +**Example:** +```python +# Pandas: Silently converts to float +pd_df["int_col"] = [1, 2, None, 4] # dtype: float64 + +# Polars: Keeps as integer with null +pl_df = pl.DataFrame({"int_col": [1, 2, None, 4]}) # dtype: Int64 +``` + +## Operation Mappings + +### Data Selection + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Select column | `df["col"]` or `df.col` | `df.select("col")` or `df["col"]` | +| Select multiple | `df[["a", "b"]]` | `df.select("a", "b")` | +| Select by position | `df.iloc[:, 0:3]` | `df.select(pl.col(df.columns[0:3]))` | +| Select by condition | `df[df["age"] > 25]` | `df.filter(pl.col("age") > 25)` | + +### Data Filtering + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Single condition | `df[df["age"] > 25]` | `df.filter(pl.col("age") > 25)` | +| Multiple conditions | `df[(df["age"] > 25) & (df["city"] == "NY")]` | `df.filter(pl.col("age") > 25, pl.col("city") == "NY")` | +| Query method | `df.query("age > 25")` | `df.filter(pl.col("age") > 25)` | +| isin | `df[df["city"].isin(["NY", "LA"])]` | `df.filter(pl.col("city").is_in(["NY", "LA"]))` | +| isna | `df[df["value"].isna()]` | `df.filter(pl.col("value").is_null())` | +| notna | `df[df["value"].notna()]` | `df.filter(pl.col("value").is_not_null())` | + +### Adding/Modifying Columns + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Add column | `df["new"] = df["old"] * 2` | `df.with_columns(new=pl.col("old") * 2)` | +| Multiple columns | `df.assign(a=..., b=...)` | `df.with_columns(a=..., b=...)` | +| Conditional column | `np.where(condition, a, b)` | `pl.when(condition).then(a).otherwise(b)` | + +**Important difference - Parallel execution:** + +```python +# Pandas: Sequential (lambda sees previous results) +df.assign( + a=lambda df_: df_.value * 10, + b=lambda df_: df_.value * 100 +) + +# Polars: Parallel (all computed together) +df.with_columns( + a=pl.col("value") * 10, + b=pl.col("value") * 100 +) +``` + +### Grouping and Aggregation + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Group by | `df.groupby("col")` | `df.group_by("col")` | +| Agg single | `df.groupby("col")["val"].mean()` | `df.group_by("col").agg(pl.col("val").mean())` | +| Agg multiple | `df.groupby("col").agg({"val": ["mean", "sum"]})` | `df.group_by("col").agg(pl.col("val").mean(), pl.col("val").sum())` | +| Size | `df.groupby("col").size()` | `df.group_by("col").agg(pl.len())` | +| Count | `df.groupby("col").count()` | `df.group_by("col").agg(pl.col("*").count())` | + +### Window Functions + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Transform | `df.groupby("col").transform("mean")` | `df.with_columns(pl.col("val").mean().over("col"))` | +| Rank | `df.groupby("col")["val"].rank()` | `df.with_columns(pl.col("val").rank().over("col"))` | +| Shift | `df.groupby("col")["val"].shift(1)` | `df.with_columns(pl.col("val").shift(1).over("col"))` | +| Cumsum | `df.groupby("col")["val"].cumsum()` | `df.with_columns(pl.col("val").cum_sum().over("col"))` | + +### Joins + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Inner join | `df1.merge(df2, on="id")` | `df1.join(df2, on="id", how="inner")` | +| Left join | `df1.merge(df2, on="id", how="left")` | `df1.join(df2, on="id", how="left")` | +| Different keys | `df1.merge(df2, left_on="a", right_on="b")` | `df1.join(df2, left_on="a", right_on="b")` | + +### Concatenation + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Vertical | `pd.concat([df1, df2], axis=0)` | `pl.concat([df1, df2], how="vertical")` | +| Horizontal | `pd.concat([df1, df2], axis=1)` | `pl.concat([df1, df2], how="horizontal")` | + +### Sorting + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Sort by column | `df.sort_values("col")` | `df.sort("col")` | +| Descending | `df.sort_values("col", ascending=False)` | `df.sort("col", descending=True)` | +| Multiple columns | `df.sort_values(["a", "b"])` | `df.sort("a", "b")` | + +### Reshaping + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Pivot | `df.pivot(index="a", columns="b", values="c")` | `df.pivot(values="c", index="a", columns="b")` | +| Melt | `df.melt(id_vars="id")` | `df.unpivot(index="id")` | + +### I/O Operations + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Read CSV | `pd.read_csv("file.csv")` | `pl.read_csv("file.csv")` or `pl.scan_csv()` | +| Write CSV | `df.to_csv("file.csv")` | `df.write_csv("file.csv")` | +| Read Parquet | `pd.read_parquet("file.parquet")` | `pl.read_parquet("file.parquet")` | +| Write Parquet | `df.to_parquet("file.parquet")` | `df.write_parquet("file.parquet")` | +| Read Excel | `pd.read_excel("file.xlsx")` | `pl.read_excel("file.xlsx")` | + +### String Operations + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Upper | `df["col"].str.upper()` | `df.select(pl.col("col").str.to_uppercase())` | +| Lower | `df["col"].str.lower()` | `df.select(pl.col("col").str.to_lowercase())` | +| Contains | `df["col"].str.contains("pattern")` | `df.filter(pl.col("col").str.contains("pattern"))` | +| Replace | `df["col"].str.replace("old", "new")` | `df.select(pl.col("col").str.replace("old", "new"))` | +| Split | `df["col"].str.split(" ")` | `df.select(pl.col("col").str.split(" "))` | + +### Datetime Operations + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Parse dates | `pd.to_datetime(df["col"])` | `df.select(pl.col("col").str.strptime(pl.Date, "%Y-%m-%d"))` | +| Year | `df["date"].dt.year` | `df.select(pl.col("date").dt.year())` | +| Month | `df["date"].dt.month` | `df.select(pl.col("date").dt.month())` | +| Day | `df["date"].dt.day` | `df.select(pl.col("date").dt.day())` | + +### Missing Data + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Drop nulls | `df.dropna()` | `df.drop_nulls()` | +| Fill nulls | `df.fillna(0)` | `df.fill_null(0)` | +| Check null | `df["col"].isna()` | `df.select(pl.col("col").is_null())` | +| Forward fill | `df.fillna(method="ffill")` | `df.select(pl.col("col").fill_null(strategy="forward"))` | + +### Other Operations + +| Operation | Pandas | Polars | +|-----------|--------|--------| +| Unique values | `df["col"].unique()` | `df["col"].unique()` | +| Value counts | `df["col"].value_counts()` | `df["col"].value_counts()` | +| Describe | `df.describe()` | `df.describe()` | +| Sample | `df.sample(n=100)` | `df.sample(n=100)` | +| Head | `df.head()` | `df.head()` | +| Tail | `df.tail()` | `df.tail()` | + +## Common Migration Patterns + +### Pattern 1: Chained Operations + +**Pandas:** +```python +result = (df + .assign(new_col=lambda x: x["old_col"] * 2) + .query("new_col > 10") + .groupby("category") + .agg({"value": "sum"}) + .reset_index() +) +``` + +**Polars:** +```python +result = (df + .with_columns(new_col=pl.col("old_col") * 2) + .filter(pl.col("new_col") > 10) + .group_by("category") + .agg(pl.col("value").sum()) +) +# No reset_index needed - Polars doesn't have index +``` + +### Pattern 2: Apply Functions + +**Pandas:** +```python +# Avoid in Polars - breaks parallelization +df["result"] = df["value"].apply(lambda x: x * 2) +``` + +**Polars:** +```python +# Use expressions instead +df = df.with_columns(result=pl.col("value") * 2) + +# If custom function needed +df = df.with_columns( + result=pl.col("value").map_elements(lambda x: x * 2, return_dtype=pl.Float64) +) +``` + +### Pattern 3: Conditional Column Creation + +**Pandas:** +```python +df["category"] = np.where( + df["value"] > 100, + "high", + np.where(df["value"] > 50, "medium", "low") +) +``` + +**Polars:** +```python +df = df.with_columns( + category=pl.when(pl.col("value") > 100) + .then("high") + .when(pl.col("value") > 50) + .then("medium") + .otherwise("low") +) +``` + +### Pattern 4: Group Transform + +**Pandas:** +```python +df["group_mean"] = df.groupby("category")["value"].transform("mean") +``` + +**Polars:** +```python +df = df.with_columns( + group_mean=pl.col("value").mean().over("category") +) +``` + +### Pattern 5: Multiple Aggregations + +**Pandas:** +```python +result = df.groupby("category").agg({ + "value": ["mean", "sum", "count"], + "price": ["min", "max"] +}) +``` + +**Polars:** +```python +result = df.group_by("category").agg( + pl.col("value").mean().alias("value_mean"), + pl.col("value").sum().alias("value_sum"), + pl.col("value").count().alias("value_count"), + pl.col("price").min().alias("price_min"), + pl.col("price").max().alias("price_max") +) +``` + +## Performance Anti-Patterns to Avoid + +### Anti-Pattern 1: Sequential Pipe Operations + +**Bad (disables parallelization):** +```python +df = df.pipe(function1).pipe(function2).pipe(function3) +``` + +**Good (enables parallelization):** +```python +df = df.with_columns( + function1_result(), + function2_result(), + function3_result() +) +``` + +### Anti-Pattern 2: Python Functions in Hot Paths + +**Bad:** +```python +df = df.with_columns( + result=pl.col("value").map_elements(lambda x: x * 2) +) +``` + +**Good:** +```python +df = df.with_columns(result=pl.col("value") * 2) +``` + +### Anti-Pattern 3: Using Eager Reading for Large Files + +**Bad:** +```python +df = pl.read_csv("large_file.csv") +result = df.filter(pl.col("age") > 25).select("name", "age") +``` + +**Good:** +```python +lf = pl.scan_csv("large_file.csv") +result = lf.filter(pl.col("age") > 25).select("name", "age").collect() +``` + +### Anti-Pattern 4: Row Iteration + +**Bad:** +```python +for row in df.iter_rows(): + # Process row + pass +``` + +**Good:** +```python +# Use vectorized operations +df = df.with_columns( + # Vectorized computation +) +``` + +## Migration Checklist + +When migrating from pandas to Polars: + +1. **Remove index operations** - Use integer positions or group_by +2. **Replace apply/map with expressions** - Use Polars native operations +3. **Update column assignment** - Use `with_columns()` instead of direct assignment +4. **Change groupby.transform to .over()** - Window functions work differently +5. **Update string operations** - Use `.str.to_uppercase()` instead of `.str.upper()` +6. **Add explicit type casts** - Polars won't silently convert types +7. **Consider lazy evaluation** - Use `scan_*` instead of `read_*` for large data +8. **Update aggregation syntax** - More explicit in Polars +9. **Remove reset_index calls** - Not needed in Polars +10. **Update conditional logic** - Use `when().then().otherwise()` pattern + +## Compatibility Layer + +For gradual migration, you can use both libraries: + +```python +import pandas as pd +import polars as pl + +# Convert pandas to Polars +pl_df = pl.from_pandas(pd_df) + +# Convert Polars to pandas +pd_df = pl_df.to_pandas() + +# Use Arrow for zero-copy (when possible) +pl_df = pl.from_arrow(pd_df) +pd_df = pl_df.to_arrow().to_pandas() +``` + +## When to Stick with Pandas + +Consider staying with pandas when: +- Working with time series requiring complex index operations +- Need extensive ecosystem support (some libraries only support pandas) +- Team lacks Rust/Polars expertise +- Data is small and performance isn't critical +- Using advanced pandas features without Polars equivalents + +## When to Switch to Polars + +Switch to Polars when: +- Performance is critical +- Working with large datasets (>1GB) +- Need lazy evaluation and query optimization +- Want better type safety +- Need parallel execution by default +- Starting a new project diff --git a/scientific-packages/polars/references/transformations.md b/scientific-packages/polars/references/transformations.md new file mode 100644 index 0000000..af57f1c --- /dev/null +++ b/scientific-packages/polars/references/transformations.md @@ -0,0 +1,549 @@ +# Polars Data Transformations + +Comprehensive guide to joins, concatenation, and reshaping operations in Polars. + +## Joins + +Joins combine data from multiple DataFrames based on common columns. + +### Basic Join Types + +**Inner Join (intersection):** +```python +# Keep only matching rows from both DataFrames +result = df1.join(df2, on="id", how="inner") +``` + +**Left Join (all left + matches from right):** +```python +# Keep all rows from left, add matching rows from right +result = df1.join(df2, on="id", how="left") +``` + +**Outer Join (union):** +```python +# Keep all rows from both DataFrames +result = df1.join(df2, on="id", how="outer") +``` + +**Cross Join (Cartesian product):** +```python +# Every row from left with every row from right +result = df1.join(df2, how="cross") +``` + +**Semi Join (filtered left):** +```python +# Keep only left rows that have a match in right +result = df1.join(df2, on="id", how="semi") +``` + +**Anti Join (non-matching left):** +```python +# Keep only left rows that DON'T have a match in right +result = df1.join(df2, on="id", how="anti") +``` + +### Join Syntax Variations + +**Single column join:** +```python +df1.join(df2, on="id") +``` + +**Multiple columns join:** +```python +df1.join(df2, on=["id", "date"]) +``` + +**Different column names:** +```python +df1.join(df2, left_on="user_id", right_on="id") +``` + +**Multiple different columns:** +```python +df1.join( + df2, + left_on=["user_id", "date"], + right_on=["id", "timestamp"] +) +``` + +### Suffix Handling + +When both DataFrames have columns with the same name (other than join keys): + +```python +# Add suffixes to distinguish columns +result = df1.join(df2, on="id", suffix="_right") + +# Results in: value, value_right (if both had "value" column) +``` + +### Join Examples + +**Example 1: Customer Orders** +```python +customers = pl.DataFrame({ + "customer_id": [1, 2, 3, 4], + "name": ["Alice", "Bob", "Charlie", "David"] +}) + +orders = pl.DataFrame({ + "order_id": [101, 102, 103], + "customer_id": [1, 2, 1], + "amount": [100, 200, 150] +}) + +# Inner join - only customers with orders +result = customers.join(orders, on="customer_id", how="inner") + +# Left join - all customers, even without orders +result = customers.join(orders, on="customer_id", how="left") +``` + +**Example 2: Time-series data** +```python +prices = pl.DataFrame({ + "date": ["2023-01-01", "2023-01-02", "2023-01-03"], + "stock": ["AAPL", "AAPL", "AAPL"], + "price": [150, 152, 151] +}) + +volumes = pl.DataFrame({ + "date": ["2023-01-01", "2023-01-02"], + "stock": ["AAPL", "AAPL"], + "volume": [1000000, 1100000] +}) + +result = prices.join( + volumes, + on=["date", "stock"], + how="left" +) +``` + +### Asof Joins (Nearest Match) + +For time-series data, join to nearest timestamp: + +```python +# Join to nearest earlier timestamp +quotes = pl.DataFrame({ + "timestamp": [1, 2, 3, 4, 5], + "stock": ["A", "A", "A", "A", "A"], + "quote": [100, 101, 102, 103, 104] +}) + +trades = pl.DataFrame({ + "timestamp": [1.5, 3.5, 4.2], + "stock": ["A", "A", "A"], + "trade": [50, 75, 100] +}) + +result = trades.join_asof( + quotes, + on="timestamp", + by="stock", + strategy="backward" # or "forward", "nearest" +) +``` + +## Concatenation + +Concatenation stacks DataFrames together. + +### Vertical Concatenation (Stack Rows) + +```python +df1 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) +df2 = pl.DataFrame({"a": [5, 6], "b": [7, 8]}) + +# Stack rows +result = pl.concat([df1, df2], how="vertical") +# Result: 4 rows, same columns +``` + +**Handling mismatched schemas:** +```python +df1 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) +df2 = pl.DataFrame({"a": [5, 6], "c": [7, 8]}) + +# Diagonal concat - fills missing columns with nulls +result = pl.concat([df1, df2], how="diagonal") +# Result: columns a, b, c (with nulls where not present) +``` + +### Horizontal Concatenation (Stack Columns) + +```python +df1 = pl.DataFrame({"a": [1, 2, 3]}) +df2 = pl.DataFrame({"b": [4, 5, 6]}) + +# Stack columns +result = pl.concat([df1, df2], how="horizontal") +# Result: 3 rows, columns a and b +``` + +**Note:** Horizontal concat requires same number of rows. + +### Concatenation Options + +```python +# Rechunk after concatenation (better performance for subsequent operations) +result = pl.concat([df1, df2], rechunk=True) + +# Parallel execution +result = pl.concat([df1, df2], parallel=True) +``` + +### Use Cases + +**Combining data from multiple sources:** +```python +# Read multiple files and concatenate +files = ["data_2023.csv", "data_2024.csv", "data_2025.csv"] +dfs = [pl.read_csv(f) for f in files] +combined = pl.concat(dfs, how="vertical") +``` + +**Adding computed columns:** +```python +base = pl.DataFrame({"value": [1, 2, 3]}) +computed = pl.DataFrame({"doubled": [2, 4, 6]}) +result = pl.concat([base, computed], how="horizontal") +``` + +## Pivoting (Wide Format) + +Convert unique values from one column into multiple columns. + +### Basic Pivot + +```python +df = pl.DataFrame({ + "date": ["2023-01", "2023-01", "2023-02", "2023-02"], + "product": ["A", "B", "A", "B"], + "sales": [100, 150, 120, 160] +}) + +# Pivot: products become columns +pivoted = df.pivot( + values="sales", + index="date", + columns="product" +) +# Result: +# date | A | B +# 2023-01 | 100 | 150 +# 2023-02 | 120 | 160 +``` + +### Pivot with Aggregation + +When there are duplicate combinations, aggregate: + +```python +df = pl.DataFrame({ + "date": ["2023-01", "2023-01", "2023-01"], + "product": ["A", "A", "B"], + "sales": [100, 110, 150] +}) + +# Aggregate duplicates +pivoted = df.pivot( + values="sales", + index="date", + columns="product", + aggregate_function="sum" # or "mean", "max", "min", etc. +) +``` + +### Multiple Index Columns + +```python +df = pl.DataFrame({ + "region": ["North", "North", "South", "South"], + "date": ["2023-01", "2023-01", "2023-01", "2023-01"], + "product": ["A", "B", "A", "B"], + "sales": [100, 150, 120, 160] +}) + +pivoted = df.pivot( + values="sales", + index=["region", "date"], + columns="product" +) +``` + +## Unpivoting/Melting (Long Format) + +Convert multiple columns into rows (opposite of pivot). + +### Basic Unpivot + +```python +df = pl.DataFrame({ + "date": ["2023-01", "2023-02"], + "product_A": [100, 120], + "product_B": [150, 160] +}) + +# Unpivot: convert columns to rows +unpivoted = df.unpivot( + index="date", + on=["product_A", "product_B"] +) +# Result: +# date | variable | value +# 2023-01 | product_A | 100 +# 2023-01 | product_B | 150 +# 2023-02 | product_A | 120 +# 2023-02 | product_B | 160 +``` + +### Custom Column Names + +```python +unpivoted = df.unpivot( + index="date", + on=["product_A", "product_B"], + variable_name="product", + value_name="sales" +) +``` + +### Unpivot by Pattern + +```python +# Unpivot all columns matching pattern +df = pl.DataFrame({ + "id": [1, 2], + "sales_Q1": [100, 200], + "sales_Q2": [150, 250], + "sales_Q3": [120, 220], + "revenue_Q1": [1000, 2000] +}) + +# Unpivot all sales columns +unpivoted = df.unpivot( + index="id", + on=pl.col("^sales_.*$") +) +``` + +## Exploding (Unnesting Lists) + +Convert list columns into multiple rows. + +### Basic Explode + +```python +df = pl.DataFrame({ + "id": [1, 2], + "values": [[1, 2, 3], [4, 5]] +}) + +# Explode list into rows +exploded = df.explode("values") +# Result: +# id | values +# 1 | 1 +# 1 | 2 +# 1 | 3 +# 2 | 4 +# 2 | 5 +``` + +### Multiple Column Explode + +```python +df = pl.DataFrame({ + "id": [1, 2], + "letters": [["a", "b"], ["c", "d"]], + "numbers": [[1, 2], [3, 4]] +}) + +# Explode multiple columns (must be same length) +exploded = df.explode("letters", "numbers") +``` + +## Transposing + +Swap rows and columns: + +```python +df = pl.DataFrame({ + "metric": ["sales", "costs", "profit"], + "Q1": [100, 60, 40], + "Q2": [150, 80, 70] +}) + +# Transpose +transposed = df.transpose( + include_header=True, + header_name="quarter", + column_names="metric" +) +# Result: quarters as rows, metrics as columns +``` + +## Reshaping Patterns + +### Pattern 1: Wide to Long to Wide + +```python +# Start wide +wide = pl.DataFrame({ + "id": [1, 2], + "A": [10, 20], + "B": [30, 40] +}) + +# To long +long = wide.unpivot(index="id", on=["A", "B"]) + +# Back to wide (maybe with transformations) +wide_again = long.pivot(values="value", index="id", columns="variable") +``` + +### Pattern 2: Nested to Flat + +```python +# Nested data +df = pl.DataFrame({ + "user": [1, 2], + "purchases": [ + [{"item": "A", "qty": 2}, {"item": "B", "qty": 1}], + [{"item": "C", "qty": 3}] + ] +}) + +# Explode and unnest +flat = ( + df.explode("purchases") + .unnest("purchases") +) +``` + +### Pattern 3: Aggregation to Pivot + +```python +# Raw data +sales = pl.DataFrame({ + "date": ["2023-01", "2023-01", "2023-02"], + "product": ["A", "B", "A"], + "sales": [100, 150, 120] +}) + +# Aggregate then pivot +result = ( + sales + .group_by("date", "product") + .agg(pl.col("sales").sum()) + .pivot(values="sales", index="date", columns="product") +) +``` + +## Advanced Transformations + +### Conditional Reshaping + +```python +# Pivot only certain values +df.filter(pl.col("year") >= 2020).pivot(...) + +# Unpivot with filtering +df.unpivot(index="id", on=pl.col("^sales.*$")) +``` + +### Multi-level Transformations + +```python +# Complex reshaping pipeline +result = ( + df + .unpivot(index="id", on=pl.col("^Q[0-9]_.*$")) + .with_columns( + quarter=pl.col("variable").str.extract(r"Q([0-9])", 1), + metric=pl.col("variable").str.extract(r"Q[0-9]_(.*)", 1) + ) + .drop("variable") + .pivot(values="value", index=["id", "quarter"], columns="metric") +) +``` + +## Performance Considerations + +### Join Performance + +```python +# 1. Join on indexed/sorted columns when possible +df1_sorted = df1.sort("id") +df2_sorted = df2.sort("id") +result = df1_sorted.join(df2_sorted, on="id") + +# 2. Use appropriate join type +# semi/anti are faster than inner+filter +matches = df1.join(df2, on="id", how="semi") # Better than filtering after inner join + +# 3. Filter before joining +df1_filtered = df1.filter(pl.col("active")) +result = df1_filtered.join(df2, on="id") # Smaller join +``` + +### Concatenation Performance + +```python +# 1. Rechunk after concatenation +result = pl.concat(dfs, rechunk=True) + +# 2. Use lazy mode for large concatenations +lf1 = pl.scan_parquet("file1.parquet") +lf2 = pl.scan_parquet("file2.parquet") +result = pl.concat([lf1, lf2]).collect() +``` + +### Pivot Performance + +```python +# 1. Filter before pivoting +pivoted = df.filter(pl.col("year") == 2023).pivot(...) + +# 2. Specify aggregate function explicitly +pivoted = df.pivot(..., aggregate_function="first") # Faster than "sum" if only one value +``` + +## Common Use Cases + +### Time Series Alignment + +```python +# Align two time series with different timestamps +ts1.join_asof(ts2, on="timestamp", strategy="backward") +``` + +### Feature Engineering + +```python +# Create lag features +df.with_columns( + pl.col("value").shift(1).over("user_id").alias("prev_value"), + pl.col("value").shift(2).over("user_id").alias("prev_prev_value") +) +``` + +### Data Denormalization + +```python +# Combine normalized tables +orders.join(customers, on="customer_id").join(products, on="product_id") +``` + +### Report Generation + +```python +# Pivot for reporting +sales.pivot(values="amount", index="month", columns="product") +``` diff --git a/scientific-packages/pubchem-database/SKILL.md b/scientific-packages/pubchem-database/SKILL.md new file mode 100644 index 0000000..bd95b1c --- /dev/null +++ b/scientific-packages/pubchem-database/SKILL.md @@ -0,0 +1,557 @@ +--- +name: pubchem-database +description: Access chemical compound data from PubChem, the world's largest free chemical database. This skill should be used when retrieving compound properties, searching for chemicals by name/SMILES/InChI, performing similarity or substructure searches, accessing bioactivity data, converting between chemical formats, or generating chemical structure images. Works with over 110 million compounds and 270 million bioactivities through PUG-REST API and PubChemPy library. +--- + +# PubChem Database + +## Overview + +PubChem is the world's largest freely available chemical database maintained by the National Center for Biotechnology Information (NCBI). It contains over 110 million unique chemical structures and over 270 million bioactivities from more than 770 data sources. This skill provides guidance for programmatically accessing PubChem data using the PUG-REST API and PubChemPy Python library. + +## Core Capabilities + +### 1. Chemical Structure Search + +Search for compounds using multiple identifier types: + +**By Chemical Name**: +```python +import pubchempy as pcp +compounds = pcp.get_compounds('aspirin', 'name') +compound = compounds[0] +``` + +**By CID (Compound ID)**: +```python +compound = pcp.Compound.from_cid(2244) # Aspirin +``` + +**By SMILES**: +```python +compound = pcp.get_compounds('CC(=O)OC1=CC=CC=C1C(=O)O', 'smiles')[0] +``` + +**By InChI**: +```python +compound = pcp.get_compounds('InChI=1S/C9H8O4/...', 'inchi')[0] +``` + +**By Molecular Formula**: +```python +compounds = pcp.get_compounds('C9H8O4', 'formula') +# Returns all compounds matching this formula +``` + +### 2. Property Retrieval + +Retrieve molecular properties for compounds using either high-level or low-level approaches: + +**Using PubChemPy (Recommended)**: +```python +import pubchempy as pcp + +# Get compound object with all properties +compound = pcp.get_compounds('caffeine', 'name')[0] + +# Access individual properties +molecular_formula = compound.molecular_formula +molecular_weight = compound.molecular_weight +iupac_name = compound.iupac_name +smiles = compound.canonical_smiles +inchi = compound.inchi +xlogp = compound.xlogp # Partition coefficient +tpsa = compound.tpsa # Topological polar surface area +``` + +**Get Specific Properties**: +```python +# Request only specific properties +properties = pcp.get_properties( + ['MolecularFormula', 'MolecularWeight', 'CanonicalSMILES', 'XLogP'], + 'aspirin', + 'name' +) +# Returns list of dictionaries +``` + +**Batch Property Retrieval**: +```python +import pandas as pd + +compound_names = ['aspirin', 'ibuprofen', 'paracetamol'] +all_properties = [] + +for name in compound_names: + props = pcp.get_properties( + ['MolecularFormula', 'MolecularWeight', 'XLogP'], + name, + 'name' + ) + all_properties.extend(props) + +df = pd.DataFrame(all_properties) +``` + +**Available Properties**: MolecularFormula, MolecularWeight, CanonicalSMILES, IsomericSMILES, InChI, InChIKey, IUPACName, XLogP, TPSA, HBondDonorCount, HBondAcceptorCount, RotatableBondCount, Complexity, Charge, and many more (see `references/api_reference.md` for complete list). + +### 3. Similarity Search + +Find structurally similar compounds using Tanimoto similarity: + +```python +import pubchempy as pcp + +# Start with a query compound +query_compound = pcp.get_compounds('gefitinib', 'name')[0] +query_smiles = query_compound.canonical_smiles + +# Perform similarity search +similar_compounds = pcp.get_compounds( + query_smiles, + 'smiles', + searchtype='similarity', + Threshold=85, # Similarity threshold (0-100) + MaxRecords=50 +) + +# Process results +for compound in similar_compounds[:10]: + print(f"CID {compound.cid}: {compound.iupac_name}") + print(f" MW: {compound.molecular_weight}") +``` + +**Note**: Similarity searches are asynchronous for large queries and may take 15-30 seconds to complete. PubChemPy handles the asynchronous pattern automatically. + +### 4. Substructure Search + +Find compounds containing a specific structural motif: + +```python +import pubchempy as pcp + +# Search for compounds containing pyridine ring +pyridine_smiles = 'c1ccncc1' + +matches = pcp.get_compounds( + pyridine_smiles, + 'smiles', + searchtype='substructure', + MaxRecords=100 +) + +print(f"Found {len(matches)} compounds containing pyridine") +``` + +**Common Substructures**: +- Benzene ring: `c1ccccc1` +- Pyridine: `c1ccncc1` +- Phenol: `c1ccc(O)cc1` +- Carboxylic acid: `C(=O)O` + +### 5. Format Conversion + +Convert between different chemical structure formats: + +```python +import pubchempy as pcp + +compound = pcp.get_compounds('aspirin', 'name')[0] + +# Convert to different formats +smiles = compound.canonical_smiles +inchi = compound.inchi +inchikey = compound.inchikey +cid = compound.cid + +# Download structure files +pcp.download('SDF', 'aspirin', 'name', 'aspirin.sdf', overwrite=True) +pcp.download('JSON', '2244', 'cid', 'aspirin.json', overwrite=True) +``` + +### 6. Structure Visualization + +Generate 2D structure images: + +```python +import pubchempy as pcp + +# Download compound structure as PNG +pcp.download('PNG', 'caffeine', 'name', 'caffeine.png', overwrite=True) + +# Using direct URL (via requests) +import requests + +cid = 2244 # Aspirin +url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/PNG?image_size=large" +response = requests.get(url) + +with open('structure.png', 'wb') as f: + f.write(response.content) +``` + +### 7. Synonym Retrieval + +Get all known names and synonyms for a compound: + +```python +import pubchempy as pcp + +synonyms_data = pcp.get_synonyms('aspirin', 'name') + +if synonyms_data: + cid = synonyms_data[0]['CID'] + synonyms = synonyms_data[0]['Synonym'] + + print(f"CID {cid} has {len(synonyms)} synonyms:") + for syn in synonyms[:10]: # First 10 + print(f" - {syn}") +``` + +### 8. Bioactivity Data Access + +Retrieve biological activity data from assays: + +```python +import requests +import json + +# Get bioassay summary for a compound +cid = 2244 # Aspirin +url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/assaysummary/JSON" + +response = requests.get(url) +if response.status_code == 200: + data = response.json() + # Process bioassay information + table = data.get('Table', {}) + rows = table.get('Row', []) + print(f"Found {len(rows)} bioassay records") +``` + +**For more complex bioactivity queries**, use the `scripts/bioactivity_query.py` helper script which provides: +- Bioassay summaries with activity outcome filtering +- Assay target identification +- Search for compounds by biological target +- Active compound lists for specific assays + +### 9. Comprehensive Compound Annotations + +Access detailed compound information through PUG-View: + +```python +import requests + +cid = 2244 +url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{cid}/JSON" + +response = requests.get(url) +if response.status_code == 200: + annotations = response.json() + # Contains extensive data including: + # - Chemical and Physical Properties + # - Drug and Medication Information + # - Pharmacology and Biochemistry + # - Safety and Hazards + # - Toxicity + # - Literature references + # - Patents +``` + +**Get Specific Section**: +```python +# Get only drug information +url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{cid}/JSON?heading=Drug and Medication Information" +``` + +## Installation Requirements + +Install PubChemPy for Python-based access: + +```bash +pip install pubchempy +``` + +For direct API access and bioactivity queries: + +```bash +pip install requests +``` + +Optional for data analysis: + +```bash +pip install pandas +``` + +## Helper Scripts + +This skill includes Python scripts for common PubChem tasks: + +### scripts/compound_search.py + +Provides utility functions for searching and retrieving compound information: + +**Key Functions**: +- `search_by_name(name, max_results=10)`: Search compounds by name +- `search_by_smiles(smiles)`: Search by SMILES string +- `get_compound_by_cid(cid)`: Retrieve compound by CID +- `get_compound_properties(identifier, namespace, properties)`: Get specific properties +- `similarity_search(smiles, threshold, max_records)`: Perform similarity search +- `substructure_search(smiles, max_records)`: Perform substructure search +- `get_synonyms(identifier, namespace)`: Get all synonyms +- `batch_search(identifiers, namespace, properties)`: Batch search multiple compounds +- `download_structure(identifier, namespace, format, filename)`: Download structures +- `print_compound_info(compound)`: Print formatted compound information + +**Usage**: +```python +from scripts.compound_search import search_by_name, get_compound_properties + +# Search for a compound +compounds = search_by_name('ibuprofen') + +# Get specific properties +props = get_compound_properties('aspirin', 'name', ['MolecularWeight', 'XLogP']) +``` + +### scripts/bioactivity_query.py + +Provides functions for retrieving biological activity data: + +**Key Functions**: +- `get_bioassay_summary(cid)`: Get bioassay summary for compound +- `get_compound_bioactivities(cid, activity_outcome)`: Get filtered bioactivities +- `get_assay_description(aid)`: Get detailed assay information +- `get_assay_targets(aid)`: Get biological targets for assay +- `search_assays_by_target(target_name, max_results)`: Find assays by target +- `get_active_compounds_in_assay(aid, max_results)`: Get active compounds +- `get_compound_annotations(cid, section)`: Get PUG-View annotations +- `summarize_bioactivities(cid)`: Generate bioactivity summary statistics +- `find_compounds_by_bioactivity(target, threshold, max_compounds)`: Find compounds by target + +**Usage**: +```python +from scripts.bioactivity_query import get_bioassay_summary, summarize_bioactivities + +# Get bioactivity summary +summary = summarize_bioactivities(2244) # Aspirin +print(f"Total assays: {summary['total_assays']}") +print(f"Active: {summary['active']}, Inactive: {summary['inactive']}") +``` + +## API Rate Limits and Best Practices + +**Rate Limits**: +- Maximum 5 requests per second +- Maximum 400 requests per minute +- Maximum 300 seconds running time per minute + +**Best Practices**: +1. **Use CIDs for repeated queries**: CIDs are more efficient than names or structures +2. **Cache results locally**: Store frequently accessed data +3. **Batch requests**: Combine multiple queries when possible +4. **Implement delays**: Add 0.2-0.3 second delays between requests +5. **Handle errors gracefully**: Check for HTTP errors and missing data +6. **Use PubChemPy**: Higher-level abstraction handles many edge cases +7. **Leverage asynchronous pattern**: For large similarity/substructure searches +8. **Specify MaxRecords**: Limit results to avoid timeouts + +**Error Handling**: +```python +from pubchempy import BadRequestError, NotFoundError, TimeoutError + +try: + compound = pcp.get_compounds('query', 'name')[0] +except NotFoundError: + print("Compound not found") +except BadRequestError: + print("Invalid request format") +except TimeoutError: + print("Request timed out - try reducing scope") +except IndexError: + print("No results returned") +``` + +## Common Workflows + +### Workflow 1: Chemical Identifier Conversion Pipeline + +Convert between different chemical identifiers: + +```python +import pubchempy as pcp + +# Start with any identifier type +compound = pcp.get_compounds('caffeine', 'name')[0] + +# Extract all identifier formats +identifiers = { + 'CID': compound.cid, + 'Name': compound.iupac_name, + 'SMILES': compound.canonical_smiles, + 'InChI': compound.inchi, + 'InChIKey': compound.inchikey, + 'Formula': compound.molecular_formula +} +``` + +### Workflow 2: Drug-Like Property Screening + +Screen compounds using Lipinski's Rule of Five: + +```python +import pubchempy as pcp + +def check_drug_likeness(compound_name): + compound = pcp.get_compounds(compound_name, 'name')[0] + + # Lipinski's Rule of Five + rules = { + 'MW <= 500': compound.molecular_weight <= 500, + 'LogP <= 5': compound.xlogp <= 5 if compound.xlogp else None, + 'HBD <= 5': compound.h_bond_donor_count <= 5, + 'HBA <= 10': compound.h_bond_acceptor_count <= 10 + } + + violations = sum(1 for v in rules.values() if v is False) + return rules, violations + +rules, violations = check_drug_likeness('aspirin') +print(f"Lipinski violations: {violations}") +``` + +### Workflow 3: Finding Similar Drug Candidates + +Identify structurally similar compounds to a known drug: + +```python +import pubchempy as pcp + +# Start with known drug +reference_drug = pcp.get_compounds('imatinib', 'name')[0] +reference_smiles = reference_drug.canonical_smiles + +# Find similar compounds +similar = pcp.get_compounds( + reference_smiles, + 'smiles', + searchtype='similarity', + Threshold=85, + MaxRecords=20 +) + +# Filter by drug-like properties +candidates = [] +for comp in similar: + if comp.molecular_weight and 200 <= comp.molecular_weight <= 600: + if comp.xlogp and -1 <= comp.xlogp <= 5: + candidates.append(comp) + +print(f"Found {len(candidates)} drug-like candidates") +``` + +### Workflow 4: Batch Compound Property Comparison + +Compare properties across multiple compounds: + +```python +import pubchempy as pcp +import pandas as pd + +compound_list = ['aspirin', 'ibuprofen', 'naproxen', 'celecoxib'] + +properties_list = [] +for name in compound_list: + try: + compound = pcp.get_compounds(name, 'name')[0] + properties_list.append({ + 'Name': name, + 'CID': compound.cid, + 'Formula': compound.molecular_formula, + 'MW': compound.molecular_weight, + 'LogP': compound.xlogp, + 'TPSA': compound.tpsa, + 'HBD': compound.h_bond_donor_count, + 'HBA': compound.h_bond_acceptor_count + }) + except Exception as e: + print(f"Error processing {name}: {e}") + +df = pd.DataFrame(properties_list) +print(df.to_string(index=False)) +``` + +### Workflow 5: Substructure-Based Virtual Screening + +Screen for compounds containing specific pharmacophores: + +```python +import pubchempy as pcp + +# Define pharmacophore (e.g., sulfonamide group) +pharmacophore_smiles = 'S(=O)(=O)N' + +# Search for compounds containing this substructure +hits = pcp.get_compounds( + pharmacophore_smiles, + 'smiles', + searchtype='substructure', + MaxRecords=100 +) + +# Further filter by properties +filtered_hits = [ + comp for comp in hits + if comp.molecular_weight and comp.molecular_weight < 500 +] + +print(f"Found {len(filtered_hits)} compounds with desired substructure") +``` + +## Reference Documentation + +For detailed API documentation, including complete property lists, URL patterns, advanced query options, and more examples, consult `references/api_reference.md`. This comprehensive reference includes: + +- Complete PUG-REST API endpoint documentation +- Full list of available molecular properties +- Asynchronous request handling patterns +- PubChemPy API reference +- PUG-View API for annotations +- Common workflows and use cases +- Links to official PubChem documentation + +## Troubleshooting + +**Compound Not Found**: +- Try alternative names or synonyms +- Use CID if known +- Check spelling and chemical name format + +**Timeout Errors**: +- Reduce MaxRecords parameter +- Add delays between requests +- Use CIDs instead of names for faster queries + +**Empty Property Values**: +- Not all properties are available for all compounds +- Check if property exists before accessing: `if compound.xlogp:` +- Some properties only available for certain compound types + +**Rate Limit Exceeded**: +- Implement delays (0.2-0.3 seconds) between requests +- Use batch operations where possible +- Consider caching results locally + +**Similarity/Substructure Search Hangs**: +- These are asynchronous operations that may take 15-30 seconds +- PubChemPy handles polling automatically +- Reduce MaxRecords if timing out + +## Additional Resources + +- PubChem Home: https://pubchem.ncbi.nlm.nih.gov/ +- PUG-REST Documentation: https://pubchem.ncbi.nlm.nih.gov/docs/pug-rest +- PUG-REST Tutorial: https://pubchem.ncbi.nlm.nih.gov/docs/pug-rest-tutorial +- PubChemPy Documentation: https://pubchempy.readthedocs.io/ +- PubChemPy GitHub: https://github.com/mcs07/PubChemPy diff --git a/scientific-packages/pubchem-database/references/api_reference.md b/scientific-packages/pubchem-database/references/api_reference.md new file mode 100644 index 0000000..1653107 --- /dev/null +++ b/scientific-packages/pubchem-database/references/api_reference.md @@ -0,0 +1,440 @@ +# PubChem API Reference + +## Overview + +PubChem is the world's largest freely available chemical database maintained by the National Center for Biotechnology Information (NCBI). It contains over 110 million unique chemical structures and over 270 million bioactivities from more than 770 data sources. + +## Database Structure + +PubChem consists of three primary subdatabases: + +1. **Compound Database**: Unique validated chemical structures with computed properties +2. **Substance Database**: Deposited chemical substance records from data sources +3. **BioAssay Database**: Biological activity test results for chemical compounds + +## PubChem PUG-REST API + +### Base URL Structure + +``` +https://pubchem.ncbi.nlm.nih.gov/rest/pug/// +``` + +Components: +- ``: compound/cid, substance/sid, assay/aid, or search specifications +- ``: Optional operations like property, synonyms, classification, etc. +- ``: Format such as JSON, XML, CSV, PNG, SDF, etc. + +### Common Request Patterns + +#### 1. Retrieve by Identifier + +Get compound by CID (Compound ID): +``` +GET /rest/pug/compound/cid/{cid}/property/{properties}/JSON +``` + +Get compound by name: +``` +GET /rest/pug/compound/name/{name}/property/{properties}/JSON +``` + +Get compound by SMILES: +``` +GET /rest/pug/compound/smiles/{smiles}/property/{properties}/JSON +``` + +Get compound by InChI: +``` +GET /rest/pug/compound/inchi/{inchi}/property/{properties}/JSON +``` + +#### 2. Available Properties + +Common molecular properties that can be retrieved: +- `MolecularFormula` +- `MolecularWeight` +- `CanonicalSMILES` +- `IsomericSMILES` +- `InChI` +- `InChIKey` +- `IUPACName` +- `XLogP` +- `ExactMass` +- `MonoisotopicMass` +- `TPSA` (Topological Polar Surface Area) +- `Complexity` +- `Charge` +- `HBondDonorCount` +- `HBondAcceptorCount` +- `RotatableBondCount` +- `HeavyAtomCount` +- `IsotopeAtomCount` +- `AtomStereoCount` +- `BondStereoCount` +- `CovalentUnitCount` +- `Volume3D` +- `XStericQuadrupole3D` +- `YStericQuadrupole3D` +- `ZStericQuadrupole3D` +- `FeatureCount3D` + +To retrieve multiple properties, separate them with commas: +``` +/property/MolecularFormula,MolecularWeight,CanonicalSMILES/JSON +``` + +#### 3. Structure Search Operations + +**Similarity Search**: +``` +POST /rest/pug/compound/similarity/smiles/{smiles}/JSON +Parameters: Threshold (default 90%) +``` + +**Substructure Search**: +``` +POST /rest/pug/compound/substructure/smiles/{smiles}/cids/JSON +``` + +**Superstructure Search**: +``` +POST /rest/pug/compound/superstructure/smiles/{smiles}/cids/JSON +``` + +#### 4. Image Generation + +Get 2D structure image: +``` +GET /rest/pug/compound/cid/{cid}/PNG +Optional parameters: image_size=small|large +``` + +#### 5. Format Conversion + +Get compound as SDF (Structure-Data File): +``` +GET /rest/pug/compound/cid/{cid}/SDF +``` + +Get compound as MOL: +``` +GET /rest/pug/compound/cid/{cid}/record/SDF +``` + +#### 6. Synonym Retrieval + +Get all synonyms for a compound: +``` +GET /rest/pug/compound/cid/{cid}/synonyms/JSON +``` + +#### 7. Bioassay Data + +Get bioassay data for a compound: +``` +GET /rest/pug/compound/cid/{cid}/assaysummary/JSON +``` + +Get specific assay information: +``` +GET /rest/pug/assay/aid/{aid}/description/JSON +``` + +### Asynchronous Requests + +For large queries (similarity/substructure searches), PUG-REST uses an asynchronous pattern: + +1. Submit the query (returns ListKey) +2. Check status using the ListKey +3. Retrieve results when ready + +Example workflow: +```python +# Step 1: Submit similarity search +response = requests.post( + "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/similarity/smiles/{smiles}/cids/JSON", + data={"Threshold": 90} +) +listkey = response.json()["Waiting"]["ListKey"] + +# Step 2: Check status +status_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/listkey/{listkey}/cids/JSON" + +# Step 3: Poll until ready (with timeout) +# Step 4: Retrieve results from the same URL +``` + +### Usage Limits + +**Rate Limits**: +- Maximum 5 requests per second +- Maximum 400 requests per minute +- Maximum 300 seconds running time per minute + +**Best Practices**: +- Use batch requests when possible +- Implement exponential backoff for retries +- Cache results when appropriate +- Use asynchronous pattern for large queries + +## PubChemPy Python Library + +PubChemPy is a Python wrapper that simplifies PUG-REST API access. + +### Installation + +```bash +pip install pubchempy +``` + +### Key Classes + +#### Compound Class + +Main class for representing chemical compounds: + +```python +import pubchempy as pcp + +# Get by CID +compound = pcp.Compound.from_cid(2244) + +# Access properties +compound.molecular_formula # 'C9H8O4' +compound.molecular_weight # 180.16 +compound.iupac_name # '2-acetyloxybenzoic acid' +compound.canonical_smiles # 'CC(=O)OC1=CC=CC=C1C(=O)O' +compound.isomeric_smiles # Same as canonical for non-stereoisomers +compound.inchi # InChI string +compound.inchikey # InChI Key +compound.xlogp # Partition coefficient +compound.tpsa # Topological polar surface area +``` + +#### Search Methods + +**By Name**: +```python +compounds = pcp.get_compounds('aspirin', 'name') +# Returns list of Compound objects +``` + +**By SMILES**: +```python +compound = pcp.get_compounds('CC(=O)OC1=CC=CC=C1C(=O)O', 'smiles')[0] +``` + +**By InChI**: +```python +compound = pcp.get_compounds('InChI=1S/C9H8O4/c1-6(10)13-8-5-3-2-4-7(8)9(11)12/h2-5H,1H3,(H,11,12)', 'inchi')[0] +``` + +**By Formula**: +```python +compounds = pcp.get_compounds('C9H8O4', 'formula') +# Returns all compounds with this formula +``` + +**Similarity Search**: +```python +results = pcp.get_compounds('CC(=O)OC1=CC=CC=C1C(=O)O', 'smiles', + searchtype='similarity', + Threshold=90) +``` + +**Substructure Search**: +```python +results = pcp.get_compounds('c1ccccc1', 'smiles', + searchtype='substructure') +# Returns all compounds containing benzene ring +``` + +#### Property Retrieval + +Get specific properties for multiple compounds: +```python +properties = pcp.get_properties( + ['MolecularFormula', 'MolecularWeight', 'CanonicalSMILES'], + 'aspirin', + 'name' +) +# Returns list of dictionaries +``` + +Get properties as pandas DataFrame: +```python +import pandas as pd +df = pd.DataFrame(properties) +``` + +#### Synonyms + +Get all synonyms for a compound: +```python +synonyms = pcp.get_synonyms('aspirin', 'name') +# Returns list of dictionaries with CID and synonym lists +``` + +#### Download Formats + +Download compound in various formats: +```python +# Get as SDF +sdf_data = pcp.download('SDF', 'aspirin', 'name', overwrite=True) + +# Get as JSON +json_data = pcp.download('JSON', '2244', 'cid') + +# Get as PNG image +pcp.download('PNG', '2244', 'cid', 'aspirin.png', overwrite=True) +``` + +### Error Handling + +```python +from pubchempy import BadRequestError, NotFoundError, TimeoutError + +try: + compound = pcp.get_compounds('nonexistent', 'name') +except NotFoundError: + print("Compound not found") +except BadRequestError: + print("Invalid request") +except TimeoutError: + print("Request timed out") +``` + +## PUG-View API + +PUG-View provides access to full textual annotations and specialized reports. + +### Key Endpoints + +Get compound annotations: +``` +GET /rest/pug_view/data/compound/{cid}/JSON +``` + +Get specific annotation sections: +``` +GET /rest/pug_view/data/compound/{cid}/JSON?heading={section_name} +``` + +Available sections include: +- Chemical and Physical Properties +- Drug and Medication Information +- Pharmacology and Biochemistry +- Safety and Hazards +- Toxicity +- Literature +- Patents +- Biomolecular Interactions and Pathways + +## Common Workflows + +### 1. Chemical Identifier Conversion + +Convert from name to SMILES to InChI: +```python +import pubchempy as pcp + +compound = pcp.get_compounds('caffeine', 'name')[0] +smiles = compound.canonical_smiles +inchi = compound.inchi +inchikey = compound.inchikey +cid = compound.cid +``` + +### 2. Batch Property Retrieval + +Get properties for multiple compounds: +```python +compound_names = ['aspirin', 'ibuprofen', 'paracetamol'] +properties = [] + +for name in compound_names: + props = pcp.get_properties( + ['MolecularFormula', 'MolecularWeight', 'XLogP'], + name, + 'name' + ) + properties.extend(props) + +import pandas as pd +df = pd.DataFrame(properties) +``` + +### 3. Finding Similar Compounds + +Find structurally similar compounds to a query: +```python +# Start with a known compound +query_compound = pcp.get_compounds('gefitinib', 'name')[0] +query_smiles = query_compound.canonical_smiles + +# Perform similarity search +similar = pcp.get_compounds( + query_smiles, + 'smiles', + searchtype='similarity', + Threshold=85 +) + +# Get properties for similar compounds +for compound in similar[:10]: # First 10 results + print(f"{compound.cid}: {compound.iupac_name}, MW: {compound.molecular_weight}") +``` + +### 4. Substructure Screening + +Find all compounds containing a specific substructure: +```python +# Search for compounds containing pyridine ring +pyridine_smiles = 'c1ccncc1' + +matches = pcp.get_compounds( + pyridine_smiles, + 'smiles', + searchtype='substructure', + MaxRecords=100 +) + +print(f"Found {len(matches)} compounds containing pyridine") +``` + +### 5. Bioactivity Data Retrieval + +```python +import requests + +cid = 2244 # Aspirin +url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/assaysummary/JSON" + +response = requests.get(url) +if response.status_code == 200: + bioassay_data = response.json() + # Process bioassay information +``` + +## Tips and Best Practices + +1. **Use CIDs for repeated queries**: CIDs are more efficient than names or structures +2. **Cache results**: Store frequently accessed data locally +3. **Batch requests**: Combine multiple queries when possible +4. **Handle rate limits**: Implement delays between requests +5. **Use appropriate search types**: Similarity for related compounds, substructure for motif finding +6. **Leverage PubChemPy**: Higher-level abstraction simplifies common tasks +7. **Handle missing data**: Not all properties are available for all compounds +8. **Use asynchronous pattern**: For large similarity/substructure searches +9. **Specify output format**: Choose JSON for programmatic access, SDF for cheminformatics tools +10. **Read documentation**: Full PUG-REST documentation available at https://pubchem.ncbi.nlm.nih.gov/docs/pug-rest + +## Additional Resources + +- PubChem Home: https://pubchem.ncbi.nlm.nih.gov/ +- PUG-REST Documentation: https://pubchem.ncbi.nlm.nih.gov/docs/pug-rest +- PUG-REST Tutorial: https://pubchem.ncbi.nlm.nih.gov/docs/pug-rest-tutorial +- PubChemPy Documentation: https://pubchempy.readthedocs.io/ +- PubChemPy GitHub: https://github.com/mcs07/PubChemPy +- IUPAC Tutorial: https://iupac.github.io/WFChemCookbook/datasources/pubchem_pugrest.html diff --git a/scientific-packages/pubchem-database/scripts/bioactivity_query.py b/scientific-packages/pubchem-database/scripts/bioactivity_query.py new file mode 100644 index 0000000..6fcc8d9 --- /dev/null +++ b/scientific-packages/pubchem-database/scripts/bioactivity_query.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +""" +PubChem Bioactivity Data Retrieval + +This script provides functions for retrieving biological activity data +from PubChem for compounds and assays. +""" + +import sys +import json +import time +from typing import Dict, List, Optional + +try: + import requests +except ImportError: + print("Error: requests is not installed. Install it with: pip install requests") + sys.exit(1) + + +BASE_URL = "https://pubchem.ncbi.nlm.nih.gov/rest/pug" +PUG_VIEW_URL = "https://pubchem.ncbi.nlm.nih.gov/rest/pug_view" + +# Rate limiting: 5 requests per second maximum +REQUEST_DELAY = 0.21 # seconds between requests + + +def rate_limited_request(url: str, method: str = 'GET', **kwargs) -> Optional[requests.Response]: + """ + Make a rate-limited request to PubChem API. + + Args: + url: Request URL + method: HTTP method ('GET' or 'POST') + **kwargs: Additional arguments for requests + + Returns: + Response object or None on error + """ + time.sleep(REQUEST_DELAY) + + try: + if method.upper() == 'GET': + response = requests.get(url, **kwargs) + else: + response = requests.post(url, **kwargs) + + response.raise_for_status() + return response + except requests.exceptions.RequestException as e: + print(f"Request error: {e}") + return None + + +def get_bioassay_summary(cid: int) -> Optional[Dict]: + """ + Get bioassay summary for a compound. + + Args: + cid: PubChem Compound ID + + Returns: + Dictionary containing bioassay summary data + """ + url = f"{BASE_URL}/compound/cid/{cid}/assaysummary/JSON" + response = rate_limited_request(url) + + if response and response.status_code == 200: + return response.json() + return None + + +def get_compound_bioactivities( + cid: int, + activity_outcome: Optional[str] = None +) -> List[Dict]: + """ + Get bioactivity data for a compound. + + Args: + cid: PubChem Compound ID + activity_outcome: Filter by activity ('active', 'inactive', 'inconclusive') + + Returns: + List of bioactivity records + """ + data = get_bioassay_summary(cid) + + if not data: + return [] + + activities = [] + table = data.get('Table', {}) + + for row in table.get('Row', []): + activity = {} + for i, cell in enumerate(row.get('Cell', [])): + column_name = table['Columns']['Column'][i] + activity[column_name] = cell + + if activity_outcome: + if activity.get('Activity Outcome', '').lower() == activity_outcome.lower(): + activities.append(activity) + else: + activities.append(activity) + + return activities + + +def get_assay_description(aid: int) -> Optional[Dict]: + """ + Get detailed description for a specific assay. + + Args: + aid: PubChem Assay ID (AID) + + Returns: + Dictionary containing assay description + """ + url = f"{BASE_URL}/assay/aid/{aid}/description/JSON" + response = rate_limited_request(url) + + if response and response.status_code == 200: + return response.json() + return None + + +def get_assay_targets(aid: int) -> List[str]: + """ + Get biological targets for an assay. + + Args: + aid: PubChem Assay ID + + Returns: + List of target names + """ + description = get_assay_description(aid) + + if not description: + return [] + + targets = [] + assay_data = description.get('PC_AssayContainer', [{}])[0] + assay = assay_data.get('assay', {}) + + # Extract target information + descr = assay.get('descr', {}) + for target in descr.get('target', []): + mol_id = target.get('mol_id', '') + name = target.get('name', '') + if name: + targets.append(name) + elif mol_id: + targets.append(f"GI:{mol_id}") + + return targets + + +def search_assays_by_target( + target_name: str, + max_results: int = 100 +) -> List[int]: + """ + Search for assays targeting a specific protein or gene. + + Args: + target_name: Name of the target (e.g., 'EGFR', 'p53') + max_results: Maximum number of results + + Returns: + List of Assay IDs (AIDs) + """ + # Use PubChem's text search for assays + url = f"{BASE_URL}/assay/target/{target_name}/aids/JSON" + response = rate_limited_request(url) + + if response and response.status_code == 200: + data = response.json() + aids = data.get('IdentifierList', {}).get('AID', []) + return aids[:max_results] + return [] + + +def get_active_compounds_in_assay(aid: int, max_results: int = 1000) -> List[int]: + """ + Get list of active compounds in an assay. + + Args: + aid: PubChem Assay ID + max_results: Maximum number of results + + Returns: + List of Compound IDs (CIDs) that showed activity + """ + url = f"{BASE_URL}/assay/aid/{aid}/cids/JSON?cids_type=active" + response = rate_limited_request(url) + + if response and response.status_code == 200: + data = response.json() + cids = data.get('IdentifierList', {}).get('CID', []) + return cids[:max_results] + return [] + + +def get_compound_annotations(cid: int, section: Optional[str] = None) -> Optional[Dict]: + """ + Get comprehensive compound annotations from PUG-View. + + Args: + cid: PubChem Compound ID + section: Specific section to retrieve (e.g., 'Pharmacology and Biochemistry') + + Returns: + Dictionary containing annotation data + """ + url = f"{PUG_VIEW_URL}/data/compound/{cid}/JSON" + + if section: + url += f"?heading={section}" + + response = rate_limited_request(url) + + if response and response.status_code == 200: + return response.json() + return None + + +def get_drug_information(cid: int) -> Optional[Dict]: + """ + Get drug and medication information for a compound. + + Args: + cid: PubChem Compound ID + + Returns: + Dictionary containing drug information + """ + return get_compound_annotations(cid, section="Drug and Medication Information") + + +def get_safety_hazards(cid: int) -> Optional[Dict]: + """ + Get safety and hazard information for a compound. + + Args: + cid: PubChem Compound ID + + Returns: + Dictionary containing safety information + """ + return get_compound_annotations(cid, section="Safety and Hazards") + + +def summarize_bioactivities(cid: int) -> Dict: + """ + Generate a summary of bioactivity data for a compound. + + Args: + cid: PubChem Compound ID + + Returns: + Dictionary with bioactivity summary statistics + """ + activities = get_compound_bioactivities(cid) + + summary = { + 'total_assays': len(activities), + 'active': 0, + 'inactive': 0, + 'inconclusive': 0, + 'unspecified': 0, + 'assay_types': {} + } + + for activity in activities: + outcome = activity.get('Activity Outcome', '').lower() + + if 'active' in outcome: + summary['active'] += 1 + elif 'inactive' in outcome: + summary['inactive'] += 1 + elif 'inconclusive' in outcome: + summary['inconclusive'] += 1 + else: + summary['unspecified'] += 1 + + return summary + + +def find_compounds_by_bioactivity( + target: str, + threshold: Optional[float] = None, + max_compounds: int = 100 +) -> List[Dict]: + """ + Find compounds with bioactivity against a specific target. + + Args: + target: Target name (e.g., 'EGFR') + threshold: Activity threshold (if applicable) + max_compounds: Maximum number of compounds to return + + Returns: + List of dictionaries with compound information and activity data + """ + # Step 1: Find assays for the target + assay_ids = search_assays_by_target(target, max_results=10) + + if not assay_ids: + print(f"No assays found for target: {target}") + return [] + + # Step 2: Get active compounds from these assays + compound_set = set() + compound_data = [] + + for aid in assay_ids[:5]: # Limit to first 5 assays + active_cids = get_active_compounds_in_assay(aid, max_results=max_compounds) + + for cid in active_cids: + if cid not in compound_set and len(compound_data) < max_compounds: + compound_set.add(cid) + compound_data.append({ + 'cid': cid, + 'aid': aid, + 'target': target + }) + + if len(compound_data) >= max_compounds: + break + + return compound_data + + +def main(): + """Example usage of bioactivity query functions.""" + + # Example 1: Get bioassay summary for aspirin (CID 2244) + print("Example 1: Getting bioassay summary for aspirin (CID 2244)...") + summary = summarize_bioactivities(2244) + print(json.dumps(summary, indent=2)) + + # Example 2: Get active bioactivities for a compound + print("\nExample 2: Getting active bioactivities for aspirin...") + activities = get_compound_bioactivities(2244, activity_outcome='active') + print(f"Found {len(activities)} active bioactivities") + if activities: + print(f"First activity: {activities[0].get('Assay Name', 'N/A')}") + + # Example 3: Get assay information + print("\nExample 3: Getting assay description...") + if activities: + aid = activities[0].get('AID', 0) + targets = get_assay_targets(aid) + print(f"Assay {aid} targets: {', '.join(targets) if targets else 'N/A'}") + + # Example 4: Search for compounds targeting EGFR + print("\nExample 4: Searching for EGFR inhibitors...") + egfr_compounds = find_compounds_by_bioactivity('EGFR', max_compounds=5) + print(f"Found {len(egfr_compounds)} compounds with EGFR activity") + for comp in egfr_compounds[:5]: + print(f" CID {comp['cid']} (from AID {comp['aid']})") + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/pubchem-database/scripts/compound_search.py b/scientific-packages/pubchem-database/scripts/compound_search.py new file mode 100644 index 0000000..b6b8a8b --- /dev/null +++ b/scientific-packages/pubchem-database/scripts/compound_search.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +""" +PubChem Compound Search Utility + +This script provides functions for searching and retrieving compound information +from PubChem using the PubChemPy library. +""" + +import sys +import json +from typing import List, Dict, Optional, Union + +try: + import pubchempy as pcp +except ImportError: + print("Error: pubchempy is not installed. Install it with: pip install pubchempy") + sys.exit(1) + + +def search_by_name(name: str, max_results: int = 10) -> List[pcp.Compound]: + """ + Search for compounds by name. + + Args: + name: Chemical name to search for + max_results: Maximum number of results to return + + Returns: + List of Compound objects + """ + try: + compounds = pcp.get_compounds(name, 'name') + return compounds[:max_results] + except Exception as e: + print(f"Error searching for '{name}': {e}") + return [] + + +def search_by_smiles(smiles: str) -> Optional[pcp.Compound]: + """ + Search for a compound by SMILES string. + + Args: + smiles: SMILES string + + Returns: + Compound object or None if not found + """ + try: + compounds = pcp.get_compounds(smiles, 'smiles') + return compounds[0] if compounds else None + except Exception as e: + print(f"Error searching for SMILES '{smiles}': {e}") + return None + + +def get_compound_by_cid(cid: int) -> Optional[pcp.Compound]: + """ + Retrieve a compound by its CID (Compound ID). + + Args: + cid: PubChem Compound ID + + Returns: + Compound object or None if not found + """ + try: + return pcp.Compound.from_cid(cid) + except Exception as e: + print(f"Error retrieving CID {cid}: {e}") + return None + + +def get_compound_properties( + identifier: Union[str, int], + namespace: str = 'name', + properties: Optional[List[str]] = None +) -> Dict: + """ + Get specific properties for a compound. + + Args: + identifier: Compound identifier (name, SMILES, CID, etc.) + namespace: Type of identifier ('name', 'smiles', 'cid', 'inchi', etc.) + properties: List of properties to retrieve. If None, returns common properties. + + Returns: + Dictionary of properties + """ + if properties is None: + properties = [ + 'MolecularFormula', + 'MolecularWeight', + 'CanonicalSMILES', + 'IUPACName', + 'XLogP', + 'TPSA', + 'HBondDonorCount', + 'HBondAcceptorCount' + ] + + try: + result = pcp.get_properties(properties, identifier, namespace) + return result[0] if result else {} + except Exception as e: + print(f"Error getting properties for '{identifier}': {e}") + return {} + + +def similarity_search( + smiles: str, + threshold: int = 90, + max_records: int = 10 +) -> List[pcp.Compound]: + """ + Perform similarity search for compounds similar to the query structure. + + Args: + smiles: Query SMILES string + threshold: Similarity threshold (0-100) + max_records: Maximum number of results + + Returns: + List of similar Compound objects + """ + try: + compounds = pcp.get_compounds( + smiles, + 'smiles', + searchtype='similarity', + Threshold=threshold, + MaxRecords=max_records + ) + return compounds + except Exception as e: + print(f"Error in similarity search: {e}") + return [] + + +def substructure_search( + smiles: str, + max_records: int = 100 +) -> List[pcp.Compound]: + """ + Perform substructure search for compounds containing the query structure. + + Args: + smiles: Query SMILES string (substructure) + max_records: Maximum number of results + + Returns: + List of Compound objects containing the substructure + """ + try: + compounds = pcp.get_compounds( + smiles, + 'smiles', + searchtype='substructure', + MaxRecords=max_records + ) + return compounds + except Exception as e: + print(f"Error in substructure search: {e}") + return [] + + +def get_synonyms(identifier: Union[str, int], namespace: str = 'name') -> List[str]: + """ + Get all synonyms for a compound. + + Args: + identifier: Compound identifier + namespace: Type of identifier + + Returns: + List of synonym strings + """ + try: + results = pcp.get_synonyms(identifier, namespace) + if results: + return results[0].get('Synonym', []) + return [] + except Exception as e: + print(f"Error getting synonyms: {e}") + return [] + + +def batch_search( + identifiers: List[str], + namespace: str = 'name', + properties: Optional[List[str]] = None +) -> List[Dict]: + """ + Batch search for multiple compounds. + + Args: + identifiers: List of compound identifiers + namespace: Type of identifiers + properties: List of properties to retrieve + + Returns: + List of dictionaries containing properties for each compound + """ + results = [] + for identifier in identifiers: + props = get_compound_properties(identifier, namespace, properties) + if props: + props['query'] = identifier + results.append(props) + return results + + +def download_structure( + identifier: Union[str, int], + namespace: str = 'name', + format: str = 'SDF', + filename: Optional[str] = None +) -> Optional[str]: + """ + Download compound structure in specified format. + + Args: + identifier: Compound identifier + namespace: Type of identifier + format: Output format ('SDF', 'JSON', 'PNG', etc.) + filename: Output filename (if None, returns data as string) + + Returns: + Data string if filename is None, else None + """ + try: + if filename: + pcp.download(format, identifier, namespace, filename, overwrite=True) + return None + else: + return pcp.download(format, identifier, namespace) + except Exception as e: + print(f"Error downloading structure: {e}") + return None + + +def print_compound_info(compound: pcp.Compound) -> None: + """ + Print formatted compound information. + + Args: + compound: PubChemPy Compound object + """ + print(f"\n{'='*60}") + print(f"Compound CID: {compound.cid}") + print(f"{'='*60}") + print(f"IUPAC Name: {compound.iupac_name or 'N/A'}") + print(f"Molecular Formula: {compound.molecular_formula or 'N/A'}") + print(f"Molecular Weight: {compound.molecular_weight or 'N/A'} g/mol") + print(f"Canonical SMILES: {compound.canonical_smiles or 'N/A'}") + print(f"InChI: {compound.inchi or 'N/A'}") + print(f"InChI Key: {compound.inchikey or 'N/A'}") + print(f"XLogP: {compound.xlogp or 'N/A'}") + print(f"TPSA: {compound.tpsa or 'N/A'} Ų") + print(f"H-Bond Donors: {compound.h_bond_donor_count or 'N/A'}") + print(f"H-Bond Acceptors: {compound.h_bond_acceptor_count or 'N/A'}") + print(f"{'='*60}\n") + + +def main(): + """Example usage of PubChem search functions.""" + + # Example 1: Search by name + print("Example 1: Searching for 'aspirin'...") + compounds = search_by_name('aspirin', max_results=1) + if compounds: + print_compound_info(compounds[0]) + + # Example 2: Get properties + print("\nExample 2: Getting properties for caffeine...") + props = get_compound_properties('caffeine', 'name') + print(json.dumps(props, indent=2)) + + # Example 3: Similarity search + print("\nExample 3: Finding compounds similar to benzene...") + benzene_smiles = 'c1ccccc1' + similar = similarity_search(benzene_smiles, threshold=95, max_records=5) + print(f"Found {len(similar)} similar compounds:") + for comp in similar: + print(f" CID {comp.cid}: {comp.iupac_name or 'N/A'}") + + # Example 4: Batch search + print("\nExample 4: Batch search for multiple compounds...") + names = ['aspirin', 'ibuprofen', 'paracetamol'] + results = batch_search(names, properties=['MolecularFormula', 'MolecularWeight']) + for result in results: + print(f" {result.get('query')}: {result.get('MolecularFormula')} " + f"({result.get('MolecularWeight')} g/mol)") + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/pydeseq2/SKILL.md b/scientific-packages/pydeseq2/SKILL.md new file mode 100644 index 0000000..2674a75 --- /dev/null +++ b/scientific-packages/pydeseq2/SKILL.md @@ -0,0 +1,567 @@ +--- +name: pydeseq2 +description: Toolkit for differential gene expression analysis using PyDESeq2, a Python implementation of the DESeq2 method for bulk RNA-seq data. Use when analyzing RNA-seq count data to identify differentially expressed genes between conditions, performing single-factor or multi-factor experimental designs with Wald tests, or when users request DESeq2 analysis in Python. Supports data loading from CSV/TSV/pickle/AnnData formats, complete statistical workflows, result visualization, and integration with pandas-based data science pipelines. +--- + +# PyDESeq2 + +## Overview + +PyDESeq2 is a Python implementation of the DESeq2 method for differential expression analysis (DEA) with bulk RNA-seq data. This skill provides comprehensive support for designing and executing PyDESeq2 workflows, from data loading through result interpretation. + +**Key capabilities:** +- Single-factor and multi-factor experimental designs +- Statistical testing using Wald tests with multiple testing correction +- Optional apeGLM log-fold-change shrinkage +- Data preprocessing and quality control +- Result export and visualization +- Integration with pandas, AnnData, and the Python data science ecosystem + +## When to Use This Skill + +Invoke this skill when: +- Analyzing bulk RNA-seq count data for differential expression +- Comparing gene expression between experimental conditions (e.g., treated vs control) +- Performing multi-factor designs accounting for batch effects or covariates +- Converting R-based DESeq2 workflows to Python +- Integrating differential expression analysis into Python-based pipelines +- Users mention "DESeq2", "differential expression", "RNA-seq analysis", or "PyDESeq2" + +## Quick Start Workflow + +For users who want to perform a standard differential expression analysis: + +```python +import pandas as pd +from pydeseq2.dds import DeseqDataSet +from pydeseq2.ds import DeseqStats + +# 1. Load data +counts_df = pd.read_csv("counts.csv", index_col=0).T # Transpose to samples × genes +metadata = pd.read_csv("metadata.csv", index_col=0) + +# 2. Filter low-count genes +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 10] +counts_df = counts_df[genes_to_keep] + +# 3. Initialize and fit DESeq2 +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition", + refit_cooks=True +) +dds.deseq2() + +# 4. Perform statistical testing +ds = DeseqStats(dds, contrast=["condition", "treated", "control"]) +ds.summary() + +# 5. Access results +results = ds.results_df +significant = results[results.padj < 0.05] +print(f"Found {len(significant)} significant genes") +``` + +## Core Workflow Steps + +### Step 1: Data Preparation + +**Input requirements:** +- **Count matrix:** Samples × genes DataFrame with non-negative integer read counts +- **Metadata:** Samples × variables DataFrame with experimental factors + +**Common data loading patterns:** + +```python +# From CSV (typical format: genes × samples, needs transpose) +counts_df = pd.read_csv("counts.csv", index_col=0).T +metadata = pd.read_csv("metadata.csv", index_col=0) + +# From TSV +counts_df = pd.read_csv("counts.tsv", sep="\t", index_col=0).T + +# From AnnData +import anndata as ad +adata = ad.read_h5ad("data.h5ad") +counts_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names) +metadata = adata.obs +``` + +**Data filtering:** + +```python +# Remove low-count genes +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 10] +counts_df = counts_df[genes_to_keep] + +# Remove samples with missing metadata +samples_to_keep = ~metadata.condition.isna() +counts_df = counts_df.loc[samples_to_keep] +metadata = metadata.loc[samples_to_keep] +``` + +### Step 2: Design Specification + +The design formula specifies how gene expression is modeled. + +**Single-factor designs:** +```python +design = "~condition" # Simple two-group comparison +``` + +**Multi-factor designs:** +```python +design = "~batch + condition" # Control for batch effects +design = "~age + condition" # Include continuous covariate +design = "~group + condition + group:condition" # Interaction effects +``` + +**Design formula guidelines:** +- Use Wilkinson formula notation (R-style) +- Put adjustment variables (e.g., batch) before the main variable of interest +- Ensure variables exist as columns in the metadata DataFrame +- Use appropriate data types (categorical for discrete variables) + +### Step 3: DESeq2 Fitting + +Initialize the DeseqDataSet and run the complete pipeline: + +```python +from pydeseq2.dds import DeseqDataSet + +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition", + refit_cooks=True, # Refit after removing outliers + n_cpus=1 # Parallel processing (adjust as needed) +) + +# Run the complete DESeq2 pipeline +dds.deseq2() +``` + +**What `deseq2()` does:** +1. Computes size factors (normalization) +2. Fits genewise dispersions +3. Fits dispersion trend curve +4. Computes dispersion priors +5. Fits MAP dispersions (shrinkage) +6. Fits log fold changes +7. Calculates Cook's distances (outlier detection) +8. Refits if outliers detected (optional) + +### Step 4: Statistical Testing + +Perform Wald tests to identify differentially expressed genes: + +```python +from pydeseq2.ds import DeseqStats + +ds = DeseqStats( + dds, + contrast=["condition", "treated", "control"], # Test treated vs control + alpha=0.05, # Significance threshold + cooks_filter=True, # Filter outliers + independent_filter=True # Filter low-power tests +) + +ds.summary() +``` + +**Contrast specification:** +- Format: `[variable, test_level, reference_level]` +- Example: `["condition", "treated", "control"]` tests treated vs control +- If `None`, uses the last coefficient in the design + +**Result DataFrame columns:** +- `baseMean`: Mean normalized count across samples +- `log2FoldChange`: Log2 fold change between conditions +- `lfcSE`: Standard error of LFC +- `stat`: Wald test statistic +- `pvalue`: Raw p-value +- `padj`: Adjusted p-value (FDR-corrected via Benjamini-Hochberg) + +### Step 5: Optional LFC Shrinkage + +Apply shrinkage to reduce noise in fold change estimates: + +```python +ds.lfc_shrink() # Applies apeGLM shrinkage +``` + +**When to use LFC shrinkage:** +- For visualization (volcano plots, heatmaps) +- For ranking genes by effect size +- When prioritizing genes for follow-up experiments + +**Important:** Shrinkage affects only the log2FoldChange values, not the statistical test results (p-values remain unchanged). Use shrunk values for visualization but report unshrunken p-values for significance. + +### Step 6: Result Export + +Save results and intermediate objects: + +```python +import pickle + +# Export results as CSV +ds.results_df.to_csv("deseq2_results.csv") + +# Save significant genes only +significant = ds.results_df[ds.results_df.padj < 0.05] +significant.to_csv("significant_genes.csv") + +# Save DeseqDataSet for later use +with open("dds_result.pkl", "wb") as f: + pickle.dump(dds.to_picklable_anndata(), f) +``` + +## Common Analysis Patterns + +### Two-Group Comparison + +Standard case-control comparison: + +```python +dds = DeseqDataSet(counts=counts_df, metadata=metadata, design="~condition") +dds.deseq2() + +ds = DeseqStats(dds, contrast=["condition", "treated", "control"]) +ds.summary() + +results = ds.results_df +significant = results[results.padj < 0.05] +``` + +### Multiple Comparisons + +Testing multiple treatment groups against control: + +```python +dds = DeseqDataSet(counts=counts_df, metadata=metadata, design="~condition") +dds.deseq2() + +treatments = ["treatment_A", "treatment_B", "treatment_C"] +all_results = {} + +for treatment in treatments: + ds = DeseqStats(dds, contrast=["condition", treatment, "control"]) + ds.summary() + all_results[treatment] = ds.results_df + + sig_count = len(ds.results_df[ds.results_df.padj < 0.05]) + print(f"{treatment}: {sig_count} significant genes") +``` + +### Accounting for Batch Effects + +Control for technical variation: + +```python +# Include batch in design +dds = DeseqDataSet(counts=counts_df, metadata=metadata, design="~batch + condition") +dds.deseq2() + +# Test condition while controlling for batch +ds = DeseqStats(dds, contrast=["condition", "treated", "control"]) +ds.summary() +``` + +### Continuous Covariates + +Include continuous variables like age or dosage: + +```python +# Ensure continuous variable is numeric +metadata["age"] = pd.to_numeric(metadata["age"]) + +dds = DeseqDataSet(counts=counts_df, metadata=metadata, design="~age + condition") +dds.deseq2() + +ds = DeseqStats(dds, contrast=["condition", "treated", "control"]) +ds.summary() +``` + +## Using the Analysis Script + +This skill includes a complete command-line script for standard analyses: + +```bash +# Basic usage +python scripts/run_deseq2_analysis.py \ + --counts counts.csv \ + --metadata metadata.csv \ + --design "~condition" \ + --contrast condition treated control \ + --output results/ + +# With additional options +python scripts/run_deseq2_analysis.py \ + --counts counts.csv \ + --metadata metadata.csv \ + --design "~batch + condition" \ + --contrast condition treated control \ + --output results/ \ + --min-counts 10 \ + --alpha 0.05 \ + --n-cpus 4 \ + --plots +``` + +**Script features:** +- Automatic data loading and validation +- Gene and sample filtering +- Complete DESeq2 pipeline execution +- Statistical testing with customizable parameters +- Result export (CSV, pickle) +- Optional visualization (volcano and MA plots) + +Refer users to `scripts/run_deseq2_analysis.py` when they need a standalone analysis tool or want to batch process multiple datasets. + +## Result Interpretation + +### Identifying Significant Genes + +```python +# Filter by adjusted p-value +significant = ds.results_df[ds.results_df.padj < 0.05] + +# Filter by both significance and effect size +sig_and_large = ds.results_df[ + (ds.results_df.padj < 0.05) & + (abs(ds.results_df.log2FoldChange) > 1) +] + +# Separate up- and down-regulated +upregulated = significant[significant.log2FoldChange > 0] +downregulated = significant[significant.log2FoldChange < 0] + +print(f"Upregulated: {len(upregulated)}") +print(f"Downregulated: {len(downregulated)}") +``` + +### Ranking and Sorting + +```python +# Sort by adjusted p-value +top_by_padj = ds.results_df.sort_values("padj").head(20) + +# Sort by absolute fold change (use shrunk values) +ds.lfc_shrink() +ds.results_df["abs_lfc"] = abs(ds.results_df.log2FoldChange) +top_by_lfc = ds.results_df.sort_values("abs_lfc", ascending=False).head(20) + +# Sort by a combined metric +ds.results_df["score"] = -np.log10(ds.results_df.padj) * abs(ds.results_df.log2FoldChange) +top_combined = ds.results_df.sort_values("score", ascending=False).head(20) +``` + +### Quality Metrics + +```python +# Check normalization (size factors should be close to 1) +print("Size factors:", dds.obsm["size_factors"]) + +# Examine dispersion estimates +import matplotlib.pyplot as plt +plt.hist(dds.varm["dispersions"], bins=50) +plt.xlabel("Dispersion") +plt.ylabel("Frequency") +plt.title("Dispersion Distribution") +plt.show() + +# Check p-value distribution (should be mostly flat with peak near 0) +plt.hist(ds.results_df.pvalue.dropna(), bins=50) +plt.xlabel("P-value") +plt.ylabel("Frequency") +plt.title("P-value Distribution") +plt.show() +``` + +## Visualization Guidelines + +### Volcano Plot + +Visualize significance vs effect size: + +```python +import matplotlib.pyplot as plt +import numpy as np + +results = ds.results_df.copy() +results["-log10(padj)"] = -np.log10(results.padj) + +plt.figure(figsize=(10, 6)) +significant = results.padj < 0.05 + +plt.scatter( + results.loc[~significant, "log2FoldChange"], + results.loc[~significant, "-log10(padj)"], + alpha=0.3, s=10, c='gray', label='Not significant' +) +plt.scatter( + results.loc[significant, "log2FoldChange"], + results.loc[significant, "-log10(padj)"], + alpha=0.6, s=10, c='red', label='padj < 0.05' +) + +plt.axhline(-np.log10(0.05), color='blue', linestyle='--', alpha=0.5) +plt.xlabel("Log2 Fold Change") +plt.ylabel("-Log10(Adjusted P-value)") +plt.title("Volcano Plot") +plt.legend() +plt.savefig("volcano_plot.png", dpi=300) +``` + +### MA Plot + +Show fold change vs mean expression: + +```python +plt.figure(figsize=(10, 6)) + +plt.scatter( + np.log10(results.loc[~significant, "baseMean"] + 1), + results.loc[~significant, "log2FoldChange"], + alpha=0.3, s=10, c='gray' +) +plt.scatter( + np.log10(results.loc[significant, "baseMean"] + 1), + results.loc[significant, "log2FoldChange"], + alpha=0.6, s=10, c='red' +) + +plt.axhline(0, color='blue', linestyle='--', alpha=0.5) +plt.xlabel("Log10(Base Mean + 1)") +plt.ylabel("Log2 Fold Change") +plt.title("MA Plot") +plt.savefig("ma_plot.png", dpi=300) +``` + +## Troubleshooting Common Issues + +### Data Format Problems + +**Issue:** "Index mismatch between counts and metadata" + +**Solution:** Ensure sample names match exactly +```python +print("Counts samples:", counts_df.index.tolist()) +print("Metadata samples:", metadata.index.tolist()) + +# Take intersection if needed +common = counts_df.index.intersection(metadata.index) +counts_df = counts_df.loc[common] +metadata = metadata.loc[common] +``` + +**Issue:** "All genes have zero counts" + +**Solution:** Check if data needs transposition +```python +print(f"Counts shape: {counts_df.shape}") +# If genes > samples, transpose is needed +if counts_df.shape[1] < counts_df.shape[0]: + counts_df = counts_df.T +``` + +### Design Matrix Issues + +**Issue:** "Design matrix is not full rank" + +**Cause:** Confounded variables (e.g., all treated samples in one batch) + +**Solution:** Remove confounded variable or add interaction term +```python +# Check confounding +print(pd.crosstab(metadata.condition, metadata.batch)) + +# Either simplify design or add interaction +design = "~condition" # Remove batch +# OR +design = "~condition + batch + condition:batch" # Model interaction +``` + +### No Significant Genes + +**Diagnostics:** +```python +# Check dispersion distribution +plt.hist(dds.varm["dispersions"], bins=50) +plt.show() + +# Check size factors +print(dds.obsm["size_factors"]) + +# Look at top genes by raw p-value +print(ds.results_df.nsmallest(20, "pvalue")) +``` + +**Possible causes:** +- Small effect sizes +- High biological variability +- Insufficient sample size +- Technical issues (batch effects, outliers) + +## Reference Documentation + +For comprehensive details beyond this workflow-oriented guide: + +- **API Reference** (`references/api_reference.md`): Complete documentation of PyDESeq2 classes, methods, and data structures. Use when needing detailed parameter information or understanding object attributes. + +- **Workflow Guide** (`references/workflow_guide.md`): In-depth guide covering complete analysis workflows, data loading patterns, multi-factor designs, troubleshooting, and best practices. Use when handling complex experimental designs or encountering issues. + +Load these references into context when users need: +- Detailed API documentation: `Read references/api_reference.md` +- Comprehensive workflow examples: `Read references/workflow_guide.md` +- Troubleshooting guidance: `Read references/workflow_guide.md` (see Troubleshooting section) + +## Key Reminders + +1. **Data orientation matters:** Count matrices typically load as genes × samples but need to be samples × genes. Always transpose with `.T` if needed. + +2. **Sample filtering:** Remove samples with missing metadata before analysis to avoid errors. + +3. **Gene filtering:** Filter low-count genes (e.g., < 10 total reads) to improve power and reduce computational time. + +4. **Design formula order:** Put adjustment variables before the variable of interest (e.g., `"~batch + condition"` not `"~condition + batch"`). + +5. **LFC shrinkage timing:** Apply shrinkage after statistical testing and only for visualization/ranking purposes. P-values remain based on unshrunken estimates. + +6. **Result interpretation:** Use `padj < 0.05` for significance, not raw p-values. The Benjamini-Hochberg procedure controls false discovery rate. + +7. **Contrast specification:** The format is `[variable, test_level, reference_level]` where test_level is compared against reference_level. + +8. **Save intermediate objects:** Use pickle to save DeseqDataSet objects for later use or additional analyses without re-running the expensive fitting step. + +## Installation and Requirements + +PyDESeq2 can be installed via pip or conda: + +```bash +# Via pip +pip install pydeseq2 + +# Via conda +conda install -c bioconda pydeseq2 +``` + +**System requirements:** +- Python 3.10-3.11 +- pandas 1.4.3+ +- numpy 1.23.0+ +- scipy 1.11.0+ +- scikit-learn 1.1.1+ +- anndata 0.8.0+ + +**Optional for visualization:** +- matplotlib +- seaborn + +## Additional Resources + +- **Official Documentation:** https://pydeseq2.readthedocs.io +- **GitHub Repository:** https://github.com/owkin/PyDESeq2 +- **Publication:** Muzellec et al. (2023) Bioinformatics, DOI: 10.1093/bioinformatics/btad547 +- **Original DESeq2 (R):** Love et al. (2014) Genome Biology, DOI: 10.1186/s13059-014-0550-8 diff --git a/scientific-packages/pydeseq2/references/api_reference.md b/scientific-packages/pydeseq2/references/api_reference.md new file mode 100644 index 0000000..60d1aba --- /dev/null +++ b/scientific-packages/pydeseq2/references/api_reference.md @@ -0,0 +1,228 @@ +# PyDESeq2 API Reference + +This document provides comprehensive API reference for PyDESeq2 classes, methods, and utilities. + +## Core Classes + +### DeseqDataSet + +The main class for differential expression analysis that handles data processing from normalization through log-fold change fitting. + +**Purpose:** Implements dispersion and log fold-change (LFC) estimation for RNA-seq count data. + +**Initialization Parameters:** +- `counts`: pandas DataFrame of shape (samples × genes) containing non-negative integer read counts +- `metadata`: pandas DataFrame of shape (samples × variables) with sample annotations +- `design`: str, Wilkinson formula specifying the statistical model (e.g., "~condition", "~group + condition") +- `refit_cooks`: bool, whether to refit parameters after removing Cook's distance outliers (default: True) +- `n_cpus`: int, number of CPUs to use for parallel processing (optional) +- `quiet`: bool, suppress progress messages (default: False) + +**Key Methods:** + +#### `deseq2()` +Run the complete DESeq2 pipeline for normalization and dispersion/LFC fitting. + +**Steps performed:** +1. Compute normalization factors (size factors) +2. Fit genewise dispersions +3. Fit dispersion trend curve +4. Calculate dispersion priors +5. Fit MAP (maximum a posteriori) dispersions +6. Fit log fold changes +7. Calculate Cook's distances for outlier detection +8. Optionally refit if `refit_cooks=True` + +**Returns:** None (modifies object in-place) + +#### `to_picklable_anndata()` +Convert the DeseqDataSet to an AnnData object that can be saved with pickle. + +**Returns:** AnnData object with: +- `X`: count data matrix +- `obs`: sample-level metadata (1D) +- `var`: gene-level metadata (1D) +- `varm`: gene-level multi-dimensional data (e.g., LFC estimates) + +**Usage:** +```python +import pickle +with open("result_adata.pkl", "wb") as f: + pickle.dump(dds.to_picklable_anndata(), f) +``` + +**Attributes (after running deseq2()):** +- `layers`: dict containing various matrices (normalized counts, etc.) +- `varm`: dict containing gene-level results (log fold changes, dispersions, etc.) +- `obsm`: dict containing sample-level information +- `uns`: dict containing global parameters + +--- + +### DeseqStats + +Class for performing statistical tests and computing p-values for differential expression. + +**Purpose:** Facilitates PyDESeq2 statistical tests using Wald tests and optional LFC shrinkage. + +**Initialization Parameters:** +- `dds`: DeseqDataSet object that has been processed with `deseq2()` +- `contrast`: list or None, specifies the contrast for testing + - Format: `[variable, test_level, reference_level]` + - Example: `["condition", "treated", "control"]` tests treated vs control + - If None, uses the last coefficient in the design formula +- `alpha`: float, significance threshold for independent filtering (default: 0.05) +- `cooks_filter`: bool, whether to filter outliers based on Cook's distance (default: True) +- `independent_filter`: bool, whether to perform independent filtering (default: True) +- `n_cpus`: int, number of CPUs for parallel processing (optional) +- `quiet`: bool, suppress progress messages (default: False) + +**Key Methods:** + +#### `summary()` +Run Wald tests and compute p-values and adjusted p-values. + +**Steps performed:** +1. Run Wald statistical tests for specified contrast +2. Optional Cook's distance filtering +3. Optional independent filtering to remove low-power tests +4. Multiple testing correction (Benjamini-Hochberg procedure) + +**Returns:** None (results stored in `results_df` attribute) + +**Result DataFrame columns:** +- `baseMean`: mean normalized count across all samples +- `log2FoldChange`: log2 fold change between conditions +- `lfcSE`: standard error of the log2 fold change +- `stat`: Wald test statistic +- `pvalue`: raw p-value +- `padj`: adjusted p-value (FDR-corrected) + +#### `lfc_shrink(coeff=None)` +Apply shrinkage to log fold changes using the apeGLM method. + +**Purpose:** Reduces noise in LFC estimates for better visualization and ranking, especially for genes with low counts or high variability. + +**Parameters:** +- `coeff`: str or None, coefficient name to shrink (if None, uses the coefficient from the contrast) + +**Important:** Shrinkage is applied only for visualization/ranking purposes. The statistical test results (p-values, adjusted p-values) remain unchanged. + +**Returns:** None (updates `results_df` with shrunk LFCs) + +**Attributes:** +- `results_df`: pandas DataFrame containing test results (available after `summary()`) + +--- + +## Utility Functions + +### `pydeseq2.utils.load_example_data(modality="single-factor")` + +Load synthetic example datasets for testing and tutorials. + +**Parameters:** +- `modality`: str, either "single-factor" or "multi-factor" + +**Returns:** tuple of (counts_df, metadata_df) +- `counts_df`: pandas DataFrame with synthetic count data +- `metadata_df`: pandas DataFrame with sample annotations + +--- + +## Preprocessing Module + +The `pydeseq2.preprocessing` module provides utilities for data preparation. + +**Common operations:** +- Gene filtering based on minimum read counts +- Sample filtering based on metadata criteria +- Data transformation and normalization + +--- + +## Inference Classes + +### Inference +Abstract base class defining the interface for DESeq2-related inference methods. + +### DefaultInference +Default implementation of inference methods using scipy, sklearn, and numpy. + +**Purpose:** Provides the mathematical implementations for: +- GLM (Generalized Linear Model) fitting +- Dispersion estimation +- Trend curve fitting +- Statistical testing + +--- + +## Data Structure Requirements + +### Count Matrix +- **Shape:** (samples × genes) +- **Type:** pandas DataFrame +- **Values:** Non-negative integers (raw read counts) +- **Index:** Sample identifiers (must match metadata index) +- **Columns:** Gene identifiers + +### Metadata +- **Shape:** (samples × variables) +- **Type:** pandas DataFrame +- **Index:** Sample identifiers (must match count matrix index) +- **Columns:** Experimental factors (e.g., "condition", "batch", "group") +- **Values:** Categorical or continuous variables used in the design formula + +### Important Notes +- Sample order must match between counts and metadata +- Missing values in metadata should be handled before analysis +- Gene names should be unique +- Count files often need transposition: `counts_df = counts_df.T` + +--- + +## Common Workflow Pattern + +```python +from pydeseq2.dds import DeseqDataSet +from pydeseq2.ds import DeseqStats + +# 1. Initialize dataset +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition", + refit_cooks=True +) + +# 2. Fit dispersions and LFCs +dds.deseq2() + +# 3. Perform statistical testing +ds = DeseqStats( + dds, + contrast=["condition", "treated", "control"], + alpha=0.05 +) +ds.summary() + +# 4. Optional: Shrink LFCs for visualization +ds.lfc_shrink() + +# 5. Access results +results = ds.results_df +``` + +--- + +## Version Compatibility + +PyDESeq2 aims to match the default settings of DESeq2 v1.34.0. Some differences may exist as it is a from-scratch reimplementation in Python. + +**Tested with:** +- Python 3.10-3.11 +- anndata 0.8.0+ +- numpy 1.23.0+ +- pandas 1.4.3+ +- scikit-learn 1.1.1+ +- scipy 1.11.0+ diff --git a/scientific-packages/pydeseq2/references/workflow_guide.md b/scientific-packages/pydeseq2/references/workflow_guide.md new file mode 100644 index 0000000..128d625 --- /dev/null +++ b/scientific-packages/pydeseq2/references/workflow_guide.md @@ -0,0 +1,582 @@ +# PyDESeq2 Workflow Guide + +This document provides detailed step-by-step workflows for common PyDESeq2 analysis patterns. + +## Table of Contents +1. [Complete Differential Expression Analysis](#complete-differential-expression-analysis) +2. [Data Loading and Preparation](#data-loading-and-preparation) +3. [Single-Factor Analysis](#single-factor-analysis) +4. [Multi-Factor Analysis](#multi-factor-analysis) +5. [Result Export and Visualization](#result-export-and-visualization) +6. [Common Patterns and Best Practices](#common-patterns-and-best-practices) +7. [Troubleshooting](#troubleshooting) + +--- + +## Complete Differential Expression Analysis + +### Overview +A standard PyDESeq2 analysis consists of 12 main steps across two phases: + +**Phase 1: Read Counts Modeling (Steps 1-7)** +- Normalization and dispersion estimation +- Log fold-change fitting +- Outlier detection + +**Phase 2: Statistical Analysis (Steps 8-12)** +- Wald testing +- Multiple testing correction +- Optional LFC shrinkage + +### Full Workflow Code + +```python +import pandas as pd +from pydeseq2.dds import DeseqDataSet +from pydeseq2.ds import DeseqStats + +# Load data +counts_df = pd.read_csv("counts.csv", index_col=0).T # Transpose if needed +metadata = pd.read_csv("metadata.csv", index_col=0) + +# Filter low-count genes +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 10] +counts_df = counts_df[genes_to_keep] + +# Remove samples with missing metadata +samples_to_keep = ~metadata.condition.isna() +counts_df = counts_df.loc[samples_to_keep] +metadata = metadata.loc[samples_to_keep] + +# Initialize DeseqDataSet +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition", + refit_cooks=True +) + +# Run normalization and fitting +dds.deseq2() + +# Perform statistical testing +ds = DeseqStats( + dds, + contrast=["condition", "treated", "control"], + alpha=0.05, + cooks_filter=True, + independent_filter=True +) +ds.summary() + +# Optional: Apply LFC shrinkage for visualization +ds.lfc_shrink() + +# Access results +results = ds.results_df +print(results.head()) +``` + +--- + +## Data Loading and Preparation + +### Loading CSV Files + +Count data typically comes in genes × samples format but needs to be transposed: + +```python +import pandas as pd + +# Load count matrix (genes × samples) +counts_df = pd.read_csv("counts.csv", index_col=0) + +# Transpose to samples × genes +counts_df = counts_df.T + +# Load metadata (already in samples × variables format) +metadata = pd.read_csv("metadata.csv", index_col=0) +``` + +### Loading from Other Formats + +**From TSV:** +```python +counts_df = pd.read_csv("counts.tsv", sep="\t", index_col=0).T +metadata = pd.read_csv("metadata.tsv", sep="\t", index_col=0) +``` + +**From saved pickle:** +```python +import pickle + +with open("counts.pkl", "rb") as f: + counts_df = pickle.load(f) + +with open("metadata.pkl", "rb") as f: + metadata = pickle.load(f) +``` + +**From AnnData:** +```python +import anndata as ad + +adata = ad.read_h5ad("data.h5ad") +counts_df = pd.DataFrame( + adata.X, + index=adata.obs_names, + columns=adata.var_names +) +metadata = adata.obs +``` + +### Data Filtering + +**Filter genes with low counts:** +```python +# Remove genes with fewer than 10 total reads +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 10] +counts_df = counts_df[genes_to_keep] +``` + +**Filter samples with missing metadata:** +```python +# Remove samples where 'condition' column is NA +samples_to_keep = ~metadata.condition.isna() +counts_df = counts_df.loc[samples_to_keep] +metadata = metadata.loc[samples_to_keep] +``` + +**Filter by multiple criteria:** +```python +# Keep only samples that meet all criteria +mask = ( + ~metadata.condition.isna() & + (metadata.batch.isin(["batch1", "batch2"])) & + (metadata.age >= 18) +) +counts_df = counts_df.loc[mask] +metadata = metadata.loc[mask] +``` + +### Data Validation + +**Check data structure:** +```python +print(f"Counts shape: {counts_df.shape}") # Should be (samples, genes) +print(f"Metadata shape: {metadata.shape}") # Should be (samples, variables) +print(f"Indices match: {all(counts_df.index == metadata.index)}") + +# Check for negative values +assert (counts_df >= 0).all().all(), "Counts must be non-negative" + +# Check for non-integer values +assert counts_df.applymap(lambda x: x == int(x)).all().all(), "Counts must be integers" +``` + +--- + +## Single-Factor Analysis + +### Simple Two-Group Comparison + +Compare treated vs control samples: + +```python +from pydeseq2.dds import DeseqDataSet +from pydeseq2.ds import DeseqStats + +# Design: model expression as a function of condition +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition" +) + +dds.deseq2() + +# Test treated vs control +ds = DeseqStats( + dds, + contrast=["condition", "treated", "control"] +) +ds.summary() + +# Results +results = ds.results_df +significant = results[results.padj < 0.05] +print(f"Found {len(significant)} significant genes") +``` + +### Multiple Pairwise Comparisons + +When comparing multiple groups: + +```python +# Test each treatment vs control +treatments = ["treated_A", "treated_B", "treated_C"] +all_results = {} + +for treatment in treatments: + ds = DeseqStats( + dds, + contrast=["condition", treatment, "control"] + ) + ds.summary() + all_results[treatment] = ds.results_df + +# Compare results across treatments +for name, results in all_results.items(): + sig = results[results.padj < 0.05] + print(f"{name}: {len(sig)} significant genes") +``` + +--- + +## Multi-Factor Analysis + +### Two-Factor Design + +Account for batch effects while testing condition: + +```python +# Design includes both batch and condition +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~batch + condition" +) + +dds.deseq2() + +# Test condition effect while controlling for batch +ds = DeseqStats( + dds, + contrast=["condition", "treated", "control"] +) +ds.summary() +``` + +### Interaction Effects + +Test whether treatment effect differs between groups: + +```python +# Design includes interaction term +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~group + condition + group:condition" +) + +dds.deseq2() + +# Test the interaction term +ds = DeseqStats(dds, contrast=["group:condition", ...]) +ds.summary() +``` + +### Continuous Covariates + +Include continuous variables like age: + +```python +# Ensure age is numeric in metadata +metadata["age"] = pd.to_numeric(metadata["age"]) + +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~age + condition" +) + +dds.deseq2() +``` + +--- + +## Result Export and Visualization + +### Saving Results + +**Export as CSV:** +```python +# Save statistical results +ds.results_df.to_csv("deseq2_results.csv") + +# Save significant genes only +significant = ds.results_df[ds.results_df.padj < 0.05] +significant.to_csv("significant_genes.csv") + +# Save with sorted results +sorted_results = ds.results_df.sort_values("padj") +sorted_results.to_csv("sorted_results.csv") +``` + +**Save DeseqDataSet:** +```python +import pickle + +# Save as AnnData for later use +with open("dds_result.pkl", "wb") as f: + pickle.dump(dds.to_picklable_anndata(), f) +``` + +**Load saved results:** +```python +# Load results +results = pd.read_csv("deseq2_results.csv", index_col=0) + +# Load AnnData +with open("dds_result.pkl", "rb") as f: + adata = pickle.load(f) +``` + +### Basic Visualization + +**Volcano plot:** +```python +import matplotlib.pyplot as plt +import numpy as np + +results = ds.results_df.copy() +results["-log10(padj)"] = -np.log10(results.padj) + +# Plot +plt.figure(figsize=(10, 6)) +plt.scatter( + results.log2FoldChange, + results["-log10(padj)"], + alpha=0.5, + s=10 +) +plt.axhline(-np.log10(0.05), color='red', linestyle='--', label='padj=0.05') +plt.axvline(1, color='gray', linestyle='--') +plt.axvline(-1, color='gray', linestyle='--') +plt.xlabel("Log2 Fold Change") +plt.ylabel("-Log10(Adjusted P-value)") +plt.title("Volcano Plot") +plt.legend() +plt.savefig("volcano_plot.png", dpi=300) +``` + +**MA plot:** +```python +plt.figure(figsize=(10, 6)) +plt.scatter( + np.log10(results.baseMean + 1), + results.log2FoldChange, + alpha=0.5, + s=10, + c=(results.padj < 0.05), + cmap='bwr' +) +plt.xlabel("Log10(Base Mean + 1)") +plt.ylabel("Log2 Fold Change") +plt.title("MA Plot") +plt.savefig("ma_plot.png", dpi=300) +``` + +--- + +## Common Patterns and Best Practices + +### 1. Data Preprocessing Checklist + +Before running PyDESeq2: +- ✓ Ensure counts are non-negative integers +- ✓ Verify samples × genes orientation +- ✓ Check that sample names match between counts and metadata +- ✓ Remove or handle missing metadata values +- ✓ Filter low-count genes (typically < 10 total reads) +- ✓ Verify experimental factors are properly encoded + +### 2. Design Formula Best Practices + +**Order matters:** Put adjustment variables before the variable of interest +```python +# Correct: control for batch, test condition +design = "~batch + condition" + +# Less ideal: condition listed first +design = "~condition + batch" +``` + +**Use categorical for discrete variables:** +```python +# Ensure proper data types +metadata["condition"] = metadata["condition"].astype("category") +metadata["batch"] = metadata["batch"].astype("category") +``` + +### 3. Statistical Testing Guidelines + +**Set appropriate alpha:** +```python +# Standard significance threshold +ds = DeseqStats(dds, alpha=0.05) + +# More stringent for exploratory analysis +ds = DeseqStats(dds, alpha=0.01) +``` + +**Use independent filtering:** +```python +# Recommended: filter low-power tests +ds = DeseqStats(dds, independent_filter=True) + +# Only disable if you have specific reasons +ds = DeseqStats(dds, independent_filter=False) +``` + +### 4. LFC Shrinkage + +**When to use:** +- For visualization (volcano plots, heatmaps) +- For ranking genes by effect size +- When prioritizing genes for follow-up + +**When NOT to use:** +- For reporting statistical significance (use unshrunken p-values) +- For gene set enrichment analysis (typically uses unshrunken values) + +```python +# Save both versions +ds.results_df.to_csv("results_unshrunken.csv") +ds.lfc_shrink() +ds.results_df.to_csv("results_shrunken.csv") +``` + +### 5. Memory Management + +For large datasets: +```python +# Use parallel processing +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design="~condition", + n_cpus=4 # Adjust based on available cores +) + +# Process in batches if needed +# (split genes into chunks, analyze separately, combine results) +``` + +--- + +## Troubleshooting + +### Error: Index mismatch between counts and metadata + +**Problem:** Sample names don't match +``` +KeyError: Sample names in counts and metadata don't match +``` + +**Solution:** +```python +# Check indices +print("Counts samples:", counts_df.index.tolist()) +print("Metadata samples:", metadata.index.tolist()) + +# Align if needed +common_samples = counts_df.index.intersection(metadata.index) +counts_df = counts_df.loc[common_samples] +metadata = metadata.loc[common_samples] +``` + +### Error: All genes have zero counts + +**Problem:** Data might need transposition +``` +ValueError: All genes have zero total counts +``` + +**Solution:** +```python +# Check data orientation +print(f"Counts shape: {counts_df.shape}") + +# If genes > samples, likely needs transpose +if counts_df.shape[1] < counts_df.shape[0]: + counts_df = counts_df.T +``` + +### Warning: Many genes filtered out + +**Problem:** Too many low-count genes removed + +**Check:** +```python +# See distribution of gene counts +print(counts_df.sum(axis=0).describe()) + +# Visualize +import matplotlib.pyplot as plt +plt.hist(counts_df.sum(axis=0), bins=50, log=True) +plt.xlabel("Total counts per gene") +plt.ylabel("Frequency") +plt.show() +``` + +**Adjust filtering if needed:** +```python +# Try lower threshold +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 5] +``` + +### Error: Design matrix is not full rank + +**Problem:** Confounded design (e.g., all treated samples in one batch) + +**Solution:** +```python +# Check design confounding +print(pd.crosstab(metadata.condition, metadata.batch)) + +# Either remove confounded variable or add interaction term +design = "~condition" # Drop batch +# OR +design = "~condition + batch + condition:batch" # Add interaction +``` + +### Issue: No significant genes found + +**Possible causes:** +1. Small effect sizes +2. High biological variability +3. Insufficient sample size +4. Technical issues (batch effects, outliers) + +**Diagnostics:** +```python +# Check dispersion estimates +import matplotlib.pyplot as plt +dispersions = dds.varm["dispersions"] +plt.hist(dispersions, bins=50) +plt.xlabel("Dispersion") +plt.ylabel("Frequency") +plt.show() + +# Check size factors (should be close to 1) +print("Size factors:", dds.obsm["size_factors"]) + +# Look at top genes even if not significant +top_genes = ds.results_df.nsmallest(20, "pvalue") +print(top_genes) +``` + +### Memory errors on large datasets + +**Solutions:** +```python +# 1. Use fewer CPUs (paradoxically can help) +dds = DeseqDataSet(..., n_cpus=1) + +# 2. Filter more aggressively +genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 20] + +# 3. Process in batches +# Split analysis by gene subsets and combine results +``` diff --git a/scientific-packages/pydeseq2/scripts/run_deseq2_analysis.py b/scientific-packages/pydeseq2/scripts/run_deseq2_analysis.py new file mode 100644 index 0000000..db9df25 --- /dev/null +++ b/scientific-packages/pydeseq2/scripts/run_deseq2_analysis.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +""" +PyDESeq2 Analysis Script + +This script performs a complete differential expression analysis using PyDESeq2. +It can be used as a template for standard RNA-seq DEA workflows. + +Usage: + python run_deseq2_analysis.py --counts counts.csv --metadata metadata.csv \ + --design "~condition" --contrast condition treated control \ + --output results/ + +Requirements: + - pydeseq2 + - pandas + - matplotlib (optional, for plots) +""" + +import argparse +import os +import pickle +import sys +from pathlib import Path + +import pandas as pd + +try: + from pydeseq2.dds import DeseqDataSet + from pydeseq2.ds import DeseqStats +except ImportError: + print("Error: pydeseq2 not installed. Install with: pip install pydeseq2") + sys.exit(1) + + +def load_and_validate_data(counts_path, metadata_path, transpose_counts=True): + """Load count matrix and metadata, perform basic validation.""" + print(f"Loading count data from {counts_path}...") + counts_df = pd.read_csv(counts_path, index_col=0) + + if transpose_counts: + print("Transposing count matrix to samples × genes format...") + counts_df = counts_df.T + + print(f"Loading metadata from {metadata_path}...") + metadata = pd.read_csv(metadata_path, index_col=0) + + print(f"\nData loaded:") + print(f" Counts shape: {counts_df.shape} (samples × genes)") + print(f" Metadata shape: {metadata.shape} (samples × variables)") + + # Validate + if not all(counts_df.index == metadata.index): + print("\nWarning: Sample indices don't match perfectly. Taking intersection...") + common_samples = counts_df.index.intersection(metadata.index) + counts_df = counts_df.loc[common_samples] + metadata = metadata.loc[common_samples] + print(f" Using {len(common_samples)} common samples") + + # Check for negative or non-integer values + if (counts_df < 0).any().any(): + raise ValueError("Count matrix contains negative values") + + return counts_df, metadata + + +def filter_data(counts_df, metadata, min_counts=10, condition_col=None): + """Filter low-count genes and samples with missing data.""" + print(f"\nFiltering data...") + + initial_genes = counts_df.shape[1] + initial_samples = counts_df.shape[0] + + # Filter genes + genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= min_counts] + counts_df = counts_df[genes_to_keep] + genes_removed = initial_genes - counts_df.shape[1] + print(f" Removed {genes_removed} genes with < {min_counts} total counts") + + # Filter samples with missing condition data + if condition_col and condition_col in metadata.columns: + samples_to_keep = ~metadata[condition_col].isna() + counts_df = counts_df.loc[samples_to_keep] + metadata = metadata.loc[samples_to_keep] + samples_removed = initial_samples - counts_df.shape[0] + if samples_removed > 0: + print(f" Removed {samples_removed} samples with missing '{condition_col}' data") + + print(f" Final data shape: {counts_df.shape[0]} samples × {counts_df.shape[1]} genes") + + return counts_df, metadata + + +def run_deseq2(counts_df, metadata, design, n_cpus=1): + """Run DESeq2 normalization and fitting.""" + print(f"\nInitializing DeseqDataSet with design: {design}") + + dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design=design, + refit_cooks=True, + n_cpus=n_cpus, + quiet=False + ) + + print("\nRunning DESeq2 pipeline...") + print(" Step 1/7: Computing size factors...") + print(" Step 2/7: Fitting genewise dispersions...") + print(" Step 3/7: Fitting dispersion trend curve...") + print(" Step 4/7: Computing dispersion priors...") + print(" Step 5/7: Fitting MAP dispersions...") + print(" Step 6/7: Fitting log fold changes...") + print(" Step 7/7: Calculating Cook's distances...") + + dds.deseq2() + + print("\n✓ DESeq2 fitting complete") + + return dds + + +def run_statistical_tests(dds, contrast, alpha=0.05, shrink_lfc=True): + """Perform Wald tests and compute p-values.""" + print(f"\nPerforming statistical tests...") + print(f" Contrast: {contrast}") + print(f" Significance threshold: {alpha}") + + ds = DeseqStats( + dds, + contrast=contrast, + alpha=alpha, + cooks_filter=True, + independent_filter=True, + quiet=False + ) + + print("\n Running Wald tests...") + print(" Filtering outliers based on Cook's distance...") + print(" Applying independent filtering...") + print(" Adjusting p-values (Benjamini-Hochberg)...") + + ds.summary() + + print("\n✓ Statistical testing complete") + + # Optional LFC shrinkage + if shrink_lfc: + print("\nApplying LFC shrinkage for visualization...") + ds.lfc_shrink() + print("✓ LFC shrinkage complete") + + return ds + + +def save_results(ds, dds, output_dir, shrink_lfc=True): + """Save results and intermediate objects.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\nSaving results to {output_dir}/") + + # Save statistical results + results_path = output_dir / "deseq2_results.csv" + ds.results_df.to_csv(results_path) + print(f" Saved: {results_path}") + + # Save significant genes + significant = ds.results_df[ds.results_df.padj < 0.05] + sig_path = output_dir / "significant_genes.csv" + significant.to_csv(sig_path) + print(f" Saved: {sig_path} ({len(significant)} significant genes)") + + # Save sorted results + sorted_results = ds.results_df.sort_values("padj") + sorted_path = output_dir / "results_sorted_by_padj.csv" + sorted_results.to_csv(sorted_path) + print(f" Saved: {sorted_path}") + + # Save DeseqDataSet as pickle + dds_path = output_dir / "deseq_dataset.pkl" + with open(dds_path, "wb") as f: + pickle.dump(dds.to_picklable_anndata(), f) + print(f" Saved: {dds_path}") + + # Print summary + print(f"\n{'='*60}") + print("ANALYSIS SUMMARY") + print(f"{'='*60}") + print(f"Total genes tested: {len(ds.results_df)}") + print(f"Significant genes (padj < 0.05): {len(significant)}") + print(f"Upregulated: {len(significant[significant.log2FoldChange > 0])}") + print(f"Downregulated: {len(significant[significant.log2FoldChange < 0])}") + print(f"{'='*60}") + + # Show top genes + print("\nTop 10 most significant genes:") + print(sorted_results.head(10)[["baseMean", "log2FoldChange", "pvalue", "padj"]]) + + return results_path + + +def create_plots(ds, output_dir): + """Create basic visualization plots.""" + try: + import matplotlib.pyplot as plt + import numpy as np + except ImportError: + print("\nNote: matplotlib not installed. Skipping plot generation.") + return + + output_dir = Path(output_dir) + results = ds.results_df.copy() + + print("\nGenerating plots...") + + # Volcano plot + results["-log10(padj)"] = -np.log10(results.padj.fillna(1)) + + plt.figure(figsize=(10, 6)) + significant = results.padj < 0.05 + plt.scatter( + results.loc[~significant, "log2FoldChange"], + results.loc[~significant, "-log10(padj)"], + alpha=0.3, s=10, c='gray', label='Not significant' + ) + plt.scatter( + results.loc[significant, "log2FoldChange"], + results.loc[significant, "-log10(padj)"], + alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)' + ) + plt.axhline(-np.log10(0.05), color='blue', linestyle='--', linewidth=1, alpha=0.5) + plt.axvline(1, color='gray', linestyle='--', linewidth=1, alpha=0.5) + plt.axvline(-1, color='gray', linestyle='--', linewidth=1, alpha=0.5) + plt.xlabel("Log2 Fold Change", fontsize=12) + plt.ylabel("-Log10(Adjusted P-value)", fontsize=12) + plt.title("Volcano Plot", fontsize=14, fontweight='bold') + plt.legend() + plt.tight_layout() + volcano_path = output_dir / "volcano_plot.png" + plt.savefig(volcano_path, dpi=300) + plt.close() + print(f" Saved: {volcano_path}") + + # MA plot + plt.figure(figsize=(10, 6)) + plt.scatter( + np.log10(results.loc[~significant, "baseMean"] + 1), + results.loc[~significant, "log2FoldChange"], + alpha=0.3, s=10, c='gray', label='Not significant' + ) + plt.scatter( + np.log10(results.loc[significant, "baseMean"] + 1), + results.loc[significant, "log2FoldChange"], + alpha=0.6, s=10, c='red', label='Significant (padj < 0.05)' + ) + plt.axhline(0, color='blue', linestyle='--', linewidth=1, alpha=0.5) + plt.xlabel("Log10(Base Mean + 1)", fontsize=12) + plt.ylabel("Log2 Fold Change", fontsize=12) + plt.title("MA Plot", fontsize=14, fontweight='bold') + plt.legend() + plt.tight_layout() + ma_path = output_dir / "ma_plot.png" + plt.savefig(ma_path, dpi=300) + plt.close() + print(f" Saved: {ma_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Run PyDESeq2 differential expression analysis", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic analysis + python run_deseq2_analysis.py \\ + --counts counts.csv \\ + --metadata metadata.csv \\ + --design "~condition" \\ + --contrast condition treated control \\ + --output results/ + + # Multi-factor analysis + python run_deseq2_analysis.py \\ + --counts counts.csv \\ + --metadata metadata.csv \\ + --design "~batch + condition" \\ + --contrast condition treated control \\ + --output results/ \\ + --n-cpus 4 + """ + ) + + parser.add_argument("--counts", required=True, help="Path to count matrix CSV file") + parser.add_argument("--metadata", required=True, help="Path to metadata CSV file") + parser.add_argument("--design", required=True, help="Design formula (e.g., '~condition')") + parser.add_argument("--contrast", nargs=3, required=True, + metavar=("VARIABLE", "TEST", "REFERENCE"), + help="Contrast specification: variable test_level reference_level") + parser.add_argument("--output", default="results", help="Output directory (default: results)") + parser.add_argument("--min-counts", type=int, default=10, + help="Minimum total counts for gene filtering (default: 10)") + parser.add_argument("--alpha", type=float, default=0.05, + help="Significance threshold (default: 0.05)") + parser.add_argument("--no-transpose", action="store_true", + help="Don't transpose count matrix (use if already samples × genes)") + parser.add_argument("--no-shrink", action="store_true", + help="Skip LFC shrinkage") + parser.add_argument("--n-cpus", type=int, default=1, + help="Number of CPUs for parallel processing (default: 1)") + parser.add_argument("--plots", action="store_true", + help="Generate volcano and MA plots") + + args = parser.parse_args() + + # Load data + counts_df, metadata = load_and_validate_data( + args.counts, + args.metadata, + transpose_counts=not args.no_transpose + ) + + # Filter data + condition_col = args.contrast[0] + counts_df, metadata = filter_data( + counts_df, + metadata, + min_counts=args.min_counts, + condition_col=condition_col + ) + + # Run DESeq2 + dds = run_deseq2(counts_df, metadata, args.design, n_cpus=args.n_cpus) + + # Statistical testing + ds = run_statistical_tests( + dds, + contrast=args.contrast, + alpha=args.alpha, + shrink_lfc=not args.no_shrink + ) + + # Save results + save_results(ds, dds, args.output, shrink_lfc=not args.no_shrink) + + # Create plots if requested + if args.plots: + create_plots(ds, args.output) + + print(f"\n✓ Analysis complete! Results saved to {args.output}/") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pymatgen/SKILL.md b/scientific-packages/pymatgen/SKILL.md new file mode 100644 index 0000000..aaac90a --- /dev/null +++ b/scientific-packages/pymatgen/SKILL.md @@ -0,0 +1,693 @@ +--- +name: pymatgen +description: Comprehensive toolkit for materials science analysis using pymatgen (Python Materials Genomics). Use when working with crystal structures, materials properties, computational materials science, electronic structure analysis, phase diagrams, surface chemistry, or when integrating with Materials Project database. Appropriate for structure file conversion, symmetry analysis, thermodynamic calculations, band structure visualization, surface generation, diffusion analysis, and high-throughput materials screening. +--- + +# Pymatgen - Python Materials Genomics + +## Overview + +Pymatgen is a comprehensive Python library for materials analysis that powers the Materials Project. This skill provides guidance for using pymatgen's extensive capabilities in computational materials science, including: + +- **Structure manipulation**: Creating, reading, writing, and transforming crystal structures and molecules +- **Materials analysis**: Symmetry, coordination environments, bonding, and structure comparison +- **Thermodynamics**: Phase diagrams, Pourbaix diagrams, reaction energies, and stability analysis +- **Electronic structure**: Band structures, density of states, and Fermi surfaces +- **Surfaces and interfaces**: Slab generation, Wulff shapes, adsorption sites, and interface construction +- **Materials Project integration**: Programmatic access to hundreds of thousands of computed materials +- **File I/O**: Support for 100+ file formats from various computational codes + +## When to Use This Skill + +Use this skill when: +- Working with crystal structures or molecular systems in materials science +- Converting between structure file formats (CIF, POSCAR, XYZ, etc.) +- Analyzing symmetry, space groups, or coordination environments +- Computing phase diagrams or assessing thermodynamic stability +- Analyzing electronic structure data (band gaps, DOS, band structures) +- Generating surfaces, slabs, or studying interfaces +- Accessing the Materials Project database programmatically +- Setting up high-throughput computational workflows +- Analyzing diffusion, magnetism, or mechanical properties +- Working with VASP, Gaussian, Quantum ESPRESSO, or other computational codes + +## Quick Start Guide + +### Installation + +```bash +# Core pymatgen +pip install pymatgen + +# With Materials Project API access +pip install pymatgen mp-api + +# Optional dependencies for extended functionality +pip install pymatgen[analysis] # Additional analysis tools +pip install pymatgen[vis] # Visualization tools +``` + +### Basic Structure Operations + +```python +from pymatgen.core import Structure, Lattice + +# Read structure from file (automatic format detection) +struct = Structure.from_file("POSCAR") + +# Create structure from scratch +lattice = Lattice.cubic(3.84) +struct = Structure(lattice, ["Si", "Si"], [[0,0,0], [0.25,0.25,0.25]]) + +# Write to different format +struct.to(filename="structure.cif") + +# Basic properties +print(f"Formula: {struct.composition.reduced_formula}") +print(f"Space group: {struct.get_space_group_info()}") +print(f"Density: {struct.density:.2f} g/cm³") +``` + +### Materials Project Integration + +```bash +# Set up API key +export MP_API_KEY="your_api_key_here" +``` + +```python +from mp_api.client import MPRester + +with MPRester() as mpr: + # Get structure by material ID + struct = mpr.get_structure_by_material_id("mp-149") + + # Search for materials + materials = mpr.materials.summary.search( + formula="Fe2O3", + energy_above_hull=(0, 0.05) + ) +``` + +## Core Capabilities + +### 1. Structure Creation and Manipulation + +Create structures using various methods and perform transformations. + +**From files:** +```python +# Automatic format detection +struct = Structure.from_file("structure.cif") +struct = Structure.from_file("POSCAR") +mol = Molecule.from_file("molecule.xyz") +``` + +**From scratch:** +```python +from pymatgen.core import Structure, Lattice + +# Using lattice parameters +lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, + alpha=120, beta=90, gamma=60) +coords = [[0, 0, 0], [0.75, 0.5, 0.75]] +struct = Structure(lattice, ["Si", "Si"], coords) + +# From space group +struct = Structure.from_spacegroup( + "Fm-3m", + Lattice.cubic(3.5), + ["Si"], + [[0, 0, 0]] +) +``` + +**Transformations:** +```python +from pymatgen.transformations.standard_transformations import ( + SupercellTransformation, + SubstitutionTransformation, + PrimitiveCellTransformation +) + +# Create supercell +trans = SupercellTransformation([[2,0,0],[0,2,0],[0,0,2]]) +supercell = trans.apply_transformation(struct) + +# Substitute elements +trans = SubstitutionTransformation({"Fe": "Mn"}) +new_struct = trans.apply_transformation(struct) + +# Get primitive cell +trans = PrimitiveCellTransformation() +primitive = trans.apply_transformation(struct) +``` + +**Reference:** See `references/core_classes.md` for comprehensive documentation of Structure, Lattice, Molecule, and related classes. + +### 2. File Format Conversion + +Convert between 100+ file formats with automatic format detection. + +**Using convenience methods:** +```python +# Read any format +struct = Structure.from_file("input_file") + +# Write to any format +struct.to(filename="output.cif") +struct.to(filename="POSCAR") +struct.to(filename="output.xyz") +``` + +**Using the conversion script:** +```bash +# Single file conversion +python scripts/structure_converter.py POSCAR structure.cif + +# Batch conversion +python scripts/structure_converter.py *.cif --output-dir ./poscar_files --format poscar +``` + +**Reference:** See `references/io_formats.md` for detailed documentation of all supported formats and code integrations. + +### 3. Structure Analysis and Symmetry + +Analyze structures for symmetry, coordination, and other properties. + +**Symmetry analysis:** +```python +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +sga = SpacegroupAnalyzer(struct) + +# Get space group information +print(f"Space group: {sga.get_space_group_symbol()}") +print(f"Number: {sga.get_space_group_number()}") +print(f"Crystal system: {sga.get_crystal_system()}") + +# Get conventional/primitive cells +conventional = sga.get_conventional_standard_structure() +primitive = sga.get_primitive_standard_structure() +``` + +**Coordination environment:** +```python +from pymatgen.analysis.local_env import CrystalNN + +cnn = CrystalNN() +neighbors = cnn.get_nn_info(struct, n=0) # Neighbors of site 0 + +print(f"Coordination number: {len(neighbors)}") +for neighbor in neighbors: + site = struct[neighbor['site_index']] + print(f" {site.species_string} at {neighbor['weight']:.3f} Å") +``` + +**Using the analysis script:** +```bash +# Comprehensive analysis +python scripts/structure_analyzer.py POSCAR --symmetry --neighbors + +# Export results +python scripts/structure_analyzer.py structure.cif --symmetry --export json +``` + +**Reference:** See `references/analysis_modules.md` for detailed documentation of all analysis capabilities. + +### 4. Phase Diagrams and Thermodynamics + +Construct phase diagrams and analyze thermodynamic stability. + +**Phase diagram construction:** +```python +from mp_api.client import MPRester +from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter + +# Get entries from Materials Project +with MPRester() as mpr: + entries = mpr.get_entries_in_chemsys("Li-Fe-O") + +# Build phase diagram +pd = PhaseDiagram(entries) + +# Check stability +from pymatgen.core import Composition +comp = Composition("LiFeO2") + +# Find entry for composition +for entry in entries: + if entry.composition.reduced_formula == comp.reduced_formula: + e_above_hull = pd.get_e_above_hull(entry) + print(f"Energy above hull: {e_above_hull:.4f} eV/atom") + + if e_above_hull > 0.001: + # Get decomposition + decomp = pd.get_decomposition(comp) + print("Decomposes to:", decomp) + +# Plot +plotter = PDPlotter(pd) +plotter.show() +``` + +**Using the phase diagram script:** +```bash +# Generate phase diagram +python scripts/phase_diagram_generator.py Li-Fe-O --output li_fe_o.png + +# Analyze specific composition +python scripts/phase_diagram_generator.py Li-Fe-O --analyze "LiFeO2" --show +``` + +**Reference:** See `references/analysis_modules.md` (Phase Diagrams section) and `references/transformations_workflows.md` (Workflow 2) for detailed examples. + +### 5. Electronic Structure Analysis + +Analyze band structures, density of states, and electronic properties. + +**Band structure:** +```python +from pymatgen.io.vasp import Vasprun +from pymatgen.electronic_structure.plotter import BSPlotter + +# Read from VASP calculation +vasprun = Vasprun("vasprun.xml") +bs = vasprun.get_band_structure() + +# Analyze +band_gap = bs.get_band_gap() +print(f"Band gap: {band_gap['energy']:.3f} eV") +print(f"Direct: {band_gap['direct']}") +print(f"Is metal: {bs.is_metal()}") + +# Plot +plotter = BSPlotter(bs) +plotter.save_plot("band_structure.png") +``` + +**Density of states:** +```python +from pymatgen.electronic_structure.plotter import DosPlotter + +dos = vasprun.complete_dos + +# Get element-projected DOS +element_dos = dos.get_element_dos() +for element, element_dos_obj in element_dos.items(): + print(f"{element}: {element_dos_obj.get_gap():.3f} eV") + +# Plot +plotter = DosPlotter() +plotter.add_dos("Total DOS", dos) +plotter.show() +``` + +**Reference:** See `references/analysis_modules.md` (Electronic Structure section) and `references/io_formats.md` (VASP section). + +### 6. Surface and Interface Analysis + +Generate slabs, analyze surfaces, and study interfaces. + +**Slab generation:** +```python +from pymatgen.core.surface import SlabGenerator + +# Generate slabs for specific Miller index +slabgen = SlabGenerator( + struct, + miller_index=(1, 1, 1), + min_slab_size=10.0, # Å + min_vacuum_size=10.0, # Å + center_slab=True +) + +slabs = slabgen.get_slabs() + +# Write slabs +for i, slab in enumerate(slabs): + slab.to(filename=f"slab_{i}.cif") +``` + +**Wulff shape construction:** +```python +from pymatgen.analysis.wulff import WulffShape + +# Define surface energies +surface_energies = { + (1, 0, 0): 1.0, + (1, 1, 0): 1.1, + (1, 1, 1): 0.9, +} + +wulff = WulffShape(struct.lattice, surface_energies) +print(f"Surface area: {wulff.surface_area:.2f} Ų") +print(f"Volume: {wulff.volume:.2f} ų") + +wulff.show() +``` + +**Adsorption site finding:** +```python +from pymatgen.analysis.adsorption import AdsorbateSiteFinder +from pymatgen.core import Molecule + +asf = AdsorbateSiteFinder(slab) + +# Find sites +ads_sites = asf.find_adsorption_sites() +print(f"On-top sites: {len(ads_sites['ontop'])}") +print(f"Bridge sites: {len(ads_sites['bridge'])}") +print(f"Hollow sites: {len(ads_sites['hollow'])}") + +# Add adsorbate +adsorbate = Molecule("O", [[0, 0, 0]]) +ads_struct = asf.add_adsorbate(adsorbate, ads_sites["ontop"][0]) +``` + +**Reference:** See `references/analysis_modules.md` (Surface and Interface section) and `references/transformations_workflows.md` (Workflows 3 and 9). + +### 7. Materials Project Database Access + +Programmatically access the Materials Project database. + +**Setup:** +1. Get API key from https://next-gen.materialsproject.org/ +2. Set environment variable: `export MP_API_KEY="your_key_here"` + +**Search and retrieve:** +```python +from mp_api.client import MPRester + +with MPRester() as mpr: + # Search by formula + materials = mpr.materials.summary.search(formula="Fe2O3") + + # Search by chemical system + materials = mpr.materials.summary.search(chemsys="Li-Fe-O") + + # Filter by properties + materials = mpr.materials.summary.search( + chemsys="Li-Fe-O", + energy_above_hull=(0, 0.05), # Stable/metastable + band_gap=(1.0, 3.0) # Semiconducting + ) + + # Get structure + struct = mpr.get_structure_by_material_id("mp-149") + + # Get band structure + bs = mpr.get_bandstructure_by_material_id("mp-149") + + # Get entries for phase diagram + entries = mpr.get_entries_in_chemsys("Li-Fe-O") +``` + +**Reference:** See `references/materials_project_api.md` for comprehensive API documentation and examples. + +### 8. Computational Workflow Setup + +Set up calculations for various electronic structure codes. + +**VASP input generation:** +```python +from pymatgen.io.vasp.sets import MPRelaxSet, MPStaticSet, MPNonSCFSet + +# Relaxation +relax = MPRelaxSet(struct) +relax.write_input("./relax_calc") + +# Static calculation +static = MPStaticSet(struct) +static.write_input("./static_calc") + +# Band structure (non-self-consistent) +nscf = MPNonSCFSet(struct, mode="line") +nscf.write_input("./bandstructure_calc") + +# Custom parameters +custom = MPRelaxSet(struct, user_incar_settings={"ENCUT": 600}) +custom.write_input("./custom_calc") +``` + +**Other codes:** +```python +# Gaussian +from pymatgen.io.gaussian import GaussianInput + +gin = GaussianInput( + mol, + functional="B3LYP", + basis_set="6-31G(d)", + route_parameters={"Opt": None} +) +gin.write_file("input.gjf") + +# Quantum ESPRESSO +from pymatgen.io.pwscf import PWInput + +pwin = PWInput(struct, control={"calculation": "scf"}) +pwin.write_file("pw.in") +``` + +**Reference:** See `references/io_formats.md` (Electronic Structure Code I/O section) and `references/transformations_workflows.md` for workflow examples. + +### 9. Advanced Analysis + +**Diffraction patterns:** +```python +from pymatgen.analysis.diffraction.xrd import XRDCalculator + +xrd = XRDCalculator() +pattern = xrd.get_pattern(struct) + +# Get peaks +for peak in pattern.hkls: + print(f"2θ = {peak['2theta']:.2f}°, hkl = {peak['hkl']}") + +pattern.plot() +``` + +**Elastic properties:** +```python +from pymatgen.analysis.elasticity import ElasticTensor + +# From elastic tensor matrix +elastic_tensor = ElasticTensor.from_voigt(matrix) + +print(f"Bulk modulus: {elastic_tensor.k_voigt:.1f} GPa") +print(f"Shear modulus: {elastic_tensor.g_voigt:.1f} GPa") +print(f"Young's modulus: {elastic_tensor.y_mod:.1f} GPa") +``` + +**Magnetic ordering:** +```python +from pymatgen.transformations.advanced_transformations import MagOrderingTransformation + +# Enumerate magnetic orderings +trans = MagOrderingTransformation({"Fe": 5.0}) +mag_structs = trans.apply_transformation(struct, return_ranked_list=True) + +# Get lowest energy magnetic structure +lowest_energy_struct = mag_structs[0]['structure'] +``` + +**Reference:** See `references/analysis_modules.md` for comprehensive analysis module documentation. + +## Bundled Resources + +### Scripts (`scripts/`) + +Executable Python scripts for common tasks: + +- **`structure_converter.py`**: Convert between structure file formats + - Supports batch conversion and automatic format detection + - Usage: `python scripts/structure_converter.py POSCAR structure.cif` + +- **`structure_analyzer.py`**: Comprehensive structure analysis + - Symmetry, coordination, lattice parameters, distance matrix + - Usage: `python scripts/structure_analyzer.py structure.cif --symmetry --neighbors` + +- **`phase_diagram_generator.py`**: Generate phase diagrams from Materials Project + - Stability analysis and thermodynamic properties + - Usage: `python scripts/phase_diagram_generator.py Li-Fe-O --analyze "LiFeO2"` + +All scripts include detailed help: `python scripts/script_name.py --help` + +### References (`references/`) + +Comprehensive documentation loaded into context as needed: + +- **`core_classes.md`**: Element, Structure, Lattice, Molecule, Composition classes +- **`io_formats.md`**: File format support and code integration (VASP, Gaussian, etc.) +- **`analysis_modules.md`**: Phase diagrams, surfaces, electronic structure, symmetry +- **`materials_project_api.md`**: Complete Materials Project API guide +- **`transformations_workflows.md`**: Transformations framework and common workflows + +Load references when detailed information is needed about specific modules or workflows. + +## Common Workflows + +### High-Throughput Structure Generation + +```python +from pymatgen.transformations.standard_transformations import SubstitutionTransformation +from pymatgen.io.vasp.sets import MPRelaxSet + +# Generate doped structures +base_struct = Structure.from_file("POSCAR") +dopants = ["Mn", "Co", "Ni", "Cu"] + +for dopant in dopants: + trans = SubstitutionTransformation({"Fe": dopant}) + doped_struct = trans.apply_transformation(base_struct) + + # Generate VASP inputs + vasp_input = MPRelaxSet(doped_struct) + vasp_input.write_input(f"./calcs/Fe_{dopant}") +``` + +### Band Structure Calculation Workflow + +```python +# 1. Relaxation +relax = MPRelaxSet(struct) +relax.write_input("./1_relax") + +# 2. Static (after relaxation) +relaxed = Structure.from_file("1_relax/CONTCAR") +static = MPStaticSet(relaxed) +static.write_input("./2_static") + +# 3. Band structure (non-self-consistent) +nscf = MPNonSCFSet(relaxed, mode="line") +nscf.write_input("./3_bandstructure") + +# 4. Analysis +from pymatgen.io.vasp import Vasprun +vasprun = Vasprun("3_bandstructure/vasprun.xml") +bs = vasprun.get_band_structure() +bs.get_band_gap() +``` + +### Surface Energy Calculation + +```python +# 1. Get bulk energy +bulk_vasprun = Vasprun("bulk/vasprun.xml") +bulk_E_per_atom = bulk_vasprun.final_energy / len(bulk) + +# 2. Generate and calculate slabs +slabgen = SlabGenerator(bulk, (1,1,1), 10, 15) +slab = slabgen.get_slabs()[0] + +MPRelaxSet(slab).write_input("./slab_calc") + +# 3. Calculate surface energy (after calculation) +slab_vasprun = Vasprun("slab_calc/vasprun.xml") +E_surf = (slab_vasprun.final_energy - len(slab) * bulk_E_per_atom) / (2 * slab.surface_area) +E_surf *= 16.021766 # Convert eV/Ų to J/m² +``` + +**More workflows:** See `references/transformations_workflows.md` for 10 detailed workflow examples. + +## Best Practices + +### Structure Handling + +1. **Use automatic format detection**: `Structure.from_file()` handles most formats +2. **Prefer immutable structures**: Use `IStructure` when structure shouldn't change +3. **Check symmetry**: Use `SpacegroupAnalyzer` to reduce to primitive cell +4. **Validate structures**: Check for overlapping atoms or unreasonable bond lengths + +### File I/O + +1. **Use convenience methods**: `from_file()` and `to()` are preferred +2. **Specify formats explicitly**: When automatic detection fails +3. **Handle exceptions**: Wrap file I/O in try-except blocks +4. **Use serialization**: `as_dict()`/`from_dict()` for version-safe storage + +### Materials Project API + +1. **Use context manager**: Always use `with MPRester() as mpr:` +2. **Batch queries**: Request multiple items at once +3. **Cache results**: Save frequently used data locally +4. **Filter effectively**: Use property filters to reduce data transfer + +### Computational Workflows + +1. **Use input sets**: Prefer `MPRelaxSet`, `MPStaticSet` over manual INCAR +2. **Check convergence**: Always verify calculations converged +3. **Track transformations**: Use `TransformedStructure` for provenance +4. **Organize calculations**: Use clear directory structures + +### Performance + +1. **Reduce symmetry**: Use primitive cells when possible +2. **Limit neighbor searches**: Specify reasonable cutoff radii +3. **Use appropriate methods**: Different analysis tools have different speed/accuracy tradeoffs +4. **Parallelize when possible**: Many operations can be parallelized + +## Units and Conventions + +Pymatgen uses atomic units throughout: +- **Lengths**: Angstroms (Å) +- **Energies**: Electronvolts (eV) +- **Angles**: Degrees (°) +- **Magnetic moments**: Bohr magnetons (μB) +- **Time**: Femtoseconds (fs) + +Convert units using `pymatgen.core.units` when needed. + +## Integration with Other Tools + +Pymatgen integrates seamlessly with: +- **ASE** (Atomic Simulation Environment) +- **Phonopy** (phonon calculations) +- **BoltzTraP** (transport properties) +- **Atomate/Fireworks** (workflow management) +- **AiiDA** (provenance tracking) +- **Zeo++** (pore analysis) +- **OpenBabel** (molecule conversion) + +## Troubleshooting + +**Import errors**: Install missing dependencies +```bash +pip install pymatgen[analysis,vis] +``` + +**API key not found**: Set MP_API_KEY environment variable +```bash +export MP_API_KEY="your_key_here" +``` + +**Structure read failures**: Check file format and syntax +```python +# Try explicit format specification +struct = Structure.from_file("file.txt", fmt="cif") +``` + +**Symmetry analysis fails**: Structure may have numerical precision issues +```python +# Increase tolerance +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +sga = SpacegroupAnalyzer(struct, symprec=0.1) +``` + +## Additional Resources + +- **Documentation**: https://pymatgen.org/ +- **Materials Project**: https://materialsproject.org/ +- **GitHub**: https://github.com/materialsproject/pymatgen +- **Forum**: https://matsci.org/ +- **Example notebooks**: https://matgenb.materialsvirtuallab.org/ + +## Version Notes + +This skill is designed for pymatgen 2024.x and later. For the Materials Project API, use the `mp-api` package (separate from legacy `pymatgen.ext.matproj`). + +Requirements: +- Python 3.10 or higher +- pymatgen >= 2023.x +- mp-api (for Materials Project access) diff --git a/scientific-packages/pymatgen/references/analysis_modules.md b/scientific-packages/pymatgen/references/analysis_modules.md new file mode 100644 index 0000000..3fb980d --- /dev/null +++ b/scientific-packages/pymatgen/references/analysis_modules.md @@ -0,0 +1,530 @@ +# Pymatgen Analysis Modules Reference + +This reference documents pymatgen's extensive analysis capabilities for materials characterization, property prediction, and computational analysis. + +## Phase Diagrams and Thermodynamics + +### Phase Diagram Construction + +```python +from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter +from pymatgen.entries.computed_entries import ComputedEntry + +# Create entries (composition and energy per atom) +entries = [ + ComputedEntry("Fe", -8.4), + ComputedEntry("O2", -4.9), + ComputedEntry("FeO", -6.7), + ComputedEntry("Fe2O3", -8.3), + ComputedEntry("Fe3O4", -9.1), +] + +# Build phase diagram +pd = PhaseDiagram(entries) + +# Get stable entries +stable_entries = pd.stable_entries + +# Get energy above hull (stability) +entry_to_test = ComputedEntry("Fe2O3", -8.0) +energy_above_hull = pd.get_e_above_hull(entry_to_test) + +# Get decomposition products +decomp = pd.get_decomposition(entry_to_test.composition) +# Returns: {entry1: fraction1, entry2: fraction2, ...} + +# Get equilibrium reaction energy +rxn_energy = pd.get_equilibrium_reaction_energy(entry_to_test) + +# Plot phase diagram +plotter = PDPlotter(pd) +plotter.show() +plotter.write_image("phase_diagram.png") +``` + +### Chemical Potential Diagrams + +```python +from pymatgen.analysis.phase_diagram import ChemicalPotentialDiagram + +# Create chemical potential diagram +cpd = ChemicalPotentialDiagram(entries, limits={"O": (-10, 0)}) + +# Get domains (stability regions) +domains = cpd.domains +``` + +### Pourbaix Diagrams + +Electrochemical phase diagrams with pH and potential axes. + +```python +from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter +from pymatgen.entries.computed_entries import ComputedEntry + +# Create entries with corrections for aqueous species +entries = [...] # Include solids and ions + +# Build Pourbaix diagram +pb = PourbaixDiagram(entries) + +# Get stable entry at specific pH and potential +stable_entry = pb.get_stable_entry(pH=7, V=0) + +# Plot +plotter = PourbaixPlotter(pb) +plotter.show() +``` + +## Structure Analysis + +### Structure Matching and Comparison + +```python +from pymatgen.analysis.structure_matcher import StructureMatcher + +matcher = StructureMatcher() + +# Check if structures match +is_match = matcher.fit(struct1, struct2) + +# Get mapping between structures +mapping = matcher.get_mapping(struct1, struct2) + +# Group similar structures +grouped = matcher.group_structures([struct1, struct2, struct3, ...]) +``` + +### Ewald Summation + +Calculate electrostatic energy of ionic structures. + +```python +from pymatgen.analysis.ewald import EwaldSummation + +ewald = EwaldSummation(struct) +total_energy = ewald.total_energy # In eV +forces = ewald.forces # Forces on each site +``` + +### Symmetry Analysis + +```python +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +sga = SpacegroupAnalyzer(struct) + +# Get space group information +spacegroup_symbol = sga.get_space_group_symbol() # e.g., "Fm-3m" +spacegroup_number = sga.get_space_group_number() # e.g., 225 +crystal_system = sga.get_crystal_system() # e.g., "cubic" + +# Get symmetrized structure +sym_struct = sga.get_symmetrized_structure() +equivalent_sites = sym_struct.equivalent_sites + +# Get conventional/primitive cells +conventional = sga.get_conventional_standard_structure() +primitive = sga.get_primitive_standard_structure() + +# Get symmetry operations +symmetry_ops = sga.get_symmetry_operations() +``` + +## Local Environment Analysis + +### Coordination Environment + +```python +from pymatgen.analysis.local_env import ( + VoronoiNN, # Voronoi tessellation + CrystalNN, # Crystal-based + MinimumDistanceNN, # Distance cutoff + BrunnerNN_real, # Brunner method +) + +# Voronoi nearest neighbors +voronoi = VoronoiNN() +neighbors = voronoi.get_nn_info(struct, n=0) # Neighbors of site 0 + +# CrystalNN (recommended for most cases) +crystalnn = CrystalNN() +neighbors = crystalnn.get_nn_info(struct, n=0) + +# Analyze all sites +for i, site in enumerate(struct): + neighbors = voronoi.get_nn_info(struct, i) + coordination_number = len(neighbors) + print(f"Site {i} ({site.species_string}): CN = {coordination_number}") +``` + +### Coordination Geometry (ChemEnv) + +Detailed coordination environment identification. + +```python +from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder +from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import SimplestChemenvStrategy + +lgf = LocalGeometryFinder() +lgf.setup_structure(struct) + +# Get coordination environment for site +se = lgf.compute_structure_environments(only_indices=[0]) +strategy = SimplestChemenvStrategy() +lse = strategy.get_site_coordination_environment(se[0]) + +print(f"Coordination: {lse}") +``` + +### Bond Valence Sum + +```python +from pymatgen.analysis.bond_valence import BVAnalyzer + +bva = BVAnalyzer() + +# Calculate oxidation states +valences = bva.get_valences(struct) + +# Get structure with oxidation states +struct_with_oxi = bva.get_oxi_state_decorated_structure(struct) +``` + +## Surface and Interface Analysis + +### Surface (Slab) Generation + +```python +from pymatgen.core.surface import SlabGenerator, generate_all_slabs + +# Generate slabs for a specific Miller index +slabgen = SlabGenerator( + struct, + miller_index=(1, 1, 1), + min_slab_size=10.0, # Minimum slab thickness (Å) + min_vacuum_size=10.0, # Minimum vacuum thickness (Å) + center_slab=True +) + +slabs = slabgen.get_slabs() + +# Generate all slabs up to a Miller index +all_slabs = generate_all_slabs( + struct, + max_index=2, + min_slab_size=10.0, + min_vacuum_size=10.0 +) +``` + +### Wulff Shape Construction + +```python +from pymatgen.analysis.wulff import WulffShape + +# Define surface energies (J/m²) +surface_energies = { + (1, 0, 0): 1.0, + (1, 1, 0): 1.1, + (1, 1, 1): 0.9, +} + +wulff = WulffShape(struct.lattice, surface_energies, symm_reduce=True) + +# Get effective radius and surface area +effective_radius = wulff.effective_radius +surface_area = wulff.surface_area +volume = wulff.volume + +# Visualize +wulff.show() +``` + +### Adsorption Site Finding + +```python +from pymatgen.analysis.adsorption import AdsorbateSiteFinder + +asf = AdsorbateSiteFinder(slab) + +# Find adsorption sites +ads_sites = asf.find_adsorption_sites() +# Returns dictionary: {"ontop": [...], "bridge": [...], "hollow": [...]} + +# Generate structures with adsorbates +from pymatgen.core import Molecule +adsorbate = Molecule("O", [[0, 0, 0]]) + +ads_structs = asf.generate_adsorption_structures( + adsorbate, + repeat=[2, 2, 1], # Supercell to reduce adsorbate coverage +) +``` + +### Interface Construction + +```python +from pymatgen.analysis.interfaces.coherent_interfaces import CoherentInterfaceBuilder + +# Build interface between two materials +builder = CoherentInterfaceBuilder( + substrate_structure=substrate, + film_structure=film, + substrate_miller=(0, 0, 1), + film_miller=(1, 1, 1), +) + +interfaces = builder.get_interfaces() +``` + +## Magnetism + +### Magnetic Structure Analysis + +```python +from pymatgen.analysis.magnetism import CollinearMagneticStructureAnalyzer + +analyzer = CollinearMagneticStructureAnalyzer(struct) + +# Get magnetic ordering +ordering = analyzer.ordering # e.g., "FM" (ferromagnetic), "AFM", "FiM" + +# Get magnetic space group +mag_space_group = analyzer.get_structure_with_spin().get_space_group_info() +``` + +### Magnetic Ordering Enumeration + +```python +from pymatgen.transformations.advanced_transformations import MagOrderingTransformation + +# Enumerate possible magnetic orderings +mag_trans = MagOrderingTransformation({"Fe": 5.0}) # Magnetic moment in μB +transformed_structures = mag_trans.apply_transformation(struct, return_ranked_list=True) +``` + +## Electronic Structure Analysis + +### Band Structure Analysis + +```python +from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine +from pymatgen.electronic_structure.plotter import BSPlotter + +# Read band structure from VASP calculation +from pymatgen.io.vasp import Vasprun +vasprun = Vasprun("vasprun.xml") +bs = vasprun.get_band_structure() + +# Get band gap +band_gap = bs.get_band_gap() +# Returns: {'energy': gap_value, 'direct': True/False, 'transition': '...'} + +# Check if metal +is_metal = bs.is_metal() + +# Get VBM and CBM +vbm = bs.get_vbm() +cbm = bs.get_cbm() + +# Plot band structure +plotter = BSPlotter(bs) +plotter.show() +plotter.save_plot("band_structure.png") +``` + +### Density of States (DOS) + +```python +from pymatgen.electronic_structure.dos import CompleteDos +from pymatgen.electronic_structure.plotter import DosPlotter + +# Read DOS from VASP calculation +vasprun = Vasprun("vasprun.xml") +dos = vasprun.complete_dos + +# Get total DOS +total_dos = dos.densities + +# Get projected DOS +pdos = dos.get_element_dos() # By element +site_dos = dos.get_site_dos(struct[0]) # For specific site +spd_dos = dos.get_spd_dos() # By orbital (s, p, d) + +# Plot DOS +plotter = DosPlotter() +plotter.add_dos("Total", dos) +plotter.show() +``` + +### Fermi Surface + +```python +from pymatgen.electronic_structure.boltztrap2 import BoltztrapRunner + +runner = BoltztrapRunner(struct, nelec=n_electrons) +runner.run() + +# Get transport properties at different temperatures +results = runner.get_results() +``` + +## Diffraction + +### X-ray Diffraction (XRD) + +```python +from pymatgen.analysis.diffraction.xrd import XRDCalculator + +xrd = XRDCalculator() + +pattern = xrd.get_pattern(struct, two_theta_range=(0, 90)) + +# Get peak data +for peak in pattern.hkls: + print(f"2θ = {peak['2theta']:.2f}°, hkl = {peak['hkl']}, I = {peak['intensity']:.1f}") + +# Plot pattern +pattern.plot() +``` + +### Neutron Diffraction + +```python +from pymatgen.analysis.diffraction.neutron import NDCalculator + +nd = NDCalculator() +pattern = nd.get_pattern(struct) +``` + +## Elasticity and Mechanical Properties + +```python +from pymatgen.analysis.elasticity import ElasticTensor, Stress, Strain + +# Create elastic tensor from matrix +elastic_tensor = ElasticTensor([[...]]) # 6x6 or 3x3x3x3 matrix + +# Get mechanical properties +bulk_modulus = elastic_tensor.k_voigt # Voigt bulk modulus (GPa) +shear_modulus = elastic_tensor.g_voigt # Shear modulus (GPa) +youngs_modulus = elastic_tensor.y_mod # Young's modulus (GPa) + +# Apply strain +strain = Strain([[0.01, 0, 0], [0, 0, 0], [0, 0, 0]]) +stress = elastic_tensor.calculate_stress(strain) +``` + +## Reaction Analysis + +### Reaction Computation + +```python +from pymatgen.analysis.reaction_calculator import ComputedReaction + +reactants = [ComputedEntry("Fe", -8.4), ComputedEntry("O2", -4.9)] +products = [ComputedEntry("Fe2O3", -8.3)] + +rxn = ComputedReaction(reactants, products) + +# Get balanced equation +balanced_rxn = rxn.normalized_repr # e.g., "2 Fe + 1.5 O2 -> Fe2O3" + +# Get reaction energy +energy = rxn.calculated_reaction_energy # eV per formula unit +``` + +### Reaction Path Finding + +```python +from pymatgen.analysis.path_finder import ChgcarPotential, NEBPathfinder + +# Read charge density +chgcar_potential = ChgcarPotential.from_file("CHGCAR") + +# Find diffusion path +neb_path = NEBPathfinder( + start_struct, + end_struct, + relax_sites=[i for i in range(len(start_struct))], + v=chgcar_potential +) + +images = neb_path.images # Interpolated structures for NEB +``` + +## Molecular Analysis + +### Bond Analysis + +```python +# Get covalent bonds +bonds = mol.get_covalent_bonds() + +for bond in bonds: + print(f"{bond.site1.species_string} - {bond.site2.species_string}: {bond.length:.2f} Å") +``` + +### Molecule Graph + +```python +from pymatgen.analysis.graphs import MoleculeGraph +from pymatgen.analysis.local_env import OpenBabelNN + +# Build molecule graph +mg = MoleculeGraph.with_local_env_strategy(mol, OpenBabelNN()) + +# Get fragments +fragments = mg.get_disconnected_fragments() + +# Find rings +rings = mg.find_rings() +``` + +## Spectroscopy + +### X-ray Absorption Spectroscopy (XAS) + +```python +from pymatgen.analysis.xas.spectrum import XAS + +# Read XAS spectrum +xas = XAS.from_file("xas.dat") + +# Normalize and process +xas.normalize() +``` + +## Additional Analysis Tools + +### Grain Boundaries + +```python +from pymatgen.analysis.gb.grain import GrainBoundaryGenerator + +gb_gen = GrainBoundaryGenerator(struct) +gb_structures = gb_gen.generate_grain_boundaries( + rotation_axis=[0, 0, 1], + rotation_angle=36.87, # degrees +) +``` + +### Prototypes and Structure Matching + +```python +from pymatgen.analysis.prototypes import AflowPrototypeMatcher + +matcher = AflowPrototypeMatcher() +prototype = matcher.get_prototypes(struct) +``` + +## Best Practices + +1. **Start simple**: Use basic analysis before advanced methods +2. **Validate results**: Cross-check analysis with multiple methods +3. **Consider symmetry**: Use `SpacegroupAnalyzer` to reduce computational cost +4. **Check convergence**: Ensure input structures are well-converged +5. **Use appropriate methods**: Different analyses have different accuracy/speed tradeoffs +6. **Visualize results**: Use built-in plotters for quick validation +7. **Save intermediate results**: Complex analyses can be time-consuming diff --git a/scientific-packages/pymatgen/references/core_classes.md b/scientific-packages/pymatgen/references/core_classes.md new file mode 100644 index 0000000..8758c46 --- /dev/null +++ b/scientific-packages/pymatgen/references/core_classes.md @@ -0,0 +1,318 @@ +# Pymatgen Core Classes Reference + +This reference documents the fundamental classes in `pymatgen.core` that form the foundation for materials analysis. + +## Architecture Principles + +Pymatgen follows an object-oriented design where elements, sites, and structures are represented as objects. The framework emphasizes periodic boundary conditions for crystal representation while maintaining flexibility for molecular systems. + +**Unit Conventions**: All units in pymatgen are typically assumed to be in atomic units: +- Lengths: angstroms (Å) +- Energies: electronvolts (eV) +- Angles: degrees + +## Element and Periodic Table + +### Element +Represents periodic table elements with comprehensive properties. + +**Creation methods:** +```python +from pymatgen.core import Element + +# Create from symbol +si = Element("Si") +# Create from atomic number +si = Element.from_Z(14) +# Create from name +si = Element.from_name("silicon") +``` + +**Key properties:** +- `atomic_mass`: Atomic mass in amu +- `atomic_radius`: Atomic radius in angstroms +- `electronegativity`: Pauling electronegativity +- `ionization_energy`: First ionization energy in eV +- `common_oxidation_states`: List of common oxidation states +- `is_metal`, `is_halogen`, `is_noble_gas`, etc.: Boolean properties +- `X`: Element symbol as string + +### Species +Extends Element for charged ions and specific oxidation states. + +```python +from pymatgen.core import Species + +# Create an Fe2+ ion +fe2 = Species("Fe", 2) +# Or with explicit sign +fe2 = Species("Fe", +2) +``` + +### DummySpecies +Placeholder atoms for special structural representations (e.g., vacancies). + +```python +from pymatgen.core import DummySpecies + +vacancy = DummySpecies("X") +``` + +## Composition + +Represents chemical formulas and compositions, enabling chemical analysis and manipulation. + +### Creation +```python +from pymatgen.core import Composition + +# From string formula +comp = Composition("Fe2O3") +# From dictionary +comp = Composition({"Fe": 2, "O": 3}) +# From weight dictionary +comp = Composition.from_weight_dict({"Fe": 111.69, "O": 48.00}) +``` + +### Key methods +- `get_reduced_formula_and_factor()`: Returns reduced formula and multiplication factor +- `oxi_state_guesses()`: Attempts to determine oxidation states +- `replace(replacements_dict)`: Replace elements +- `add_charges_from_oxi_state_guesses()`: Infer and add oxidation states +- `is_element`: Check if composition is a single element + +### Key properties +- `weight`: Molecular weight +- `reduced_formula`: Reduced chemical formula +- `hill_formula`: Formula in Hill notation (C, H, then alphabetical) +- `num_atoms`: Total number of atoms +- `chemical_system`: Alphabetically sorted elements (e.g., "Fe-O") +- `element_composition`: Dictionary of element to amount + +## Lattice + +Defines unit cell geometry for crystal structures. + +### Creation +```python +from pymatgen.core import Lattice + +# From lattice parameters +lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, + alpha=120, beta=90, gamma=60) + +# From matrix (row vectors are lattice vectors) +lattice = Lattice([[3.84, 0, 0], + [0, 3.84, 0], + [0, 0, 3.84]]) + +# Cubic lattice +lattice = Lattice.cubic(3.84) +# Hexagonal lattice +lattice = Lattice.hexagonal(a=2.95, c=4.68) +``` + +### Key methods +- `get_niggli_reduced_lattice()`: Returns Niggli-reduced lattice +- `get_distance_and_image(frac_coords1, frac_coords2)`: Distance between fractional coordinates with periodic boundary conditions +- `get_all_distances(frac_coords1, frac_coords2)`: Distances including periodic images + +### Key properties +- `volume`: Volume of the unit cell (ų) +- `abc`: Lattice parameters (a, b, c) as tuple +- `angles`: Lattice angles (alpha, beta, gamma) as tuple +- `matrix`: 3x3 matrix of lattice vectors +- `reciprocal_lattice`: Reciprocal lattice object +- `is_orthogonal`: Whether lattice vectors are orthogonal + +## Sites + +### Site +Represents an atomic position in non-periodic systems. + +```python +from pymatgen.core import Site + +site = Site("Si", [0.0, 0.0, 0.0]) # Species and Cartesian coordinates +``` + +### PeriodicSite +Represents an atomic position in a periodic lattice with fractional coordinates. + +```python +from pymatgen.core import PeriodicSite + +site = PeriodicSite("Si", [0.5, 0.5, 0.5], lattice) # Species, fractional coords, lattice +``` + +**Key methods:** +- `distance(other_site)`: Distance to another site +- `is_periodic_image(other_site)`: Check if sites are periodic images + +**Key properties:** +- `species`: Species or element at the site +- `coords`: Cartesian coordinates +- `frac_coords`: Fractional coordinates (for PeriodicSite) +- `x`, `y`, `z`: Individual Cartesian coordinates + +## Structure + +Represents a crystal structure as a collection of periodic sites. `Structure` is mutable, while `IStructure` is immutable. + +### Creation +```python +from pymatgen.core import Structure, Lattice + +# From scratch +coords = [[0, 0, 0], [0.75, 0.5, 0.75]] +lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, + alpha=120, beta=90, gamma=60) +struct = Structure(lattice, ["Si", "Si"], coords) + +# From file (automatic format detection) +struct = Structure.from_file("POSCAR") +struct = Structure.from_file("structure.cif") + +# From spacegroup +struct = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3.5), + ["Si"], [[0, 0, 0]]) +``` + +### File I/O +```python +# Write to file (format inferred from extension) +struct.to(filename="output.cif") +struct.to(filename="POSCAR") +struct.to(filename="structure.xyz") + +# Get string representation +cif_string = struct.to(fmt="cif") +poscar_string = struct.to(fmt="poscar") +``` + +### Key methods + +**Structure modification:** +- `append(species, coords)`: Add a site +- `insert(i, species, coords)`: Insert site at index +- `remove_sites(indices)`: Remove sites by index +- `replace(i, species)`: Replace species at index +- `apply_strain(strain)`: Apply strain to structure +- `perturb(distance)`: Randomly perturb atomic positions +- `make_supercell(scaling_matrix)`: Create supercell +- `get_primitive_structure()`: Get primitive cell + +**Analysis:** +- `get_distance(i, j)`: Distance between sites i and j +- `get_neighbors(site, r)`: Get neighbors within radius r +- `get_all_neighbors(r)`: Get all neighbors for all sites +- `get_space_group_info()`: Get space group information +- `matches(other_struct)`: Check if structures match + +**Interpolation:** +- `interpolate(end_structure, nimages)`: Interpolate between structures + +### Key properties +- `lattice`: Lattice object +- `species`: List of species at each site +- `sites`: List of PeriodicSite objects +- `num_sites`: Number of sites +- `volume`: Volume of the structure +- `density`: Density in g/cm³ +- `composition`: Composition object +- `formula`: Chemical formula +- `distance_matrix`: Matrix of pairwise distances + +## Molecule + +Represents non-periodic collections of atoms. `Molecule` is mutable, while `IMolecule` is immutable. + +### Creation +```python +from pymatgen.core import Molecule + +# From scratch +coords = [[0.00, 0.00, 0.00], + [0.00, 0.00, 1.08]] +mol = Molecule(["C", "O"], coords) + +# From file +mol = Molecule.from_file("molecule.xyz") +mol = Molecule.from_file("molecule.mol") +``` + +### Key methods +- `get_covalent_bonds()`: Returns bonds based on covalent radii +- `get_neighbors(site, r)`: Get neighbors within radius +- `get_zmatrix()`: Get Z-matrix representation +- `get_distance(i, j)`: Distance between sites +- `get_centered_molecule()`: Center molecule at origin + +### Key properties +- `species`: List of species +- `sites`: List of Site objects +- `num_sites`: Number of atoms +- `charge`: Total charge of molecule +- `spin_multiplicity`: Spin multiplicity +- `center_of_mass`: Center of mass coordinates + +## Serialization + +All core objects implement `as_dict()` and `from_dict()` methods for robust JSON/YAML persistence. + +```python +# Serialize to dictionary +struct_dict = struct.as_dict() + +# Write to JSON +import json +with open("structure.json", "w") as f: + json.dump(struct_dict, f) + +# Read from JSON +with open("structure.json", "r") as f: + struct_dict = json.load(f) + struct = Structure.from_dict(struct_dict) +``` + +This approach addresses limitations of Python pickling and maintains compatibility across pymatgen versions. + +## Additional Core Classes + +### CovalentBond +Represents bonds in molecules. + +**Key properties:** +- `length`: Bond length +- `get_bond_order()`: Returns bond order (single, double, triple) + +### Ion +Represents charged ionic species with oxidation states. + +```python +from pymatgen.core import Ion + +# Create Fe2+ ion +fe2_ion = Ion.from_formula("Fe2+") +``` + +### Interface +Represents substrate-film combinations for heterojunction analysis. + +### GrainBoundary +Represents crystallographic grain boundaries. + +### Spectrum +Represents spectroscopic data with methods for normalization and processing. + +**Key methods:** +- `normalize(mode="max")`: Normalize spectrum +- `smear(sigma)`: Apply Gaussian smearing + +## Best Practices + +1. **Immutability**: Use immutable versions (`IStructure`, `IMolecule`) when structures shouldn't be modified +2. **Serialization**: Prefer `as_dict()`/`from_dict()` over pickle for long-term storage +3. **Units**: Always work in atomic units (Å, eV) - conversions are available in `pymatgen.core.units` +4. **File I/O**: Use `from_file()` for automatic format detection +5. **Coordinates**: Pay attention to whether methods expect Cartesian or fractional coordinates diff --git a/scientific-packages/pymatgen/references/io_formats.md b/scientific-packages/pymatgen/references/io_formats.md new file mode 100644 index 0000000..2552c4c --- /dev/null +++ b/scientific-packages/pymatgen/references/io_formats.md @@ -0,0 +1,469 @@ +# Pymatgen I/O and File Format Reference + +This reference documents pymatgen's extensive input/output capabilities for reading and writing structural and computational data across 100+ file formats. + +## General I/O Philosophy + +Pymatgen provides a unified interface for file operations through the `from_file()` and `to()` methods, with automatic format detection based on file extensions. + +### Reading Files + +```python +from pymatgen.core import Structure, Molecule + +# Automatic format detection +struct = Structure.from_file("POSCAR") +struct = Structure.from_file("structure.cif") +mol = Molecule.from_file("molecule.xyz") + +# Explicit format specification +struct = Structure.from_file("file.txt", fmt="cif") +``` + +### Writing Files + +```python +# Write to file (format inferred from extension) +struct.to(filename="output.cif") +struct.to(filename="POSCAR") +struct.to(filename="structure.xyz") + +# Get string representation without writing +cif_string = struct.to(fmt="cif") +poscar_string = struct.to(fmt="poscar") +``` + +## Structure File Formats + +### CIF (Crystallographic Information File) +Standard format for crystallographic data. + +```python +from pymatgen.io.cif import CifParser, CifWriter + +# Reading +parser = CifParser("structure.cif") +structure = parser.get_structures()[0] # Returns list of structures + +# Writing +writer = CifWriter(struct) +writer.write_file("output.cif") + +# Or using convenience methods +struct = Structure.from_file("structure.cif") +struct.to(filename="output.cif") +``` + +**Key features:** +- Supports symmetry information +- Can contain multiple structures +- Preserves space group and symmetry operations +- Handles partial occupancies + +### POSCAR/CONTCAR (VASP) +VASP's structure format. + +```python +from pymatgen.io.vasp import Poscar + +# Reading +poscar = Poscar.from_file("POSCAR") +structure = poscar.structure + +# Writing +poscar = Poscar(struct) +poscar.write_file("POSCAR") + +# Or using convenience methods +struct = Structure.from_file("POSCAR") +struct.to(filename="POSCAR") +``` + +**Key features:** +- Supports selective dynamics +- Can include velocities (XDATCAR format) +- Preserves lattice and coordinate precision + +### XYZ +Simple molecular coordinates format. + +```python +# For molecules +mol = Molecule.from_file("molecule.xyz") +mol.to(filename="output.xyz") + +# For structures (Cartesian coordinates) +struct.to(filename="structure.xyz") +``` + +### PDB (Protein Data Bank) +Common format for biomolecules. + +```python +mol = Molecule.from_file("protein.pdb") +mol.to(filename="output.pdb") +``` + +### JSON/YAML +Serialization via dictionaries. + +```python +import json +import yaml + +# JSON +with open("structure.json", "w") as f: + json.dump(struct.as_dict(), f) + +with open("structure.json", "r") as f: + struct = Structure.from_dict(json.load(f)) + +# YAML +with open("structure.yaml", "w") as f: + yaml.dump(struct.as_dict(), f) + +with open("structure.yaml", "r") as f: + struct = Structure.from_dict(yaml.safe_load(f)) +``` + +## Electronic Structure Code I/O + +### VASP + +The most comprehensive integration in pymatgen. + +#### Input Files + +```python +from pymatgen.io.vasp.inputs import Incar, Poscar, Potcar, Kpoints, VaspInput + +# INCAR (calculation parameters) +incar = Incar.from_file("INCAR") +incar = Incar({"ENCUT": 520, "ISMEAR": 0, "SIGMA": 0.05}) +incar.write_file("INCAR") + +# KPOINTS (k-point mesh) +from pymatgen.io.vasp.inputs import Kpoints +kpoints = Kpoints.automatic(20) # 20x20x20 Gamma-centered mesh +kpoints = Kpoints.automatic_density(struct, 1000) # By density +kpoints.write_file("KPOINTS") + +# POTCAR (pseudopotentials) +potcar = Potcar(["Fe_pv", "O"]) # Specify functional variants + +# Complete input set +vasp_input = VaspInput(incar, kpoints, poscar, potcar) +vasp_input.write_input("./vasp_calc") +``` + +#### Output Files + +```python +from pymatgen.io.vasp.outputs import Vasprun, Outcar, Oszicar, Eigenval + +# vasprun.xml (comprehensive output) +vasprun = Vasprun("vasprun.xml") +final_structure = vasprun.final_structure +energy = vasprun.final_energy +band_structure = vasprun.get_band_structure() +dos = vasprun.complete_dos + +# OUTCAR +outcar = Outcar("OUTCAR") +magnetization = outcar.total_mag +elastic_tensor = outcar.elastic_tensor + +# OSZICAR (convergence information) +oszicar = Oszicar("OSZICAR") +``` + +#### Input Sets + +Pymatgen provides pre-configured input sets for common calculations: + +```python +from pymatgen.io.vasp.sets import ( + MPRelaxSet, # Materials Project relaxation + MPStaticSet, # Static calculation + MPNonSCFSet, # Non-self-consistent (band structure) + MPSOCSet, # Spin-orbit coupling + MPHSERelaxSet, # HSE06 hybrid functional +) + +# Create input set +relax = MPRelaxSet(struct) +relax.write_input("./relax_calc") + +# Customize parameters +static = MPStaticSet(struct, user_incar_settings={"ENCUT": 600}) +static.write_input("./static_calc") +``` + +### Gaussian + +Quantum chemistry package integration. + +```python +from pymatgen.io.gaussian import GaussianInput, GaussianOutput + +# Input +gin = GaussianInput( + mol, + charge=0, + spin_multiplicity=1, + functional="B3LYP", + basis_set="6-31G(d)", + route_parameters={"Opt": None, "Freq": None} +) +gin.write_file("input.gjf") + +# Output +gout = GaussianOutput("output.log") +final_mol = gout.final_structure +energy = gout.final_energy +frequencies = gout.frequencies +``` + +### LAMMPS + +Classical molecular dynamics. + +```python +from pymatgen.io.lammps.data import LammpsData +from pymatgen.io.lammps.inputs import LammpsInputFile + +# Structure to LAMMPS data file +lammps_data = LammpsData.from_structure(struct) +lammps_data.write_file("data.lammps") + +# LAMMPS input script +lammps_input = LammpsInputFile.from_file("in.lammps") +``` + +### Quantum ESPRESSO + +```python +from pymatgen.io.pwscf import PWInput, PWOutput + +# Input +pwin = PWInput( + struct, + control={"calculation": "scf"}, + system={"ecutwfc": 50, "ecutrho": 400}, + electrons={"conv_thr": 1e-8} +) +pwin.write_file("pw.in") + +# Output +pwout = PWOutput("pw.out") +final_structure = pwout.final_structure +energy = pwout.final_energy +``` + +### ABINIT + +```python +from pymatgen.io.abinit import AbinitInput + +abin = AbinitInput(struct, pseudos) +abin.set_vars(ecut=10, nband=10) +abin.write("abinit.in") +``` + +### CP2K + +```python +from pymatgen.io.cp2k.inputs import Cp2kInput +from pymatgen.io.cp2k.outputs import Cp2kOutput + +# Input +cp2k_input = Cp2kInput.from_file("cp2k.inp") + +# Output +cp2k_output = Cp2kOutput("cp2k.out") +``` + +### FEFF (XAS/XANES) + +```python +from pymatgen.io.feff import FeffInput + +feff_input = FeffInput(struct, absorbing_atom="Fe") +feff_input.write_file("feff.inp") +``` + +### LMTO (Stuttgart TB-LMTO-ASA) + +```python +from pymatgen.io.lmto import LMTOCtrl + +ctrl = LMTOCtrl.from_file("CTRL") +ctrl.structure +``` + +### Q-Chem + +```python +from pymatgen.io.qchem.inputs import QCInput +from pymatgen.io.qchem.outputs import QCOutput + +# Input +qc_input = QCInput( + mol, + rem={"method": "B3LYP", "basis": "6-31G*", "job_type": "opt"} +) +qc_input.write_file("mol.qin") + +# Output +qc_output = QCOutput("mol.qout") +``` + +### Exciting + +```python +from pymatgen.io.exciting import ExcitingInput + +exc_input = ExcitingInput(struct) +exc_input.write_file("input.xml") +``` + +### ATAT (Alloy Theoretic Automated Toolkit) + +```python +from pymatgen.io.atat import Mcsqs + +mcsqs = Mcsqs(struct) +mcsqs.write_input(".") +``` + +## Special Purpose Formats + +### Phonopy + +```python +from pymatgen.io.phonopy import get_phonopy_structure, get_pmg_structure + +# Convert to phonopy structure +phonopy_struct = get_phonopy_structure(struct) + +# Convert from phonopy +struct = get_pmg_structure(phonopy_struct) +``` + +### ASE (Atomic Simulation Environment) + +```python +from pymatgen.io.ase import AseAtomsAdaptor + +adaptor = AseAtomsAdaptor() + +# Pymatgen to ASE +atoms = adaptor.get_atoms(struct) + +# ASE to Pymatgen +struct = adaptor.get_structure(atoms) +``` + +### Zeo++ (Porous Materials) + +```python +from pymatgen.io.zeopp import get_voronoi_nodes, get_high_accuracy_voronoi_nodes + +# Analyze pore structure +vor_nodes = get_voronoi_nodes(struct) +``` + +### BabelMolAdaptor (OpenBabel) + +```python +from pymatgen.io.babel import BabelMolAdaptor + +adaptor = BabelMolAdaptor(mol) + +# Convert to different formats +pdb_str = adaptor.pdbstring +sdf_str = adaptor.write_file("mol.sdf", file_format="sdf") + +# Generate 3D coordinates +adaptor.add_hydrogen() +adaptor.make3d() +``` + +## Alchemy and Transformation I/O + +### TransformedStructure + +Structures that track their transformation history. + +```python +from pymatgen.alchemy.materials import TransformedStructure +from pymatgen.transformations.standard_transformations import ( + SupercellTransformation, + SubstitutionTransformation +) + +# Create transformed structure +ts = TransformedStructure(struct, []) +ts.append_transformation(SupercellTransformation([[2,0,0],[0,2,0],[0,0,2]])) +ts.append_transformation(SubstitutionTransformation({"Fe": "Mn"})) + +# Write with history +ts.write_vasp_input("./calc_dir") + +# Read from SNL (Structure Notebook Language) +ts = TransformedStructure.from_snl(snl) +``` + +## Batch Operations + +### CifTransmuter + +Process multiple CIF files. + +```python +from pymatgen.alchemy.transmuters import CifTransmuter + +transmuter = CifTransmuter.from_filenames( + ["structure1.cif", "structure2.cif"], + [SupercellTransformation([[2,0,0],[0,2,0],[0,0,2]])] +) + +# Write all structures +transmuter.write_vasp_input("./batch_calc") +``` + +### PoscarTransmuter + +Similar for POSCAR files. + +```python +from pymatgen.alchemy.transmuters import PoscarTransmuter + +transmuter = PoscarTransmuter.from_filenames( + ["POSCAR1", "POSCAR2"], + [transformation1, transformation2] +) +``` + +## Best Practices + +1. **Automatic format detection**: Use `from_file()` and `to()` methods whenever possible +2. **Error handling**: Always wrap file I/O in try-except blocks +3. **Format-specific parsers**: Use specialized parsers (e.g., `Vasprun`) for detailed output analysis +4. **Input sets**: Prefer pre-configured input sets over manual parameter specification +5. **Serialization**: Use JSON/YAML for long-term storage and version control +6. **Batch processing**: Use transmuters for applying transformations to multiple structures + +## Supported Format Summary + +### Structure formats: +CIF, POSCAR/CONTCAR, XYZ, PDB, XSF, PWMAT, Res, CSSR, JSON, YAML + +### Electronic structure codes: +VASP, Gaussian, LAMMPS, Quantum ESPRESSO, ABINIT, CP2K, FEFF, Q-Chem, LMTO, Exciting, NWChem, AIMS, Crystallographic data formats + +### Molecular formats: +XYZ, PDB, MOL, SDF, PQR, via OpenBabel (many additional formats) + +### Special purpose: +Phonopy, ASE, Zeo++, Lobster, BoltzTraP diff --git a/scientific-packages/pymatgen/references/materials_project_api.md b/scientific-packages/pymatgen/references/materials_project_api.md new file mode 100644 index 0000000..1dc621c --- /dev/null +++ b/scientific-packages/pymatgen/references/materials_project_api.md @@ -0,0 +1,517 @@ +# Materials Project API Reference + +This reference documents how to access and use the Materials Project database through pymatgen's API integration. + +## Overview + +The Materials Project is a comprehensive database of computed materials properties, containing data on hundreds of thousands of inorganic crystals and molecules. The API provides programmatic access to this data through the `MPRester` client. + +## Installation and Setup + +The Materials Project API client is now in a separate package: + +```bash +pip install mp-api +``` + +### Getting an API Key + +1. Visit https://next-gen.materialsproject.org/ +2. Create an account or log in +3. Navigate to your dashboard/settings +4. Generate an API key +5. Store it as an environment variable: + +```bash +export MP_API_KEY="your_api_key_here" +``` + +Or add to your shell configuration file (~/.bashrc, ~/.zshrc, etc.) + +## Basic Usage + +### Initialization + +```python +from mp_api.client import MPRester + +# Using environment variable (recommended) +with MPRester() as mpr: + # Perform queries + pass + +# Or explicitly pass API key +with MPRester("your_api_key_here") as mpr: + # Perform queries + pass +``` + +**Important**: Always use the `with` context manager to ensure sessions are properly closed. + +## Querying Materials Data + +### Search by Formula + +```python +with MPRester() as mpr: + # Get all materials with formula + materials = mpr.materials.summary.search(formula="Fe2O3") + + for mat in materials: + print(f"Material ID: {mat.material_id}") + print(f"Formula: {mat.formula_pretty}") + print(f"Energy above hull: {mat.energy_above_hull} eV/atom") + print(f"Band gap: {mat.band_gap} eV") + print() +``` + +### Search by Material ID + +```python +with MPRester() as mpr: + # Get specific material + material = mpr.materials.summary.search(material_ids=["mp-149"])[0] + + print(f"Formula: {material.formula_pretty}") + print(f"Space group: {material.symmetry.symbol}") + print(f"Density: {material.density} g/cm³") +``` + +### Search by Chemical System + +```python +with MPRester() as mpr: + # Get all materials in Fe-O system + materials = mpr.materials.summary.search(chemsys="Fe-O") + + # Get materials in ternary system + materials = mpr.materials.summary.search(chemsys="Li-Fe-O") +``` + +### Search by Elements + +```python +with MPRester() as mpr: + # Materials containing Fe and O + materials = mpr.materials.summary.search(elements=["Fe", "O"]) + + # Materials containing ONLY Fe and O (excluding others) + materials = mpr.materials.summary.search( + elements=["Fe", "O"], + exclude_elements=True + ) +``` + +## Getting Structures + +### Structure from Material ID + +```python +with MPRester() as mpr: + # Get structure + structure = mpr.get_structure_by_material_id("mp-149") + + # Get multiple structures + structures = mpr.get_structures(["mp-149", "mp-510", "mp-19017"]) +``` + +### All Structures for a Formula + +```python +with MPRester() as mpr: + # Get all Fe2O3 structures + materials = mpr.materials.summary.search(formula="Fe2O3") + + for mat in materials: + structure = mpr.get_structure_by_material_id(mat.material_id) + print(f"{mat.material_id}: {structure.get_space_group_info()}") +``` + +## Advanced Queries + +### Property Filtering + +```python +with MPRester() as mpr: + # Materials with specific property ranges + materials = mpr.materials.summary.search( + chemsys="Li-Fe-O", + energy_above_hull=(0, 0.05), # Stable or near-stable + band_gap=(1.0, 3.0), # Semiconducting + ) + + # Magnetic materials + materials = mpr.materials.summary.search( + elements=["Fe"], + is_magnetic=True + ) + + # Metals only + materials = mpr.materials.summary.search( + chemsys="Fe-Ni", + is_metal=True + ) +``` + +### Sorting and Limiting + +```python +with MPRester() as mpr: + # Get most stable materials + materials = mpr.materials.summary.search( + chemsys="Li-Fe-O", + sort_fields=["energy_above_hull"], + num_chunks=1, + chunk_size=10 # Limit to 10 results + ) +``` + +## Electronic Structure Data + +### Band Structure + +```python +with MPRester() as mpr: + # Get band structure + bs = mpr.get_bandstructure_by_material_id("mp-149") + + # Analyze band structure + if bs: + print(f"Band gap: {bs.get_band_gap()}") + print(f"Is metal: {bs.is_metal()}") + print(f"Direct gap: {bs.get_band_gap()['direct']}") + + # Plot + from pymatgen.electronic_structure.plotter import BSPlotter + plotter = BSPlotter(bs) + plotter.show() +``` + +### Density of States + +```python +with MPRester() as mpr: + # Get DOS + dos = mpr.get_dos_by_material_id("mp-149") + + if dos: + # Get band gap from DOS + gap = dos.get_gap() + print(f"Band gap from DOS: {gap} eV") + + # Plot DOS + from pymatgen.electronic_structure.plotter import DosPlotter + plotter = DosPlotter() + plotter.add_dos("Total DOS", dos) + plotter.show() +``` + +### Fermi Surface + +```python +with MPRester() as mpr: + # Get electronic structure data for Fermi surface + bs = mpr.get_bandstructure_by_material_id("mp-149", line_mode=False) +``` + +## Thermodynamic Data + +### Phase Diagram Construction + +```python +from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter + +with MPRester() as mpr: + # Get entries for phase diagram + entries = mpr.get_entries_in_chemsys("Li-Fe-O") + + # Build phase diagram + pd = PhaseDiagram(entries) + + # Plot + plotter = PDPlotter(pd) + plotter.show() +``` + +### Pourbaix Diagram + +```python +from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter + +with MPRester() as mpr: + # Get entries for Pourbaix diagram + entries = mpr.get_pourbaix_entries(["Fe"]) + + # Build Pourbaix diagram + pb = PourbaixDiagram(entries) + + # Plot + plotter = PourbaixPlotter(pb) + plotter.show() +``` + +### Formation Energy + +```python +with MPRester() as mpr: + materials = mpr.materials.summary.search(material_ids=["mp-149"]) + + for mat in materials: + print(f"Formation energy: {mat.formation_energy_per_atom} eV/atom") + print(f"Energy above hull: {mat.energy_above_hull} eV/atom") +``` + +## Elasticity and Mechanical Properties + +```python +with MPRester() as mpr: + # Search for materials with elastic data + materials = mpr.materials.elasticity.search( + chemsys="Fe-O", + bulk_modulus_vrh=(100, 300) # GPa + ) + + for mat in materials: + print(f"{mat.material_id}: K = {mat.bulk_modulus_vrh} GPa") +``` + +## Dielectric Properties + +```python +with MPRester() as mpr: + # Get dielectric data + materials = mpr.materials.dielectric.search( + material_ids=["mp-149"] + ) + + for mat in materials: + print(f"Dielectric constant: {mat.e_electronic}") + print(f"Refractive index: {mat.n}") +``` + +## Piezoelectric Properties + +```python +with MPRester() as mpr: + # Get piezoelectric materials + materials = mpr.materials.piezoelectric.search( + piezoelectric_modulus=(1, 100) + ) +``` + +## Surface Properties + +```python +with MPRester() as mpr: + # Get surface data + surfaces = mpr.materials.surface_properties.search( + material_ids=["mp-149"] + ) +``` + +## Molecule Data (For Molecular Materials) + +```python +with MPRester() as mpr: + # Search molecules + molecules = mpr.molecules.summary.search( + formula="H2O" + ) + + for mol in molecules: + print(f"Molecule ID: {mol.molecule_id}") + print(f"Formula: {mol.formula_pretty}") +``` + +## Bulk Data Download + +### Download All Data for Materials + +```python +with MPRester() as mpr: + # Get comprehensive data + materials = mpr.materials.summary.search( + material_ids=["mp-149"], + fields=[ + "material_id", + "formula_pretty", + "structure", + "energy_above_hull", + "band_gap", + "density", + "symmetry", + "elasticity", + "magnetic_ordering" + ] + ) +``` + +## Provenance and Calculation Details + +```python +with MPRester() as mpr: + # Get calculation details + materials = mpr.materials.summary.search( + material_ids=["mp-149"], + fields=["material_id", "origins"] + ) + + for mat in materials: + print(f"Origins: {mat.origins}") +``` + +## Working with Entries + +### ComputedEntry for Thermodynamic Analysis + +```python +with MPRester() as mpr: + # Get entries (includes energy and composition) + entries = mpr.get_entries_in_chemsys("Li-Fe-O") + + # Entries can be used directly in phase diagram analysis + from pymatgen.analysis.phase_diagram import PhaseDiagram + pd = PhaseDiagram(entries) + + # Check stability + for entry in entries[:5]: + e_above_hull = pd.get_e_above_hull(entry) + print(f"{entry.composition.reduced_formula}: {e_above_hull:.3f} eV/atom") +``` + +## Rate Limiting and Best Practices + +### Rate Limits + +The Materials Project API has rate limits to ensure fair usage: +- Be mindful of request frequency +- Use batch queries when possible +- Cache results locally for repeated analysis + +### Efficient Querying + +```python +# Bad: Multiple separate queries +with MPRester() as mpr: + for mp_id in ["mp-149", "mp-510", "mp-19017"]: + struct = mpr.get_structure_by_material_id(mp_id) # 3 API calls + +# Good: Single batch query +with MPRester() as mpr: + structs = mpr.get_structures(["mp-149", "mp-510", "mp-19017"]) # 1 API call +``` + +### Caching Results + +```python +import json + +# Save results for later use +with MPRester() as mpr: + materials = mpr.materials.summary.search(chemsys="Li-Fe-O") + + # Save to file + with open("li_fe_o_materials.json", "w") as f: + json.dump([mat.dict() for mat in materials], f) + +# Load cached results +with open("li_fe_o_materials.json", "r") as f: + cached_data = json.load(f) +``` + +## Error Handling + +```python +from mp_api.client.core.client import MPRestError + +try: + with MPRester() as mpr: + materials = mpr.materials.summary.search(material_ids=["invalid-id"]) +except MPRestError as e: + print(f"API Error: {e}") +except Exception as e: + print(f"Unexpected error: {e}") +``` + +## Common Use Cases + +### Finding Stable Compounds + +```python +with MPRester() as mpr: + # Get all stable compounds in a chemical system + materials = mpr.materials.summary.search( + chemsys="Li-Fe-O", + energy_above_hull=(0, 0.001) # Essentially on convex hull + ) + + print(f"Found {len(materials)} stable compounds") + for mat in materials: + print(f" {mat.formula_pretty} ({mat.material_id})") +``` + +### Battery Material Screening + +```python +with MPRester() as mpr: + # Screen for potential cathode materials + materials = mpr.materials.summary.search( + elements=["Li"], # Must contain Li + energy_above_hull=(0, 0.05), # Near stable + band_gap=(0, 0.5), # Metallic or small gap + ) + + print(f"Found {len(materials)} potential cathode materials") +``` + +### Finding Materials with Specific Crystal Structure + +```python +with MPRester() as mpr: + # Find materials with specific space group + materials = mpr.materials.summary.search( + chemsys="Fe-O", + spacegroup_number=167 # R-3c (corundum structure) + ) +``` + +## Integration with Other Pymatgen Features + +All data retrieved from the Materials Project can be directly used with pymatgen's analysis tools: + +```python +with MPRester() as mpr: + # Get structure + struct = mpr.get_structure_by_material_id("mp-149") + + # Use with pymatgen analysis + from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + sga = SpacegroupAnalyzer(struct) + + # Generate surfaces + from pymatgen.core.surface import SlabGenerator + slabgen = SlabGenerator(struct, (1,0,0), 10, 10) + slabs = slabgen.get_slabs() + + # Phase diagram analysis + entries = mpr.get_entries_in_chemsys(struct.composition.chemical_system) + from pymatgen.analysis.phase_diagram import PhaseDiagram + pd = PhaseDiagram(entries) +``` + +## Additional Resources + +- **API Documentation**: https://docs.materialsproject.org/ +- **Materials Project Website**: https://next-gen.materialsproject.org/ +- **GitHub**: https://github.com/materialsproject/api +- **Forum**: https://matsci.org/ + +## Best Practices Summary + +1. **Always use context manager**: Use `with MPRester() as mpr:` +2. **Store API key as environment variable**: Never hardcode API keys +3. **Batch queries**: Request multiple items at once when possible +4. **Cache results**: Save frequently used data locally +5. **Handle errors**: Wrap API calls in try-except blocks +6. **Be specific**: Use filters to limit results and reduce data transfer +7. **Check data availability**: Not all properties are available for all materials diff --git a/scientific-packages/pymatgen/references/transformations_workflows.md b/scientific-packages/pymatgen/references/transformations_workflows.md new file mode 100644 index 0000000..745509b --- /dev/null +++ b/scientific-packages/pymatgen/references/transformations_workflows.md @@ -0,0 +1,591 @@ +# Pymatgen Transformations and Common Workflows + +This reference documents pymatgen's transformation framework and provides recipes for common materials science workflows. + +## Transformation Framework + +Transformations provide a systematic way to modify structures while tracking the history of modifications. + +### Standard Transformations + +Located in `pymatgen.transformations.standard_transformations`. + +#### SupercellTransformation + +Create supercells with arbitrary scaling matrices. + +```python +from pymatgen.transformations.standard_transformations import SupercellTransformation + +# Simple 2x2x2 supercell +trans = SupercellTransformation([[2,0,0], [0,2,0], [0,0,2]]) +new_struct = trans.apply_transformation(struct) + +# Non-orthogonal supercell +trans = SupercellTransformation([[2,1,0], [0,2,0], [0,0,2]]) +new_struct = trans.apply_transformation(struct) +``` + +#### SubstitutionTransformation + +Replace species in a structure. + +```python +from pymatgen.transformations.standard_transformations import SubstitutionTransformation + +# Replace all Fe with Mn +trans = SubstitutionTransformation({"Fe": "Mn"}) +new_struct = trans.apply_transformation(struct) + +# Partial substitution (50% Fe -> Mn) +trans = SubstitutionTransformation({"Fe": {"Mn": 0.5, "Fe": 0.5}}) +new_struct = trans.apply_transformation(struct) +``` + +#### RemoveSpeciesTransformation + +Remove specific species from structure. + +```python +from pymatgen.transformations.standard_transformations import RemoveSpeciesTransformation + +trans = RemoveSpeciesTransformation(["H"]) # Remove all hydrogen +new_struct = trans.apply_transformation(struct) +``` + +#### OrderDisorderedStructureTransformation + +Order disordered structures with partial occupancies. + +```python +from pymatgen.transformations.standard_transformations import OrderDisorderedStructureTransformation + +trans = OrderDisorderedStructureTransformation() +new_struct = trans.apply_transformation(disordered_struct) +``` + +#### PrimitiveCellTransformation + +Convert to primitive cell. + +```python +from pymatgen.transformations.standard_transformations import PrimitiveCellTransformation + +trans = PrimitiveCellTransformation() +primitive_struct = trans.apply_transformation(struct) +``` + +#### ConventionalCellTransformation + +Convert to conventional cell. + +```python +from pymatgen.transformations.standard_transformations import ConventionalCellTransformation + +trans = ConventionalCellTransformation() +conventional_struct = trans.apply_transformation(struct) +``` + +#### RotationTransformation + +Rotate structure. + +```python +from pymatgen.transformations.standard_transformations import RotationTransformation + +# Rotate by axis and angle +trans = RotationTransformation([0, 0, 1], 45) # 45° around z-axis +new_struct = trans.apply_transformation(struct) +``` + +#### ScaleToRelaxedTransformation + +Scale lattice to match a relaxed structure. + +```python +from pymatgen.transformations.standard_transformations import ScaleToRelaxedTransformation + +trans = ScaleToRelaxedTransformation(relaxed_struct) +scaled_struct = trans.apply_transformation(unrelaxed_struct) +``` + +### Advanced Transformations + +Located in `pymatgen.transformations.advanced_transformations`. + +#### EnumerateStructureTransformation + +Enumerate all symmetrically distinct ordered structures from a disordered structure. + +```python +from pymatgen.transformations.advanced_transformations import EnumerateStructureTransformation + +# Enumerate structures up to max 8 atoms per unit cell +trans = EnumerateStructureTransformation(max_cell_size=8) +structures = trans.apply_transformation(struct, return_ranked_list=True) + +# Returns list of ranked structures +for s in structures[:5]: # Top 5 structures + print(f"Energy: {s['energy']}, Structure: {s['structure']}") +``` + +#### MagOrderingTransformation + +Enumerate magnetic orderings. + +```python +from pymatgen.transformations.advanced_transformations import MagOrderingTransformation + +# Specify magnetic moments for each species +trans = MagOrderingTransformation({"Fe": 5.0, "Ni": 2.0}) +mag_structures = trans.apply_transformation(struct, return_ranked_list=True) +``` + +#### DopingTransformation + +Systematically dope a structure. + +```python +from pymatgen.transformations.advanced_transformations import DopingTransformation + +# Replace 12.5% of Fe sites with Mn +trans = DopingTransformation("Mn", min_length=10) +doped_structs = trans.apply_transformation(struct, return_ranked_list=True) +``` + +#### ChargeBalanceTransformation + +Balance charge in a structure by oxidation state manipulation. + +```python +from pymatgen.transformations.advanced_transformations import ChargeBalanceTransformation + +trans = ChargeBalanceTransformation("Li") +charged_struct = trans.apply_transformation(struct) +``` + +#### SlabTransformation + +Generate surface slabs. + +```python +from pymatgen.transformations.advanced_transformations import SlabTransformation + +trans = SlabTransformation( + miller_index=[1, 0, 0], + min_slab_size=10, + min_vacuum_size=10, + shift=0, + lll_reduce=True +) +slab = trans.apply_transformation(struct) +``` + +### Chaining Transformations + +```python +from pymatgen.alchemy.materials import TransformedStructure + +# Create transformed structure that tracks history +ts = TransformedStructure(struct, []) + +# Apply multiple transformations +ts.append_transformation(SupercellTransformation([[2,0,0],[0,2,0],[0,0,2]])) +ts.append_transformation(SubstitutionTransformation({"Fe": "Mn"})) +ts.append_transformation(PrimitiveCellTransformation()) + +# Get final structure +final_struct = ts.final_structure + +# View transformation history +print(ts.history) +``` + +## Common Workflows + +### Workflow 1: High-Throughput Structure Generation + +Generate multiple structures for screening studies. + +```python +from pymatgen.core import Structure +from pymatgen.transformations.standard_transformations import ( + SubstitutionTransformation, + SupercellTransformation +) +from pymatgen.io.vasp.sets import MPRelaxSet + +# Starting structure +base_struct = Structure.from_file("POSCAR") + +# Define substitutions +dopants = ["Mn", "Co", "Ni", "Cu"] +structures = {} + +for dopant in dopants: + # Create substituted structure + trans = SubstitutionTransformation({"Fe": dopant}) + new_struct = trans.apply_transformation(base_struct) + + # Generate VASP inputs + vasp_input = MPRelaxSet(new_struct) + vasp_input.write_input(f"./calcs/Fe_{dopant}") + + structures[dopant] = new_struct + +print(f"Generated {len(structures)} structures") +``` + +### Workflow 2: Phase Diagram Construction + +Build and analyze phase diagrams from Materials Project data. + +```python +from mp_api.client import MPRester +from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter +from pymatgen.core import Composition + +# Get data from Materials Project +with MPRester() as mpr: + entries = mpr.get_entries_in_chemsys("Li-Fe-O") + +# Build phase diagram +pd = PhaseDiagram(entries) + +# Analyze specific composition +comp = Composition("LiFeO2") +e_above_hull = pd.get_e_above_hull(entries[0]) + +# Get decomposition products +decomp = pd.get_decomposition(comp) +print(f"Decomposition: {decomp}") + +# Visualize +plotter = PDPlotter(pd) +plotter.show() +``` + +### Workflow 3: Surface Energy Calculation + +Calculate surface energies from slab calculations. + +```python +from pymatgen.core.surface import SlabGenerator, generate_all_slabs +from pymatgen.io.vasp.sets import MPStaticSet, MPRelaxSet +from pymatgen.core import Structure + +# Read bulk structure +bulk = Structure.from_file("bulk_POSCAR") + +# Get bulk energy (from previous calculation) +from pymatgen.io.vasp import Vasprun +bulk_vasprun = Vasprun("bulk/vasprun.xml") +bulk_energy_per_atom = bulk_vasprun.final_energy / len(bulk) + +# Generate slabs +miller_indices = [(1,0,0), (1,1,0), (1,1,1)] +surface_energies = {} + +for miller in miller_indices: + slabgen = SlabGenerator( + bulk, + miller_index=miller, + min_slab_size=10, + min_vacuum_size=15, + center_slab=True + ) + + slab = slabgen.get_slabs()[0] + + # Write VASP input for slab + relax = MPRelaxSet(slab) + relax.write_input(f"./slab_{miller[0]}{miller[1]}{miller[2]}") + + # After calculation, compute surface energy: + # slab_vasprun = Vasprun(f"slab_{miller[0]}{miller[1]}{miller[2]}/vasprun.xml") + # slab_energy = slab_vasprun.final_energy + # n_atoms = len(slab) + # area = slab.surface_area # in Ų + # + # # Surface energy (J/m²) + # surf_energy = (slab_energy - n_atoms * bulk_energy_per_atom) / (2 * area) + # surf_energy *= 16.021766 # Convert eV/Ų to J/m² + # surface_energies[miller] = surf_energy + +print(f"Set up calculations for {len(miller_indices)} surfaces") +``` + +### Workflow 4: Band Structure Calculation + +Complete workflow for band structure calculations. + +```python +from pymatgen.core import Structure +from pymatgen.io.vasp.sets import MPRelaxSet, MPStaticSet, MPNonSCFSet +from pymatgen.symmetry.bandstructure import HighSymmKpath + +# Step 1: Relaxation +struct = Structure.from_file("initial_POSCAR") +relax = MPRelaxSet(struct) +relax.write_input("./1_relax") + +# After relaxation, read structure +relaxed_struct = Structure.from_file("1_relax/CONTCAR") + +# Step 2: Static calculation +static = MPStaticSet(relaxed_struct) +static.write_input("./2_static") + +# Step 3: Band structure (non-self-consistent) +kpath = HighSymmKpath(relaxed_struct) +nscf = MPNonSCFSet(relaxed_struct, mode="line") # Band structure mode +nscf.write_input("./3_bandstructure") + +# After calculations, analyze +from pymatgen.io.vasp import Vasprun +from pymatgen.electronic_structure.plotter import BSPlotter + +vasprun = Vasprun("3_bandstructure/vasprun.xml") +bs = vasprun.get_band_structure(line_mode=True) + +print(f"Band gap: {bs.get_band_gap()}") + +plotter = BSPlotter(bs) +plotter.save_plot("band_structure.png") +``` + +### Workflow 5: Molecular Dynamics Setup + +Set up and analyze molecular dynamics simulations. + +```python +from pymatgen.core import Structure +from pymatgen.io.vasp.sets import MVLRelaxSet +from pymatgen.io.vasp.inputs import Incar + +# Read structure +struct = Structure.from_file("POSCAR") + +# Create 2x2x2 supercell for MD +from pymatgen.transformations.standard_transformations import SupercellTransformation +trans = SupercellTransformation([[2,0,0],[0,2,0],[0,0,2]]) +supercell = trans.apply_transformation(struct) + +# Set up VASP input +md_input = MVLRelaxSet(supercell) + +# Modify INCAR for MD +incar = md_input.incar +incar.update({ + "IBRION": 0, # Molecular dynamics + "NSW": 1000, # Number of steps + "POTIM": 2, # Time step (fs) + "TEBEG": 300, # Initial temperature (K) + "TEEND": 300, # Final temperature (K) + "SMASS": 0, # NVT ensemble + "MDALGO": 2, # Nose-Hoover thermostat +}) + +md_input.incar = incar +md_input.write_input("./md_calc") +``` + +### Workflow 6: Diffusion Analysis + +Analyze ion diffusion from AIMD trajectories. + +```python +from pymatgen.io.vasp import Xdatcar +from pymatgen.analysis.diffusion.analyzer import DiffusionAnalyzer + +# Read trajectory from XDATCAR +xdatcar = Xdatcar("XDATCAR") +structures = xdatcar.structures + +# Analyze diffusion for specific species (e.g., Li) +analyzer = DiffusionAnalyzer.from_structures( + structures, + specie="Li", + temperature=300, # K + time_step=2, # fs + step_skip=10 # Skip initial equilibration +) + +# Get diffusivity +diffusivity = analyzer.diffusivity # cm²/s +conductivity = analyzer.conductivity # mS/cm + +# Get mean squared displacement +msd = analyzer.msd + +# Plot MSD +analyzer.plot_msd() + +print(f"Diffusivity: {diffusivity:.2e} cm²/s") +print(f"Conductivity: {conductivity:.2e} mS/cm") +``` + +### Workflow 7: Structure Prediction and Enumeration + +Predict and enumerate possible structures. + +```python +from pymatgen.core import Structure, Lattice +from pymatgen.transformations.advanced_transformations import ( + EnumerateStructureTransformation, + SubstitutionTransformation +) + +# Start with a known structure type (e.g., rocksalt) +lattice = Lattice.cubic(4.2) +struct = Structure.from_spacegroup("Fm-3m", lattice, ["Li", "O"], [[0,0,0], [0.5,0.5,0.5]]) + +# Create disordered structure +from pymatgen.core import Species +species_on_site = {Species("Li"): 0.5, Species("Na"): 0.5} +struct[0] = species_on_site # Mixed occupancy on Li site + +# Enumerate all ordered structures +trans = EnumerateStructureTransformation(max_cell_size=4) +ordered_structs = trans.apply_transformation(struct, return_ranked_list=True) + +print(f"Found {len(ordered_structs)} distinct ordered structures") + +# Write all structures +for i, s_dict in enumerate(ordered_structs[:10]): # Top 10 + s_dict['structure'].to(filename=f"ordered_struct_{i}.cif") +``` + +### Workflow 8: Elastic Constant Calculation + +Calculate elastic constants using the stress-strain method. + +```python +from pymatgen.core import Structure +from pymatgen.transformations.standard_transformations import DeformStructureTransformation +from pymatgen.io.vasp.sets import MPStaticSet + +# Read equilibrium structure +struct = Structure.from_file("relaxed_POSCAR") + +# Generate deformed structures +strains = [0.00, 0.01, 0.02, -0.01, -0.02] # Applied strains +deformation_sets = [] + +for strain in strains: + # Apply strain in different directions + trans = DeformStructureTransformation([[1+strain, 0, 0], [0, 1, 0], [0, 0, 1]]) + deformed = trans.apply_transformation(struct) + + # Set up VASP calculation + static = MPStaticSet(deformed) + static.write_input(f"./strain_{strain:.2f}") + +# After calculations, fit stress vs strain to get elastic constants +# from pymatgen.analysis.elasticity import ElasticTensor +# ... (collect stress tensors from OUTCAR) +# elastic_tensor = ElasticTensor.from_stress_list(stress_list) +``` + +### Workflow 9: Adsorption Energy Calculation + +Calculate adsorption energies on surfaces. + +```python +from pymatgen.core import Structure, Molecule +from pymatgen.core.surface import SlabGenerator +from pymatgen.analysis.adsorption import AdsorbateSiteFinder +from pymatgen.io.vasp.sets import MPRelaxSet + +# Generate slab +bulk = Structure.from_file("bulk_POSCAR") +slabgen = SlabGenerator(bulk, (1,1,1), 10, 10) +slab = slabgen.get_slabs()[0] + +# Find adsorption sites +asf = AdsorbateSiteFinder(slab) +ads_sites = asf.find_adsorption_sites() + +# Create adsorbate +adsorbate = Molecule("O", [[0, 0, 0]]) + +# Generate structures with adsorbate +ads_structs = asf.add_adsorbate(adsorbate, ads_sites["ontop"][0]) + +# Set up calculations +relax_slab = MPRelaxSet(slab) +relax_slab.write_input("./slab") + +relax_ads = MPRelaxSet(ads_structs) +relax_ads.write_input("./slab_with_adsorbate") + +# After calculations: +# E_ads = E(slab+adsorbate) - E(slab) - E(adsorbate_gas) +``` + +### Workflow 10: High-Throughput Materials Screening + +Screen materials database for specific properties. + +```python +from mp_api.client import MPRester +from pymatgen.core import Structure +import pandas as pd + +# Define screening criteria +def screen_material(material): + """Screen for potential battery cathode materials""" + criteria = { + "has_li": "Li" in material.composition.elements, + "stable": material.energy_above_hull < 0.05, + "good_voltage": 2.5 < material.formation_energy_per_atom < 4.5, + "electronically_conductive": material.band_gap < 0.5 + } + return all(criteria.values()), criteria + +# Query Materials Project +with MPRester() as mpr: + # Get potential materials + materials = mpr.materials.summary.search( + elements=["Li"], + energy_above_hull=(0, 0.05), + ) + + results = [] + for mat in materials: + passes, criteria = screen_material(mat) + if passes: + results.append({ + "material_id": mat.material_id, + "formula": mat.formula_pretty, + "energy_above_hull": mat.energy_above_hull, + "band_gap": mat.band_gap, + }) + + # Save results + df = pd.DataFrame(results) + df.to_csv("screened_materials.csv", index=False) + + print(f"Found {len(results)} promising materials") +``` + +## Best Practices for Workflows + +1. **Modular design**: Break workflows into discrete steps +2. **Error handling**: Check file existence and calculation convergence +3. **Documentation**: Track transformation history using `TransformedStructure` +4. **Version control**: Store input parameters and scripts in git +5. **Automation**: Use workflow managers (Fireworks, AiiDA) for complex pipelines +6. **Data management**: Organize calculations in clear directory structures +7. **Validation**: Always validate intermediate results before proceeding + +## Integration with Workflow Tools + +Pymatgen integrates with several workflow management systems: + +- **Atomate**: Pre-built VASP workflows +- **Fireworks**: Workflow execution engine +- **AiiDA**: Provenance tracking and workflow management +- **Custodian**: Error correction and job monitoring + +These tools provide robust automation for production calculations. diff --git a/scientific-packages/pymatgen/scripts/phase_diagram_generator.py b/scientific-packages/pymatgen/scripts/phase_diagram_generator.py new file mode 100644 index 0000000..ba2902f --- /dev/null +++ b/scientific-packages/pymatgen/scripts/phase_diagram_generator.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +""" +Phase diagram generator using Materials Project data. + +This script generates phase diagrams for chemical systems using data from the +Materials Project database via pymatgen's MPRester. + +Usage: + python phase_diagram_generator.py chemical_system [options] + +Examples: + python phase_diagram_generator.py Li-Fe-O + python phase_diagram_generator.py Li-Fe-O --output li_fe_o_pd.png + python phase_diagram_generator.py Fe-O --show + python phase_diagram_generator.py Li-Fe-O --analyze "LiFeO2" +""" + +import argparse +import os +import sys +from pathlib import Path + +try: + from pymatgen.core import Composition + from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter +except ImportError: + print("Error: pymatgen is not installed. Install with: pip install pymatgen") + sys.exit(1) + +try: + from mp_api.client import MPRester +except ImportError: + print("Error: mp-api is not installed. Install with: pip install mp-api") + sys.exit(1) + + +def get_api_key() -> str: + """Get Materials Project API key from environment.""" + api_key = os.environ.get("MP_API_KEY") + if not api_key: + print("Error: MP_API_KEY environment variable not set.") + print("Get your API key from https://next-gen.materialsproject.org/") + print("Then set it with: export MP_API_KEY='your_key_here'") + sys.exit(1) + return api_key + + +def generate_phase_diagram(chemsys: str, args): + """ + Generate and analyze phase diagram for a chemical system. + + Args: + chemsys: Chemical system (e.g., "Li-Fe-O") + args: Command line arguments + """ + api_key = get_api_key() + + print(f"\n{'='*60}") + print(f"PHASE DIAGRAM: {chemsys}") + print(f"{'='*60}\n") + + # Get entries from Materials Project + print("Fetching data from Materials Project...") + with MPRester(api_key) as mpr: + entries = mpr.get_entries_in_chemsys(chemsys) + + print(f"✓ Retrieved {len(entries)} entries") + + if len(entries) == 0: + print(f"Error: No entries found for chemical system {chemsys}") + sys.exit(1) + + # Build phase diagram + print("Building phase diagram...") + pd = PhaseDiagram(entries) + + # Get stable entries + stable_entries = pd.stable_entries + print(f"✓ Phase diagram constructed with {len(stable_entries)} stable phases") + + # Print stable phases + print("\n--- STABLE PHASES ---") + for entry in stable_entries: + formula = entry.composition.reduced_formula + energy = entry.energy_per_atom + print(f" {formula:<20} E = {energy:.4f} eV/atom") + + # Analyze specific composition if requested + if args.analyze: + print(f"\n--- STABILITY ANALYSIS: {args.analyze} ---") + try: + comp = Composition(args.analyze) + + # Find closest entry + closest_entry = None + min_distance = float('inf') + + for entry in entries: + if entry.composition.reduced_formula == comp.reduced_formula: + closest_entry = entry + break + + if closest_entry: + # Calculate energy above hull + e_above_hull = pd.get_e_above_hull(closest_entry) + print(f"Energy above hull: {e_above_hull:.4f} eV/atom") + + if e_above_hull < 0.001: + print(f"Status: STABLE (on convex hull)") + elif e_above_hull < 0.05: + print(f"Status: METASTABLE (nearly stable)") + else: + print(f"Status: UNSTABLE") + + # Get decomposition + decomp = pd.get_decomposition(comp) + print(f"\nDecomposes to:") + for entry, fraction in decomp.items(): + formula = entry.composition.reduced_formula + print(f" {fraction:.3f} × {formula}") + + # Get reaction energy + rxn_energy = pd.get_equilibrium_reaction_energy(closest_entry) + print(f"\nDecomposition energy: {rxn_energy:.4f} eV/atom") + + else: + print(f"No entry found for composition {args.analyze}") + print("Checking stability of hypothetical composition...") + + # Analyze hypothetical composition + decomp = pd.get_decomposition(comp) + print(f"\nWould decompose to:") + for entry, fraction in decomp.items(): + formula = entry.composition.reduced_formula + print(f" {fraction:.3f} × {formula}") + + except Exception as e: + print(f"Error analyzing composition: {e}") + + # Get chemical potentials + if args.chemical_potentials: + print("\n--- CHEMICAL POTENTIALS ---") + print("(at stability regions)") + try: + chempots = pd.get_all_chempots() + for element, potentials in chempots.items(): + print(f"\n{element}:") + for potential in potentials[:5]: # Show first 5 + print(f" {potential:.4f} eV") + except Exception as e: + print(f"Could not calculate chemical potentials: {e}") + + # Plot phase diagram + print("\n--- GENERATING PLOT ---") + plotter = PDPlotter(pd, show_unstable=args.show_unstable) + + if args.output: + output_path = Path(args.output) + plotter.write_image(str(output_path), image_format=output_path.suffix[1:]) + print(f"✓ Phase diagram saved to {output_path}") + + if args.show: + print("Opening interactive plot...") + plotter.show() + + print(f"\n{'='*60}\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate phase diagrams using Materials Project data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Requirements: + - Materials Project API key (set MP_API_KEY environment variable) + - mp-api package: pip install mp-api + +Examples: + %(prog)s Li-Fe-O + %(prog)s Li-Fe-O --output li_fe_o_phase_diagram.png + %(prog)s Fe-O --show --analyze "Fe2O3" + %(prog)s Li-Fe-O --analyze "LiFeO2" --show-unstable + """ + ) + + parser.add_argument( + "chemsys", + help="Chemical system (e.g., Li-Fe-O, Fe-O)" + ) + + parser.add_argument( + "--output", "-o", + help="Output file for phase diagram plot (PNG, PDF, SVG)" + ) + + parser.add_argument( + "--show", "-s", + action="store_true", + help="Show interactive plot" + ) + + parser.add_argument( + "--analyze", "-a", + help="Analyze stability of specific composition (e.g., LiFeO2)" + ) + + parser.add_argument( + "--show-unstable", + action="store_true", + help="Include unstable phases in plot" + ) + + parser.add_argument( + "--chemical-potentials", + action="store_true", + help="Calculate chemical potentials" + ) + + args = parser.parse_args() + + # Validate chemical system format + elements = args.chemsys.split("-") + if len(elements) < 2: + print("Error: Chemical system must contain at least 2 elements") + print("Example: Li-Fe-O") + sys.exit(1) + + # Generate phase diagram + generate_phase_diagram(args.chemsys, args) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pymatgen/scripts/structure_analyzer.py b/scientific-packages/pymatgen/scripts/structure_analyzer.py new file mode 100644 index 0000000..677ec19 --- /dev/null +++ b/scientific-packages/pymatgen/scripts/structure_analyzer.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +""" +Structure analysis tool using pymatgen. + +Analyzes crystal structures and provides comprehensive information including: +- Composition and formula +- Space group and symmetry +- Lattice parameters +- Density +- Coordination environment +- Bond lengths and angles + +Usage: + python structure_analyzer.py structure_file [options] + +Examples: + python structure_analyzer.py POSCAR + python structure_analyzer.py structure.cif --symmetry --neighbors + python structure_analyzer.py POSCAR --export json +""" + +import argparse +import json +import sys +from pathlib import Path + +try: + from pymatgen.core import Structure + from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + from pymatgen.analysis.local_env import CrystalNN +except ImportError: + print("Error: pymatgen is not installed. Install with: pip install pymatgen") + sys.exit(1) + + +def analyze_structure(struct: Structure, args) -> dict: + """ + Perform comprehensive structure analysis. + + Args: + struct: Pymatgen Structure object + args: Command line arguments + + Returns: + Dictionary containing analysis results + """ + results = {} + + # Basic information + print("\n" + "="*60) + print("STRUCTURE ANALYSIS") + print("="*60) + + print("\n--- COMPOSITION ---") + print(f"Formula (reduced): {struct.composition.reduced_formula}") + print(f"Formula (full): {struct.composition.formula}") + print(f"Formula (Hill): {struct.composition.hill_formula}") + print(f"Chemical system: {struct.composition.chemical_system}") + print(f"Number of sites: {len(struct)}") + print(f"Number of species: {len(struct.composition.elements)}") + print(f"Molecular weight: {struct.composition.weight:.2f} amu") + + results['composition'] = { + 'reduced_formula': struct.composition.reduced_formula, + 'formula': struct.composition.formula, + 'hill_formula': struct.composition.hill_formula, + 'chemical_system': struct.composition.chemical_system, + 'num_sites': len(struct), + 'molecular_weight': struct.composition.weight, + } + + # Lattice information + print("\n--- LATTICE ---") + print(f"a = {struct.lattice.a:.4f} Å") + print(f"b = {struct.lattice.b:.4f} Å") + print(f"c = {struct.lattice.c:.4f} Å") + print(f"α = {struct.lattice.alpha:.2f}°") + print(f"β = {struct.lattice.beta:.2f}°") + print(f"γ = {struct.lattice.gamma:.2f}°") + print(f"Volume: {struct.volume:.2f} ų") + print(f"Density: {struct.density:.3f} g/cm³") + + results['lattice'] = { + 'a': struct.lattice.a, + 'b': struct.lattice.b, + 'c': struct.lattice.c, + 'alpha': struct.lattice.alpha, + 'beta': struct.lattice.beta, + 'gamma': struct.lattice.gamma, + 'volume': struct.volume, + 'density': struct.density, + } + + # Symmetry analysis + if args.symmetry: + print("\n--- SYMMETRY ---") + try: + sga = SpacegroupAnalyzer(struct) + + spacegroup_symbol = sga.get_space_group_symbol() + spacegroup_number = sga.get_space_group_number() + crystal_system = sga.get_crystal_system() + point_group = sga.get_point_group_symbol() + + print(f"Space group: {spacegroup_symbol} (#{spacegroup_number})") + print(f"Crystal system: {crystal_system}") + print(f"Point group: {point_group}") + + # Get symmetry operations + symm_ops = sga.get_symmetry_operations() + print(f"Symmetry operations: {len(symm_ops)}") + + results['symmetry'] = { + 'spacegroup_symbol': spacegroup_symbol, + 'spacegroup_number': spacegroup_number, + 'crystal_system': crystal_system, + 'point_group': point_group, + 'num_symmetry_ops': len(symm_ops), + } + + # Show equivalent sites + sym_struct = sga.get_symmetrized_structure() + print(f"Symmetry-equivalent site groups: {len(sym_struct.equivalent_sites)}") + + except Exception as e: + print(f"Could not determine symmetry: {e}") + + # Site information + print("\n--- SITES ---") + print(f"{'Index':<6} {'Species':<10} {'Wyckoff':<10} {'Frac Coords':<30}") + print("-" * 60) + + for i, site in enumerate(struct): + coords_str = f"[{site.frac_coords[0]:.4f}, {site.frac_coords[1]:.4f}, {site.frac_coords[2]:.4f}]" + wyckoff = "N/A" + + if args.symmetry: + try: + sga = SpacegroupAnalyzer(struct) + sym_struct = sga.get_symmetrized_structure() + wyckoff = sym_struct.equivalent_sites[0][0].species_string # Simplified + except: + pass + + print(f"{i:<6} {site.species_string:<10} {wyckoff:<10} {coords_str:<30}") + + # Neighbor analysis + if args.neighbors: + print("\n--- COORDINATION ENVIRONMENT ---") + try: + cnn = CrystalNN() + + for i, site in enumerate(struct): + neighbors = cnn.get_nn_info(struct, i) + print(f"\nSite {i} ({site.species_string}):") + print(f" Coordination number: {len(neighbors)}") + + if len(neighbors) > 0 and len(neighbors) <= 12: + print(f" Neighbors:") + for j, neighbor in enumerate(neighbors): + neighbor_site = struct[neighbor['site_index']] + distance = site.distance(neighbor_site) + print(f" {neighbor_site.species_string} at {distance:.3f} Å") + + except Exception as e: + print(f"Could not analyze coordination: {e}") + + # Distance matrix (for small structures) + if args.distances and len(struct) <= 20: + print("\n--- DISTANCE MATRIX (Å) ---") + distance_matrix = struct.distance_matrix + + # Print header + print(f"{'':>4}", end="") + for i in range(len(struct)): + print(f"{i:>8}", end="") + print() + + # Print matrix + for i in range(len(struct)): + print(f"{i:>4}", end="") + for j in range(len(struct)): + if i == j: + print(f"{'---':>8}", end="") + else: + print(f"{distance_matrix[i][j]:>8.3f}", end="") + print() + + print("\n" + "="*60) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze crystal structures using pymatgen", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "structure_file", + help="Structure file to analyze (CIF, POSCAR, etc.)" + ) + + parser.add_argument( + "--symmetry", "-s", + action="store_true", + help="Perform symmetry analysis" + ) + + parser.add_argument( + "--neighbors", "-n", + action="store_true", + help="Analyze coordination environment" + ) + + parser.add_argument( + "--distances", "-d", + action="store_true", + help="Show distance matrix (for structures with ≤20 atoms)" + ) + + parser.add_argument( + "--export", "-e", + choices=["json", "yaml"], + help="Export analysis results to file" + ) + + parser.add_argument( + "--output", "-o", + help="Output file for exported results" + ) + + args = parser.parse_args() + + # Read structure + try: + struct = Structure.from_file(args.structure_file) + except Exception as e: + print(f"Error reading structure file: {e}") + sys.exit(1) + + # Analyze structure + results = analyze_structure(struct, args) + + # Export results + if args.export: + output_file = args.output or f"analysis.{args.export}" + + if args.export == "json": + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\n✓ Analysis exported to {output_file}") + + elif args.export == "yaml": + try: + import yaml + with open(output_file, "w") as f: + yaml.dump(results, f, default_flow_style=False) + print(f"\n✓ Analysis exported to {output_file}") + except ImportError: + print("Error: PyYAML is not installed. Install with: pip install pyyaml") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pymatgen/scripts/structure_converter.py b/scientific-packages/pymatgen/scripts/structure_converter.py new file mode 100644 index 0000000..d2ec405 --- /dev/null +++ b/scientific-packages/pymatgen/scripts/structure_converter.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Structure file format converter using pymatgen. + +This script converts between different structure file formats supported by pymatgen. +Supports automatic format detection and batch conversion. + +Usage: + python structure_converter.py input_file output_file + python structure_converter.py input_file --format cif + python structure_converter.py *.cif --output-dir ./converted --format poscar + +Examples: + python structure_converter.py POSCAR structure.cif + python structure_converter.py structure.cif --format json + python structure_converter.py *.vasp --output-dir ./cif_files --format cif +""" + +import argparse +import sys +from pathlib import Path +from typing import List + +try: + from pymatgen.core import Structure +except ImportError: + print("Error: pymatgen is not installed. Install with: pip install pymatgen") + sys.exit(1) + + +def convert_structure(input_path: Path, output_path: Path = None, output_format: str = None) -> bool: + """ + Convert a structure file to a different format. + + Args: + input_path: Path to input structure file + output_path: Path to output file (optional if output_format is specified) + output_format: Target format (e.g., 'cif', 'poscar', 'json', 'yaml') + + Returns: + True if conversion succeeded, False otherwise + """ + try: + # Read structure with automatic format detection + struct = Structure.from_file(str(input_path)) + print(f"✓ Read structure: {struct.composition.reduced_formula} from {input_path}") + + # Determine output path + if output_path is None and output_format: + output_path = input_path.with_suffix(f".{output_format}") + elif output_path is None: + print("Error: Must specify either output_path or output_format") + return False + + # Write structure + struct.to(filename=str(output_path)) + print(f"✓ Wrote structure to {output_path}") + + return True + + except Exception as e: + print(f"✗ Error converting {input_path}: {e}") + return False + + +def batch_convert(input_files: List[Path], output_dir: Path, output_format: str) -> None: + """ + Convert multiple structure files to a common format. + + Args: + input_files: List of input structure files + output_dir: Directory for output files + output_format: Target format for all files + """ + output_dir.mkdir(parents=True, exist_ok=True) + + success_count = 0 + for input_file in input_files: + output_file = output_dir / f"{input_file.stem}.{output_format}" + if convert_structure(input_file, output_file): + success_count += 1 + + print(f"\n{'='*60}") + print(f"Conversion complete: {success_count}/{len(input_files)} files converted successfully") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert structure files between different formats using pymatgen", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Supported formats: + Input: CIF, POSCAR, CONTCAR, XYZ, PDB, JSON, YAML, and many more + Output: CIF, POSCAR, XYZ, PDB, JSON, YAML, XSF, and many more + +Examples: + %(prog)s POSCAR structure.cif + %(prog)s structure.cif --format json + %(prog)s *.cif --output-dir ./poscar_files --format poscar + """ + ) + + parser.add_argument( + "input", + nargs="+", + help="Input structure file(s). Supports wildcards for batch conversion." + ) + + parser.add_argument( + "output", + nargs="?", + help="Output structure file (ignored if --output-dir is used)" + ) + + parser.add_argument( + "--format", "-f", + help="Output format (e.g., cif, poscar, json, yaml, xyz)" + ) + + parser.add_argument( + "--output-dir", "-o", + type=Path, + help="Output directory for batch conversion" + ) + + args = parser.parse_args() + + # Expand wildcards and convert to Path objects + input_files = [] + for pattern in args.input: + matches = list(Path.cwd().glob(pattern)) + if matches: + input_files.extend(matches) + else: + input_files.append(Path(pattern)) + + # Filter to files only + input_files = [f for f in input_files if f.is_file()] + + if not input_files: + print("Error: No input files found") + sys.exit(1) + + # Batch conversion mode + if args.output_dir or len(input_files) > 1: + if not args.format: + print("Error: --format is required for batch conversion") + sys.exit(1) + + output_dir = args.output_dir or Path("./converted") + batch_convert(input_files, output_dir, args.format) + + # Single file conversion + elif len(input_files) == 1: + input_file = input_files[0] + + if args.output: + output_file = Path(args.output) + convert_structure(input_file, output_file) + elif args.format: + convert_structure(input_file, output_format=args.format) + else: + print("Error: Must specify output file or --format") + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pymc/SKILL.md b/scientific-packages/pymc/SKILL.md new file mode 100644 index 0000000..9ccbdf4 --- /dev/null +++ b/scientific-packages/pymc/SKILL.md @@ -0,0 +1,566 @@ +--- +name: pymc-bayesian-modeling +description: Comprehensive toolkit for building, fitting, and analyzing Bayesian models using PyMC. This skill should be used when working with probabilistic programming, Bayesian statistics, MCMC sampling, hierarchical models, model comparison, or any task involving uncertainty quantification through Bayesian inference. Use for linear regression, logistic regression, hierarchical/multilevel models, time series, mixture models, model diagnostics, and posterior predictive checks. +--- + +# PyMC Bayesian Modeling + +## Overview + +PyMC is a Python library for Bayesian modeling and probabilistic programming. This skill provides comprehensive guidance for building, fitting, validating, and comparing Bayesian models using PyMC's modern API (version 5.x+). It includes workflows for common model types, diagnostic procedures, and best practices for Bayesian inference. + +## When to Use This Skill + +Use this skill when: +- Building Bayesian models (linear/logistic regression, hierarchical models, time series, etc.) +- Performing MCMC sampling or variational inference +- Conducting prior/posterior predictive checks +- Diagnosing sampling issues (divergences, convergence, ESS) +- Comparing multiple models using information criteria (LOO, WAIC) +- Implementing uncertainty quantification through Bayesian methods +- Working with hierarchical/multilevel data structures +- Handling missing data or measurement error in a principled way + +## Standard Bayesian Workflow + +Follow this workflow for building and validating Bayesian models: + +### 1. Data Preparation + +```python +import pymc as pm +import arviz as az +import numpy as np + +# Load and prepare data +X = ... # Predictors +y = ... # Outcomes + +# Standardize predictors for better sampling +X_mean = X.mean(axis=0) +X_std = X.std(axis=0) +X_scaled = (X - X_mean) / X_std +``` + +**Key practices:** +- Standardize continuous predictors (improves sampling efficiency) +- Center outcomes when possible +- Handle missing data explicitly (treat as parameters) +- Use named dimensions with `coords` for clarity + +### 2. Model Building + +```python +coords = { + 'predictors': ['var1', 'var2', 'var3'], + 'obs_id': np.arange(len(y)) +} + +with pm.Model(coords=coords) as model: + # Priors + alpha = pm.Normal('alpha', mu=0, sigma=1) + beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') + sigma = pm.HalfNormal('sigma', sigma=1) + + # Linear predictor + mu = alpha + pm.math.dot(X_scaled, beta) + + # Likelihood + y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id') +``` + +**Key practices:** +- Use weakly informative priors (not flat priors) +- Use `HalfNormal` or `Exponential` for scale parameters +- Use named dimensions (`dims`) instead of `shape` when possible +- Use `pm.Data()` for values that will be updated for predictions + +### 3. Prior Predictive Check + +**Always validate priors before fitting:** + +```python +with model: + prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42) + +# Visualize +az.plot_ppc(prior_pred, group='prior') +``` + +**Check:** +- Do prior predictions span reasonable values? +- Are extreme values plausible given domain knowledge? +- If priors generate implausible data, adjust and re-check + +### 4. Fit Model + +```python +with model: + # Optional: Quick exploration with ADVI + # approx = pm.fit(n=20000) + + # Full MCMC inference + idata = pm.sample( + draws=2000, + tune=1000, + chains=4, + target_accept=0.9, + random_seed=42, + idata_kwargs={'log_likelihood': True} # For model comparison + ) +``` + +**Key parameters:** +- `draws=2000`: Number of samples per chain +- `tune=1000`: Warmup samples (discarded) +- `chains=4`: Run 4 chains for convergence checking +- `target_accept=0.9`: Higher for difficult posteriors (0.95-0.99) +- Include `log_likelihood=True` for model comparison + +### 5. Check Diagnostics + +**Use the diagnostic script:** + +```python +from scripts.model_diagnostics import check_diagnostics + +results = check_diagnostics(idata, var_names=['alpha', 'beta', 'sigma']) +``` + +**Check:** +- **R-hat < 1.01**: Chains have converged +- **ESS > 400**: Sufficient effective samples +- **No divergences**: NUTS sampled successfully +- **Trace plots**: Chains should mix well (fuzzy caterpillar) + +**If issues arise:** +- Divergences → Increase `target_accept=0.95`, use non-centered parameterization +- Low ESS → Sample more draws, reparameterize to reduce correlation +- High R-hat → Run longer, check for multimodality + +### 6. Posterior Predictive Check + +**Validate model fit:** + +```python +with model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42) + +# Visualize +az.plot_ppc(idata) +``` + +**Check:** +- Do posterior predictions capture observed data patterns? +- Are systematic deviations evident (model misspecification)? +- Consider alternative models if fit is poor + +### 7. Analyze Results + +```python +# Summary statistics +print(az.summary(idata, var_names=['alpha', 'beta', 'sigma'])) + +# Posterior distributions +az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma']) + +# Coefficient estimates +az.plot_forest(idata, var_names=['beta'], combined=True) +``` + +### 8. Make Predictions + +```python +X_new = ... # New predictor values +X_new_scaled = (X_new - X_mean) / X_std + +with model: + pm.set_data({'X_scaled': X_new_scaled}) + post_pred = pm.sample_posterior_predictive( + idata.posterior, + var_names=['y_obs'], + random_seed=42 + ) + +# Extract prediction intervals +y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']) +y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs']) +``` + +## Common Model Patterns + +### Linear Regression + +For continuous outcomes with linear relationships: + +```python +with pm.Model() as linear_model: + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + sigma = pm.HalfNormal('sigma', sigma=1) + + mu = alpha + pm.math.dot(X, beta) + y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs) +``` + +**Use template:** `assets/linear_regression_template.py` + +### Logistic Regression + +For binary outcomes: + +```python +with pm.Model() as logistic_model: + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + + logit_p = alpha + pm.math.dot(X, beta) + y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs) +``` + +### Hierarchical Models + +For grouped data (use non-centered parameterization): + +```python +with pm.Model(coords={'groups': group_names}) as hierarchical_model: + # Hyperpriors + mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10) + sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1) + + # Group-level (non-centered) + alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups') + alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups') + + # Observation-level + mu = alpha[group_idx] + sigma = pm.HalfNormal('sigma', sigma=1) + y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs) +``` + +**Use template:** `assets/hierarchical_model_template.py` + +**Critical:** Always use non-centered parameterization for hierarchical models to avoid divergences. + +### Poisson Regression + +For count data: + +```python +with pm.Model() as poisson_model: + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + + log_lambda = alpha + pm.math.dot(X, beta) + y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs) +``` + +For overdispersed counts, use `NegativeBinomial` instead. + +### Time Series + +For autoregressive processes: + +```python +with pm.Model() as ar_model: + sigma = pm.HalfNormal('sigma', sigma=1) + rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order) + init_dist = pm.Normal.dist(mu=0, sigma=sigma) + + y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs) +``` + +## Model Comparison + +### Comparing Models + +Use LOO or WAIC for model comparison: + +```python +from scripts.model_comparison import compare_models, check_loo_reliability + +# Fit models with log_likelihood +models = { + 'Model1': idata1, + 'Model2': idata2, + 'Model3': idata3 +} + +# Compare using LOO +comparison = compare_models(models, ic='loo') + +# Check reliability +check_loo_reliability(models) +``` + +**Interpretation:** +- **Δloo < 2**: Models are similar, choose simpler model +- **2 < Δloo < 4**: Weak evidence for better model +- **4 < Δloo < 10**: Moderate evidence +- **Δloo > 10**: Strong evidence for better model + +**Check Pareto-k values:** +- k < 0.7: LOO reliable +- k > 0.7: Consider WAIC or k-fold CV + +### Model Averaging + +When models are similar, average predictions: + +```python +from scripts.model_comparison import model_averaging + +averaged_pred, weights = model_averaging(models, var_name='y_obs') +``` + +## Distribution Selection Guide + +### For Priors + +**Scale parameters** (σ, τ): +- `pm.HalfNormal('sigma', sigma=1)` - Default choice +- `pm.Exponential('sigma', lam=1)` - Alternative +- `pm.Gamma('sigma', alpha=2, beta=1)` - More informative + +**Unbounded parameters**: +- `pm.Normal('theta', mu=0, sigma=1)` - For standardized data +- `pm.StudentT('theta', nu=3, mu=0, sigma=1)` - Robust to outliers + +**Positive parameters**: +- `pm.LogNormal('theta', mu=0, sigma=1)` +- `pm.Gamma('theta', alpha=2, beta=1)` + +**Probabilities**: +- `pm.Beta('p', alpha=2, beta=2)` - Weakly informative +- `pm.Uniform('p', lower=0, upper=1)` - Non-informative (use sparingly) + +**Correlation matrices**: +- `pm.LKJCorr('corr', n=n_vars, eta=2)` - eta=1 uniform, eta>1 prefers identity + +### For Likelihoods + +**Continuous outcomes**: +- `pm.Normal('y', mu=mu, sigma=sigma)` - Default for continuous data +- `pm.StudentT('y', nu=nu, mu=mu, sigma=sigma)` - Robust to outliers + +**Count data**: +- `pm.Poisson('y', mu=lambda)` - Equidispersed counts +- `pm.NegativeBinomial('y', mu=mu, alpha=alpha)` - Overdispersed counts +- `pm.ZeroInflatedPoisson('y', psi=psi, mu=mu)` - Excess zeros + +**Binary outcomes**: +- `pm.Bernoulli('y', p=p)` or `pm.Bernoulli('y', logit_p=logit_p)` + +**Categorical outcomes**: +- `pm.Categorical('y', p=probs)` + +**See:** `references/distributions.md` for comprehensive distribution reference + +## Sampling and Inference + +### MCMC with NUTS + +Default and recommended for most models: + +```python +idata = pm.sample( + draws=2000, + tune=1000, + chains=4, + target_accept=0.9, + random_seed=42 +) +``` + +**Adjust when needed:** +- Divergences → `target_accept=0.95` or higher +- Slow sampling → Use ADVI for initialization +- Discrete parameters → Use `pm.Metropolis()` for discrete vars + +### Variational Inference + +Fast approximation for exploration or initialization: + +```python +with model: + approx = pm.fit(n=20000, method='advi') + + # Use for initialization + start = approx.sample(return_inferencedata=False)[0] + idata = pm.sample(start=start) +``` + +**Trade-offs:** +- Much faster than MCMC +- Approximate (may underestimate uncertainty) +- Good for large models or quick exploration + +**See:** `references/sampling_inference.md` for detailed sampling guide + +## Diagnostic Scripts + +### Comprehensive Diagnostics + +```python +from scripts.model_diagnostics import create_diagnostic_report + +create_diagnostic_report( + idata, + var_names=['alpha', 'beta', 'sigma'], + output_dir='diagnostics/' +) +``` + +Creates: +- Trace plots +- Rank plots (mixing check) +- Autocorrelation plots +- Energy plots +- ESS evolution +- Summary statistics CSV + +### Quick Diagnostic Check + +```python +from scripts.model_diagnostics import check_diagnostics + +results = check_diagnostics(idata) +``` + +Checks R-hat, ESS, divergences, and tree depth. + +## Common Issues and Solutions + +### Divergences + +**Symptom:** `idata.sample_stats.diverging.sum() > 0` + +**Solutions:** +1. Increase `target_accept=0.95` or `0.99` +2. Use non-centered parameterization (hierarchical models) +3. Add stronger priors to constrain parameters +4. Check for model misspecification + +### Low Effective Sample Size + +**Symptom:** `ESS < 400` + +**Solutions:** +1. Sample more draws: `draws=5000` +2. Reparameterize to reduce posterior correlation +3. Use QR decomposition for regression with correlated predictors + +### High R-hat + +**Symptom:** `R-hat > 1.01` + +**Solutions:** +1. Run longer chains: `tune=2000, draws=5000` +2. Check for multimodality +3. Improve initialization with ADVI + +### Slow Sampling + +**Solutions:** +1. Use ADVI initialization +2. Reduce model complexity +3. Increase parallelization: `cores=8, chains=8` +4. Use variational inference if appropriate + +## Best Practices + +### Model Building + +1. **Always standardize predictors** for better sampling +2. **Use weakly informative priors** (not flat) +3. **Use named dimensions** (`dims`) for clarity +4. **Non-centered parameterization** for hierarchical models +5. **Check prior predictive** before fitting + +### Sampling + +1. **Run multiple chains** (at least 4) for convergence +2. **Use `target_accept=0.9`** as baseline (higher if needed) +3. **Include `log_likelihood=True`** for model comparison +4. **Set random seed** for reproducibility + +### Validation + +1. **Check diagnostics** before interpretation (R-hat, ESS, divergences) +2. **Posterior predictive check** for model validation +3. **Compare multiple models** when appropriate +4. **Report uncertainty** (HDI intervals, not just point estimates) + +### Workflow + +1. Start simple, add complexity gradually +2. Prior predictive check → Fit → Diagnostics → Posterior predictive check +3. Iterate on model specification based on checks +4. Document assumptions and prior choices + +## Resources + +This skill includes: + +### References (`references/`) + +- **`distributions.md`**: Comprehensive catalog of PyMC distributions organized by category (continuous, discrete, multivariate, mixture, time series). Use when selecting priors or likelihoods. + +- **`sampling_inference.md`**: Detailed guide to sampling algorithms (NUTS, Metropolis, SMC), variational inference (ADVI, SVGD), and handling sampling issues. Use when encountering convergence problems or choosing inference methods. + +- **`workflows.md`**: Complete workflow examples and code patterns for common model types, data preparation, prior selection, and model validation. Use as a cookbook for standard Bayesian analyses. + +### Scripts (`scripts/`) + +- **`model_diagnostics.py`**: Automated diagnostic checking and report generation. Functions: `check_diagnostics()` for quick checks, `create_diagnostic_report()` for comprehensive analysis with plots. + +- **`model_comparison.py`**: Model comparison utilities using LOO/WAIC. Functions: `compare_models()`, `check_loo_reliability()`, `model_averaging()`. + +### Templates (`assets/`) + +- **`linear_regression_template.py`**: Complete template for Bayesian linear regression with full workflow (data prep, prior checks, fitting, diagnostics, predictions). + +- **`hierarchical_model_template.py`**: Complete template for hierarchical/multilevel models with non-centered parameterization and group-level analysis. + +## Quick Reference + +### Model Building +```python +with pm.Model(coords={'var': names}) as model: + # Priors + param = pm.Normal('param', mu=0, sigma=1, dims='var') + # Likelihood + y = pm.Normal('y', mu=..., sigma=..., observed=data) +``` + +### Sampling +```python +idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9) +``` + +### Diagnostics +```python +from scripts.model_diagnostics import check_diagnostics +check_diagnostics(idata) +``` + +### Model Comparison +```python +from scripts.model_comparison import compare_models +compare_models({'m1': idata1, 'm2': idata2}, ic='loo') +``` + +### Predictions +```python +with model: + pm.set_data({'X': X_new}) + pred = pm.sample_posterior_predictive(idata.posterior) +``` + +## Additional Notes + +- PyMC integrates with ArviZ for visualization and diagnostics +- Use `pm.model_to_graphviz(model)` to visualize model structure +- Save results with `idata.to_netcdf('results.nc')` +- Load with `az.from_netcdf('results.nc')` +- For very large models, consider minibatch ADVI or data subsampling diff --git a/scientific-packages/pymc/assets/hierarchical_model_template.py b/scientific-packages/pymc/assets/hierarchical_model_template.py new file mode 100644 index 0000000..d6215de --- /dev/null +++ b/scientific-packages/pymc/assets/hierarchical_model_template.py @@ -0,0 +1,333 @@ +""" +PyMC Hierarchical/Multilevel Model Template + +This template provides a complete workflow for Bayesian hierarchical models, +useful for grouped/nested data (e.g., students within schools, patients within hospitals). + +Customize the sections marked with # TODO +""" + +import pymc as pm +import arviz as az +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +# ============================================================================= +# 1. DATA PREPARATION +# ============================================================================= + +# TODO: Load your data with group structure +# Example: +# df = pd.read_csv('data.csv') +# groups = df['group_id'].values +# X = df['predictor'].values +# y = df['outcome'].values + +# For demonstration: Generate hierarchical data +np.random.seed(42) +n_groups = 10 +n_per_group = 20 +n_obs = n_groups * n_per_group + +# True hierarchical structure +true_mu_alpha = 5.0 +true_sigma_alpha = 2.0 +true_mu_beta = 1.5 +true_sigma_beta = 0.5 +true_sigma = 1.0 + +group_alphas = np.random.normal(true_mu_alpha, true_sigma_alpha, n_groups) +group_betas = np.random.normal(true_mu_beta, true_sigma_beta, n_groups) + +# Generate data +groups = np.repeat(np.arange(n_groups), n_per_group) +X = np.random.randn(n_obs) +y = group_alphas[groups] + group_betas[groups] * X + np.random.randn(n_obs) * true_sigma + +# TODO: Customize group names +group_names = [f'Group_{i}' for i in range(n_groups)] + +# ============================================================================= +# 2. BUILD HIERARCHICAL MODEL +# ============================================================================= + +print("Building hierarchical model...") + +coords = { + 'groups': group_names, + 'obs': np.arange(n_obs) +} + +with pm.Model(coords=coords) as hierarchical_model: + # Data containers (for later predictions) + X_data = pm.Data('X_data', X) + groups_data = pm.Data('groups_data', groups) + + # Hyperpriors (population-level parameters) + # TODO: Adjust hyperpriors based on your domain knowledge + mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10) + sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=5) + + mu_beta = pm.Normal('mu_beta', mu=0, sigma=10) + sigma_beta = pm.HalfNormal('sigma_beta', sigma=5) + + # Group-level parameters (non-centered parameterization) + # Non-centered parameterization improves sampling efficiency + alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups') + alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups') + + beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='groups') + beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='groups') + + # Observation-level model + mu = alpha[groups_data] + beta[groups_data] * X_data + + # Observation noise + sigma = pm.HalfNormal('sigma', sigma=5) + + # Likelihood + y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs') + +print("Model built successfully!") +print(f"Groups: {n_groups}") +print(f"Observations: {n_obs}") + +# ============================================================================= +# 3. PRIOR PREDICTIVE CHECK +# ============================================================================= + +print("\nRunning prior predictive check...") +with hierarchical_model: + prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42) + +# Visualize prior predictions +fig, ax = plt.subplots(figsize=(10, 6)) +az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax) +ax.set_title('Prior Predictive Check') +plt.tight_layout() +plt.savefig('hierarchical_prior_check.png', dpi=300, bbox_inches='tight') +print("Prior predictive check saved to 'hierarchical_prior_check.png'") + +# ============================================================================= +# 4. FIT MODEL +# ============================================================================= + +print("\nFitting hierarchical model...") +print("(This may take a few minutes due to model complexity)") + +with hierarchical_model: + # MCMC sampling with higher target_accept for hierarchical models + idata = pm.sample( + draws=2000, + tune=2000, # More tuning for hierarchical models + chains=4, + target_accept=0.95, # Higher for better convergence + random_seed=42, + idata_kwargs={'log_likelihood': True} + ) + +print("Sampling complete!") + +# ============================================================================= +# 5. CHECK DIAGNOSTICS +# ============================================================================= + +print("\n" + "="*60) +print("DIAGNOSTICS") +print("="*60) + +# Summary for key parameters +summary = az.summary( + idata, + var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma', 'alpha', 'beta'] +) +print("\nParameter Summary:") +print(summary) + +# Check convergence +bad_rhat = summary[summary['r_hat'] > 1.01] +if len(bad_rhat) > 0: + print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01") + print(bad_rhat[['r_hat']]) +else: + print("\n✓ All R-hat values < 1.01 (good convergence)") + +# Check effective sample size +low_ess = summary[summary['ess_bulk'] < 400] +if len(low_ess) > 0: + print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400") + print(low_ess[['ess_bulk']].head(10)) +else: + print("\n✓ All ESS values > 400 (sufficient samples)") + +# Check divergences +divergences = idata.sample_stats.diverging.sum().item() +if divergences > 0: + print(f"\n⚠️ WARNING: {divergences} divergent transitions") + print(" This is common in hierarchical models - non-centered parameterization already applied") + print(" Consider even higher target_accept or stronger hyperpriors") +else: + print("\n✓ No divergences") + +# Trace plots for hyperparameters +fig, axes = plt.subplots(5, 2, figsize=(12, 12)) +az.plot_trace( + idata, + var_names=['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma'], + axes=axes +) +plt.tight_layout() +plt.savefig('hierarchical_trace_plots.png', dpi=300, bbox_inches='tight') +print("\nTrace plots saved to 'hierarchical_trace_plots.png'") + +# ============================================================================= +# 6. POSTERIOR PREDICTIVE CHECK +# ============================================================================= + +print("\nRunning posterior predictive check...") +with hierarchical_model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42) + +# Visualize fit +fig, ax = plt.subplots(figsize=(10, 6)) +az.plot_ppc(idata, num_pp_samples=100, ax=ax) +ax.set_title('Posterior Predictive Check') +plt.tight_layout() +plt.savefig('hierarchical_posterior_check.png', dpi=300, bbox_inches='tight') +print("Posterior predictive check saved to 'hierarchical_posterior_check.png'") + +# ============================================================================= +# 7. ANALYZE HIERARCHICAL STRUCTURE +# ============================================================================= + +print("\n" + "="*60) +print("POPULATION-LEVEL (HYPERPARAMETER) ESTIMATES") +print("="*60) + +# Population-level estimates +hyper_summary = summary.loc[['mu_alpha', 'sigma_alpha', 'mu_beta', 'sigma_beta', 'sigma']] +print(hyper_summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']]) + +# Forest plot for group-level parameters +fig, axes = plt.subplots(1, 2, figsize=(14, 8)) + +# Group intercepts +az.plot_forest(idata, var_names=['alpha'], combined=True, ax=axes[0]) +axes[0].set_title('Group-Level Intercepts (α)') +axes[0].set_yticklabels(group_names) +axes[0].axvline(idata.posterior['mu_alpha'].mean().item(), color='red', linestyle='--', label='Population mean') +axes[0].legend() + +# Group slopes +az.plot_forest(idata, var_names=['beta'], combined=True, ax=axes[1]) +axes[1].set_title('Group-Level Slopes (β)') +axes[1].set_yticklabels(group_names) +axes[1].axvline(idata.posterior['mu_beta'].mean().item(), color='red', linestyle='--', label='Population mean') +axes[1].legend() + +plt.tight_layout() +plt.savefig('group_level_estimates.png', dpi=300, bbox_inches='tight') +print("\nGroup-level estimates saved to 'group_level_estimates.png'") + +# Shrinkage visualization +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# Intercepts +alpha_samples = idata.posterior['alpha'].values.reshape(-1, n_groups) +alpha_means = alpha_samples.mean(axis=0) +mu_alpha_mean = idata.posterior['mu_alpha'].mean().item() + +axes[0].scatter(range(n_groups), alpha_means, alpha=0.6) +axes[0].axhline(mu_alpha_mean, color='red', linestyle='--', label='Population mean') +axes[0].set_xlabel('Group') +axes[0].set_ylabel('Intercept') +axes[0].set_title('Group Intercepts (showing shrinkage to population mean)') +axes[0].legend() + +# Slopes +beta_samples = idata.posterior['beta'].values.reshape(-1, n_groups) +beta_means = beta_samples.mean(axis=0) +mu_beta_mean = idata.posterior['mu_beta'].mean().item() + +axes[1].scatter(range(n_groups), beta_means, alpha=0.6) +axes[1].axhline(mu_beta_mean, color='red', linestyle='--', label='Population mean') +axes[1].set_xlabel('Group') +axes[1].set_ylabel('Slope') +axes[1].set_title('Group Slopes (showing shrinkage to population mean)') +axes[1].legend() + +plt.tight_layout() +plt.savefig('shrinkage_plot.png', dpi=300, bbox_inches='tight') +print("Shrinkage plot saved to 'shrinkage_plot.png'") + +# ============================================================================= +# 8. PREDICTIONS FOR NEW DATA +# ============================================================================= + +# TODO: Specify new data +# For existing groups: +# new_X = np.array([...]) +# new_groups = np.array([0, 1, 2, ...]) # Existing group indices + +# For a new group (predict using population-level parameters): +# Just use mu_alpha and mu_beta + +print("\n" + "="*60) +print("PREDICTIONS FOR NEW DATA") +print("="*60) + +# Example: Predict for existing groups +new_X = np.array([-2, -1, 0, 1, 2]) +new_groups = np.array([0, 2, 4, 6, 8]) # Select some groups + +with hierarchical_model: + pm.set_data({'X_data': new_X, 'groups_data': new_groups, 'obs': np.arange(len(new_X))}) + + post_pred = pm.sample_posterior_predictive( + idata.posterior, + var_names=['y_obs'], + random_seed=42 + ) + +y_pred_samples = post_pred.posterior_predictive['y_obs'] +y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values +y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values + +print(f"Predictions for existing groups:") +print(f"{'Group':<10} {'X':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}") +print("-"*65) +for i, g in enumerate(new_groups): + print(f"{group_names[g]:<10} {new_X[i]:<10.2f} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}") + +# Predict for a new group (using population parameters) +print(f"\nPrediction for a NEW group (using population-level parameters):") +new_X_newgroup = np.array([0.0]) + +# Manually compute using population parameters +mu_alpha_samples = idata.posterior['mu_alpha'].values.flatten() +mu_beta_samples = idata.posterior['mu_beta'].values.flatten() +sigma_samples = idata.posterior['sigma'].values.flatten() + +# Predicted mean for new group +y_pred_newgroup = mu_alpha_samples + mu_beta_samples * new_X_newgroup[0] +y_pred_mean_newgroup = y_pred_newgroup.mean() +y_pred_hdi_newgroup = az.hdi(y_pred_newgroup, hdi_prob=0.95) + +print(f"X = {new_X_newgroup[0]:.2f}") +print(f"Predicted mean: {y_pred_mean_newgroup:.3f}") +print(f"95% HDI: [{y_pred_hdi_newgroup[0]:.3f}, {y_pred_hdi_newgroup[1]:.3f}]") + +# ============================================================================= +# 9. SAVE RESULTS +# ============================================================================= + +idata.to_netcdf('hierarchical_model_results.nc') +print("\nResults saved to 'hierarchical_model_results.nc'") + +summary.to_csv('hierarchical_model_summary.csv') +print("Summary saved to 'hierarchical_model_summary.csv'") + +print("\n" + "="*60) +print("ANALYSIS COMPLETE") +print("="*60) diff --git a/scientific-packages/pymc/assets/linear_regression_template.py b/scientific-packages/pymc/assets/linear_regression_template.py new file mode 100644 index 0000000..63d47e8 --- /dev/null +++ b/scientific-packages/pymc/assets/linear_regression_template.py @@ -0,0 +1,241 @@ +""" +PyMC Linear Regression Template + +This template provides a complete workflow for Bayesian linear regression, +including data preparation, model building, diagnostics, and predictions. + +Customize the sections marked with # TODO +""" + +import pymc as pm +import arviz as az +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +# ============================================================================= +# 1. DATA PREPARATION +# ============================================================================= + +# TODO: Load your data +# Example: +# df = pd.read_csv('data.csv') +# X = df[['predictor1', 'predictor2', 'predictor3']].values +# y = df['outcome'].values + +# For demonstration: +np.random.seed(42) +n_samples = 100 +n_predictors = 3 + +X = np.random.randn(n_samples, n_predictors) +true_beta = np.array([1.5, -0.8, 2.1]) +true_alpha = 0.5 +y = true_alpha + X @ true_beta + np.random.randn(n_samples) * 0.5 + +# Standardize predictors for better sampling +X_mean = X.mean(axis=0) +X_std = X.std(axis=0) +X_scaled = (X - X_mean) / X_std + +# ============================================================================= +# 2. BUILD MODEL +# ============================================================================= + +# TODO: Customize predictor names +predictor_names = ['predictor1', 'predictor2', 'predictor3'] + +coords = { + 'predictors': predictor_names, + 'obs_id': np.arange(len(y)) +} + +with pm.Model(coords=coords) as linear_model: + # Priors + # TODO: Adjust prior parameters based on your domain knowledge + alpha = pm.Normal('alpha', mu=0, sigma=1) + beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') + sigma = pm.HalfNormal('sigma', sigma=1) + + # Linear predictor + mu = alpha + pm.math.dot(X_scaled, beta) + + # Likelihood + y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id') + +# ============================================================================= +# 3. PRIOR PREDICTIVE CHECK +# ============================================================================= + +print("Running prior predictive check...") +with linear_model: + prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42) + +# Visualize prior predictions +fig, ax = plt.subplots(figsize=(10, 6)) +az.plot_ppc(prior_pred, group='prior', num_pp_samples=100, ax=ax) +ax.set_title('Prior Predictive Check') +plt.tight_layout() +plt.savefig('prior_predictive_check.png', dpi=300, bbox_inches='tight') +print("Prior predictive check saved to 'prior_predictive_check.png'") + +# ============================================================================= +# 4. FIT MODEL +# ============================================================================= + +print("\nFitting model...") +with linear_model: + # Optional: Quick ADVI exploration + # approx = pm.fit(n=20000, random_seed=42) + + # MCMC sampling + idata = pm.sample( + draws=2000, + tune=1000, + chains=4, + target_accept=0.9, + random_seed=42, + idata_kwargs={'log_likelihood': True} + ) + +print("Sampling complete!") + +# ============================================================================= +# 5. CHECK DIAGNOSTICS +# ============================================================================= + +print("\n" + "="*60) +print("DIAGNOSTICS") +print("="*60) + +# Summary statistics +summary = az.summary(idata, var_names=['alpha', 'beta', 'sigma']) +print("\nParameter Summary:") +print(summary) + +# Check convergence +bad_rhat = summary[summary['r_hat'] > 1.01] +if len(bad_rhat) > 0: + print(f"\n⚠️ WARNING: {len(bad_rhat)} parameters with R-hat > 1.01") + print(bad_rhat[['r_hat']]) +else: + print("\n✓ All R-hat values < 1.01 (good convergence)") + +# Check effective sample size +low_ess = summary[summary['ess_bulk'] < 400] +if len(low_ess) > 0: + print(f"\n⚠️ WARNING: {len(low_ess)} parameters with ESS < 400") + print(low_ess[['ess_bulk', 'ess_tail']]) +else: + print("\n✓ All ESS values > 400 (sufficient samples)") + +# Check divergences +divergences = idata.sample_stats.diverging.sum().item() +if divergences > 0: + print(f"\n⚠️ WARNING: {divergences} divergent transitions") + print(" Consider increasing target_accept or reparameterizing") +else: + print("\n✓ No divergences") + +# Trace plots +fig, axes = plt.subplots(len(['alpha', 'beta', 'sigma']), 2, figsize=(12, 8)) +az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma'], axes=axes) +plt.tight_layout() +plt.savefig('trace_plots.png', dpi=300, bbox_inches='tight') +print("\nTrace plots saved to 'trace_plots.png'") + +# ============================================================================= +# 6. POSTERIOR PREDICTIVE CHECK +# ============================================================================= + +print("\nRunning posterior predictive check...") +with linear_model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42) + +# Visualize fit +fig, ax = plt.subplots(figsize=(10, 6)) +az.plot_ppc(idata, num_pp_samples=100, ax=ax) +ax.set_title('Posterior Predictive Check') +plt.tight_layout() +plt.savefig('posterior_predictive_check.png', dpi=300, bbox_inches='tight') +print("Posterior predictive check saved to 'posterior_predictive_check.png'") + +# ============================================================================= +# 7. ANALYZE RESULTS +# ============================================================================= + +# Posterior distributions +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'], ax=axes) +plt.tight_layout() +plt.savefig('posterior_distributions.png', dpi=300, bbox_inches='tight') +print("Posterior distributions saved to 'posterior_distributions.png'") + +# Forest plot for coefficients +fig, ax = plt.subplots(figsize=(8, 6)) +az.plot_forest(idata, var_names=['beta'], combined=True, ax=ax) +ax.set_title('Coefficient Estimates (95% HDI)') +ax.set_yticklabels(predictor_names) +plt.tight_layout() +plt.savefig('coefficient_forest_plot.png', dpi=300, bbox_inches='tight') +print("Forest plot saved to 'coefficient_forest_plot.png'") + +# Print coefficient estimates +print("\n" + "="*60) +print("COEFFICIENT ESTIMATES") +print("="*60) +beta_samples = idata.posterior['beta'] +for i, name in enumerate(predictor_names): + mean = beta_samples.sel(predictors=name).mean().item() + hdi = az.hdi(beta_samples.sel(predictors=name), hdi_prob=0.95) + print(f"{name:20s}: {mean:7.3f} [95% HDI: {hdi.values[0]:7.3f}, {hdi.values[1]:7.3f}]") + +# ============================================================================= +# 8. PREDICTIONS FOR NEW DATA +# ============================================================================= + +# TODO: Provide new data for predictions +# X_new = np.array([[...], [...], ...]) # New predictor values + +# For demonstration, use some test data +X_new = np.random.randn(10, n_predictors) +X_new_scaled = (X_new - X_mean) / X_std + +# Update model data and predict +with linear_model: + pm.set_data({'X_scaled': X_new_scaled, 'obs_id': np.arange(len(X_new))}) + + post_pred = pm.sample_posterior_predictive( + idata.posterior, + var_names=['y_obs'], + random_seed=42 + ) + +# Extract predictions +y_pred_samples = post_pred.posterior_predictive['y_obs'] +y_pred_mean = y_pred_samples.mean(dim=['chain', 'draw']).values +y_pred_hdi = az.hdi(y_pred_samples, hdi_prob=0.95).values + +print("\n" + "="*60) +print("PREDICTIONS FOR NEW DATA") +print("="*60) +print(f"{'Index':<10} {'Mean':<15} {'95% HDI Lower':<15} {'95% HDI Upper':<15}") +print("-"*60) +for i in range(len(X_new)): + print(f"{i:<10} {y_pred_mean[i]:<15.3f} {y_pred_hdi[i, 0]:<15.3f} {y_pred_hdi[i, 1]:<15.3f}") + +# ============================================================================= +# 9. SAVE RESULTS +# ============================================================================= + +# Save InferenceData +idata.to_netcdf('linear_regression_results.nc') +print("\nResults saved to 'linear_regression_results.nc'") + +# Save summary to CSV +summary.to_csv('model_summary.csv') +print("Summary saved to 'model_summary.csv'") + +print("\n" + "="*60) +print("ANALYSIS COMPLETE") +print("="*60) diff --git a/scientific-packages/pymc/references/distributions.md b/scientific-packages/pymc/references/distributions.md new file mode 100644 index 0000000..2d9e314 --- /dev/null +++ b/scientific-packages/pymc/references/distributions.md @@ -0,0 +1,320 @@ +# PyMC Distributions Reference + +This reference provides a comprehensive catalog of probability distributions available in PyMC, organized by category. Use this to select appropriate distributions for priors and likelihoods when building Bayesian models. + +## Continuous Distributions + +Continuous distributions define probability densities over real-valued domains. + +### Common Continuous Distributions + +**`pm.Normal(name, mu, sigma)`** +- Normal (Gaussian) distribution +- Parameters: `mu` (mean), `sigma` (standard deviation) +- Support: (-∞, ∞) +- Common uses: Default prior for unbounded parameters, likelihood for continuous data with additive noise + +**`pm.HalfNormal(name, sigma)`** +- Half-normal distribution (positive half of normal) +- Parameters: `sigma` (standard deviation) +- Support: [0, ∞) +- Common uses: Prior for scale/standard deviation parameters + +**`pm.Uniform(name, lower, upper)`** +- Uniform distribution +- Parameters: `lower`, `upper` (bounds) +- Support: [lower, upper] +- Common uses: Weakly informative prior when parameter must be bounded + +**`pm.Beta(name, alpha, beta)`** +- Beta distribution +- Parameters: `alpha`, `beta` (shape parameters) +- Support: [0, 1] +- Common uses: Prior for probabilities and proportions + +**`pm.Gamma(name, alpha, beta)`** +- Gamma distribution +- Parameters: `alpha` (shape), `beta` (rate) +- Support: (0, ∞) +- Common uses: Prior for positive parameters, rate parameters + +**`pm.Exponential(name, lam)`** +- Exponential distribution +- Parameters: `lam` (rate parameter) +- Support: [0, ∞) +- Common uses: Prior for scale parameters, waiting times + +**`pm.LogNormal(name, mu, sigma)`** +- Log-normal distribution +- Parameters: `mu`, `sigma` (parameters of underlying normal) +- Support: (0, ∞) +- Common uses: Prior for positive parameters with multiplicative effects + +**`pm.StudentT(name, nu, mu, sigma)`** +- Student's t-distribution +- Parameters: `nu` (degrees of freedom), `mu` (location), `sigma` (scale) +- Support: (-∞, ∞) +- Common uses: Robust alternative to normal for outlier-resistant models + +**`pm.Cauchy(name, alpha, beta)`** +- Cauchy distribution +- Parameters: `alpha` (location), `beta` (scale) +- Support: (-∞, ∞) +- Common uses: Heavy-tailed alternative to normal + +### Specialized Continuous Distributions + +**`pm.Laplace(name, mu, b)`** - Laplace (double exponential) distribution + +**`pm.AsymmetricLaplace(name, kappa, mu, b)`** - Asymmetric Laplace distribution + +**`pm.InverseGamma(name, alpha, beta)`** - Inverse gamma distribution + +**`pm.Weibull(name, alpha, beta)`** - Weibull distribution for reliability analysis + +**`pm.Logistic(name, mu, s)`** - Logistic distribution + +**`pm.LogitNormal(name, mu, sigma)`** - Logit-normal distribution for (0,1) support + +**`pm.Pareto(name, alpha, m)`** - Pareto distribution for power-law phenomena + +**`pm.ChiSquared(name, nu)`** - Chi-squared distribution + +**`pm.ExGaussian(name, mu, sigma, nu)`** - Exponentially modified Gaussian + +**`pm.VonMises(name, mu, kappa)`** - Von Mises (circular normal) distribution + +**`pm.SkewNormal(name, mu, sigma, alpha)`** - Skew-normal distribution + +**`pm.Triangular(name, lower, c, upper)`** - Triangular distribution + +**`pm.Gumbel(name, mu, beta)`** - Gumbel distribution for extreme values + +**`pm.Rice(name, nu, sigma)`** - Rice (Rician) distribution + +**`pm.Moyal(name, mu, sigma)`** - Moyal distribution + +**`pm.Kumaraswamy(name, a, b)`** - Kumaraswamy distribution (Beta alternative) + +**`pm.Interpolated(name, x_points, pdf_points)`** - Custom distribution from interpolation + +## Discrete Distributions + +Discrete distributions define probabilities over integer-valued domains. + +### Common Discrete Distributions + +**`pm.Bernoulli(name, p)`** +- Bernoulli distribution (binary outcome) +- Parameters: `p` (success probability) +- Support: {0, 1} +- Common uses: Binary classification, coin flips + +**`pm.Binomial(name, n, p)`** +- Binomial distribution +- Parameters: `n` (number of trials), `p` (success probability) +- Support: {0, 1, ..., n} +- Common uses: Number of successes in fixed trials + +**`pm.Poisson(name, mu)`** +- Poisson distribution +- Parameters: `mu` (rate parameter) +- Support: {0, 1, 2, ...} +- Common uses: Count data, rates, occurrences + +**`pm.Categorical(name, p)`** +- Categorical distribution +- Parameters: `p` (probability vector) +- Support: {0, 1, ..., K-1} +- Common uses: Multi-class classification + +**`pm.DiscreteUniform(name, lower, upper)`** +- Discrete uniform distribution +- Parameters: `lower`, `upper` (bounds) +- Support: {lower, ..., upper} +- Common uses: Uniform prior over finite integers + +**`pm.NegativeBinomial(name, mu, alpha)`** +- Negative binomial distribution +- Parameters: `mu` (mean), `alpha` (dispersion) +- Support: {0, 1, 2, ...} +- Common uses: Overdispersed count data + +**`pm.Geometric(name, p)`** +- Geometric distribution +- Parameters: `p` (success probability) +- Support: {0, 1, 2, ...} +- Common uses: Number of failures before first success + +### Specialized Discrete Distributions + +**`pm.BetaBinomial(name, alpha, beta, n)`** - Beta-binomial (overdispersed binomial) + +**`pm.HyperGeometric(name, N, k, n)`** - Hypergeometric distribution + +**`pm.DiscreteWeibull(name, q, beta)`** - Discrete Weibull distribution + +**`pm.OrderedLogistic(name, eta, cutpoints)`** - Ordered logistic for ordinal data + +**`pm.OrderedProbit(name, eta, cutpoints)`** - Ordered probit for ordinal data + +## Multivariate Distributions + +Multivariate distributions define joint probability distributions over vector-valued random variables. + +### Common Multivariate Distributions + +**`pm.MvNormal(name, mu, cov)`** +- Multivariate normal distribution +- Parameters: `mu` (mean vector), `cov` (covariance matrix) +- Common uses: Correlated continuous variables, Gaussian processes + +**`pm.Dirichlet(name, a)`** +- Dirichlet distribution +- Parameters: `a` (concentration parameters) +- Support: Simplex (sums to 1) +- Common uses: Prior for probability vectors, topic modeling + +**`pm.Multinomial(name, n, p)`** +- Multinomial distribution +- Parameters: `n` (number of trials), `p` (probability vector) +- Common uses: Count data across multiple categories + +**`pm.MvStudentT(name, nu, mu, cov)`** +- Multivariate Student's t-distribution +- Parameters: `nu` (degrees of freedom), `mu` (location), `cov` (scale matrix) +- Common uses: Robust multivariate modeling + +### Specialized Multivariate Distributions + +**`pm.LKJCorr(name, n, eta)`** - LKJ correlation matrix prior (for correlation matrices) + +**`pm.LKJCholeskyCov(name, n, eta, sd_dist)`** - LKJ prior with Cholesky decomposition + +**`pm.Wishart(name, nu, V)`** - Wishart distribution (for covariance matrices) + +**`pm.InverseWishart(name, nu, V)`** - Inverse Wishart distribution + +**`pm.MatrixNormal(name, mu, rowcov, colcov)`** - Matrix normal distribution + +**`pm.KroneckerNormal(name, mu, covs, sigma)`** - Kronecker-structured normal + +**`pm.CAR(name, mu, W, alpha, tau)`** - Conditional autoregressive (spatial) + +**`pm.ICAR(name, W, sigma)`** - Intrinsic conditional autoregressive (spatial) + +## Mixture Distributions + +Mixture distributions combine multiple component distributions. + +**`pm.Mixture(name, w, comp_dists)`** +- General mixture distribution +- Parameters: `w` (weights), `comp_dists` (component distributions) +- Common uses: Clustering, multi-modal data + +**`pm.NormalMixture(name, w, mu, sigma)`** +- Mixture of normal distributions +- Common uses: Mixture of Gaussians clustering + +### Zero-Inflated and Hurdle Models + +**`pm.ZeroInflatedPoisson(name, psi, mu)`** - Excess zeros in count data + +**`pm.ZeroInflatedBinomial(name, psi, n, p)`** - Zero-inflated binomial + +**`pm.ZeroInflatedNegativeBinomial(name, psi, mu, alpha)`** - Zero-inflated negative binomial + +**`pm.HurdlePoisson(name, psi, mu)`** - Hurdle Poisson (two-part model) + +**`pm.HurdleGamma(name, psi, alpha, beta)`** - Hurdle gamma + +**`pm.HurdleLogNormal(name, psi, mu, sigma)`** - Hurdle log-normal + +## Time Series Distributions + +Distributions designed for temporal data and sequential modeling. + +**`pm.AR(name, rho, sigma, init_dist)`** +- Autoregressive process +- Parameters: `rho` (AR coefficients), `sigma` (innovation std), `init_dist` (initial distribution) +- Common uses: Time series modeling, sequential data + +**`pm.GaussianRandomWalk(name, mu, sigma, init_dist)`** +- Gaussian random walk +- Parameters: `mu` (drift), `sigma` (step size), `init_dist` (initial value) +- Common uses: Cumulative processes, random walk priors + +**`pm.MvGaussianRandomWalk(name, mu, cov, init_dist)`** +- Multivariate Gaussian random walk + +**`pm.GARCH11(name, omega, alpha_1, beta_1)`** +- GARCH(1,1) volatility model +- Common uses: Financial time series, volatility modeling + +**`pm.EulerMaruyama(name, dt, sde_fn, sde_pars, init_dist)`** +- Stochastic differential equation via Euler-Maruyama discretization +- Common uses: Continuous-time processes + +## Special Distributions + +**`pm.Deterministic(name, var)`** +- Deterministic transformation (not a random variable) +- Use for computed quantities derived from other variables + +**`pm.Potential(name, logp)`** +- Add arbitrary log-probability contribution +- Use for custom likelihood components or constraints + +**`pm.Flat(name)`** +- Improper flat prior (constant density) +- Use sparingly; can cause sampling issues + +**`pm.HalfFlat(name)`** +- Improper flat prior on positive reals +- Use sparingly; can cause sampling issues + +## Distribution Modifiers + +**`pm.Truncated(name, dist, lower, upper)`** +- Truncate any distribution to specified bounds + +**`pm.Censored(name, dist, lower, upper)`** +- Handle censored observations (observed bounds, not exact values) + +**`pm.CustomDist(name, ..., logp, random)`** +- Define custom distributions with user-specified log-probability and random sampling functions + +**`pm.Simulator(name, fn, params, ...)`** +- Custom distributions via simulation (for likelihood-free inference) + +## Usage Tips + +### Choosing Priors + +1. **Scale parameters** (σ, τ): Use `HalfNormal`, `HalfCauchy`, `Exponential`, or `Gamma` +2. **Probabilities**: Use `Beta` or `Uniform(0, 1)` +3. **Unbounded parameters**: Use `Normal` or `StudentT` (for robustness) +4. **Positive parameters**: Use `LogNormal`, `Gamma`, or `Exponential` +5. **Correlation matrices**: Use `LKJCorr` +6. **Count data**: Use `Poisson` or `NegativeBinomial` (for overdispersion) + +### Shape Broadcasting + +PyMC distributions support NumPy-style broadcasting. Use the `shape` parameter to create vectors or arrays of random variables: + +```python +# Vector of 5 independent normals +beta = pm.Normal('beta', mu=0, sigma=1, shape=5) + +# 3x4 matrix of independent gammas +tau = pm.Gamma('tau', alpha=2, beta=1, shape=(3, 4)) +``` + +### Using dims for Named Dimensions + +Instead of shape, use `dims` for more readable models: + +```python +with pm.Model(coords={'predictors': ['age', 'income', 'education']}) as model: + beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') +``` diff --git a/scientific-packages/pymc/references/sampling_inference.md b/scientific-packages/pymc/references/sampling_inference.md new file mode 100644 index 0000000..53d6102 --- /dev/null +++ b/scientific-packages/pymc/references/sampling_inference.md @@ -0,0 +1,424 @@ +# PyMC Sampling and Inference Methods + +This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference. + +## MCMC Sampling Methods + +### Primary Sampling Function + +**`pm.sample(draws=1000, tune=1000, chains=4, **kwargs)`** + +The main interface for MCMC sampling in PyMC. + +**Key Parameters:** +- `draws`: Number of samples to draw per chain (default: 1000) +- `tune`: Number of tuning/warmup samples (default: 1000, discarded) +- `chains`: Number of parallel chains (default: 4) +- `cores`: Number of CPU cores to use (default: all available) +- `target_accept`: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors) +- `random_seed`: Random seed for reproducibility +- `return_inferencedata`: Return ArviZ InferenceData object (default: True) +- `idata_kwargs`: Additional kwargs for InferenceData creation (e.g., `{"log_likelihood": True}` for model comparison) + +**Returns:** InferenceData object containing posterior samples, sampling statistics, and diagnostics + +**Example:** +```python +with pm.Model() as model: + # ... define model ... + idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9) +``` + +### Sampling Algorithms + +PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually. + +#### NUTS (No-U-Turn Sampler) + +**Default algorithm** for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant. + +- Automatically tunes step size and mass matrix +- Adaptive: explores posterior geometry during tuning +- Best for smooth, continuous posteriors +- Can struggle with high correlation or multimodality + +**Manual specification:** +```python +with model: + idata = pm.sample(step=pm.NUTS(target_accept=0.95)) +``` + +**When to adjust:** +- Increase `target_accept` (0.9-0.99) if seeing divergences +- Use `init='adapt_diag'` for faster initialization (default) +- Use `init='jitter+adapt_diag'` for difficult initializations + +#### Metropolis + +General-purpose Metropolis-Hastings sampler. + +- Works for both continuous and discrete variables +- Less efficient than NUTS for smooth continuous posteriors +- Useful for discrete parameters or non-differentiable models +- Requires manual tuning + +**Example:** +```python +with model: + idata = pm.sample(step=pm.Metropolis()) +``` + +#### Slice Sampler + +Slice sampling for univariate distributions. + +- No tuning required +- Good for difficult univariate posteriors +- Can be slow for high dimensions + +**Example:** +```python +with model: + idata = pm.sample(step=pm.Slice()) +``` + +#### CompoundStep + +Combine different samplers for different parameters. + +**Example:** +```python +with model: + # Use NUTS for continuous params, Metropolis for discrete + step1 = pm.NUTS([continuous_var1, continuous_var2]) + step2 = pm.Metropolis([discrete_var]) + idata = pm.sample(step=[step1, step2]) +``` + +### Sampling Diagnostics + +PyMC automatically computes diagnostics. Check these before trusting results: + +#### Effective Sample Size (ESS) + +Measures independent information in correlated samples. + +- **Rule of thumb**: ESS > 400 per chain (1600 total for 4 chains) +- Low ESS indicates high autocorrelation +- Access via: `az.ess(idata)` + +#### R-hat (Gelman-Rubin statistic) + +Measures convergence across chains. + +- **Rule of thumb**: R-hat < 1.01 for all parameters +- R-hat > 1.01 indicates non-convergence +- Access via: `az.rhat(idata)` + +#### Divergences + +Indicate regions where NUTS struggled. + +- **Rule of thumb**: 0 divergences (or very few) +- Divergences suggest biased samples +- **Fix**: Increase `target_accept`, reparameterize, or use stronger priors +- Access via: `idata.sample_stats.diverging.sum()` + +#### Energy Plot + +Visualizes Hamiltonian Monte Carlo energy transitions. + +```python +az.plot_energy(idata) +``` + +Good separation between energy distributions indicates healthy sampling. + +### Handling Sampling Issues + +#### Divergences + +```python +# Increase target acceptance rate +idata = pm.sample(target_accept=0.95) + +# Or reparameterize using non-centered parameterization +# Bad (centered): +mu = pm.Normal('mu', 0, 1) +sigma = pm.HalfNormal('sigma', 1) +x = pm.Normal('x', mu, sigma, observed=data) + +# Good (non-centered): +mu = pm.Normal('mu', 0, 1) +sigma = pm.HalfNormal('sigma', 1) +x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma) +``` + +#### Slow Sampling + +```python +# Use fewer tuning steps if model is simple +idata = pm.sample(tune=500) + +# Increase cores for parallelization +idata = pm.sample(cores=8, chains=8) + +# Use variational inference for initialization +with model: + approx = pm.fit() # Run ADVI + idata = pm.sample(start=approx.sample(return_inferencedata=False)[0]) +``` + +#### High Autocorrelation + +```python +# Increase draws +idata = pm.sample(draws=5000) + +# Reparameterize to reduce correlation +# Consider using QR decomposition for regression models +``` + +## Variational Inference + +Faster approximate inference for large models or quick exploration. + +### ADVI (Automatic Differentiation Variational Inference) + +**`pm.fit(n=10000, method='advi', **kwargs)`** + +Approximates posterior with simpler distribution (typically mean-field Gaussian). + +**Key Parameters:** +- `n`: Number of iterations (default: 10000) +- `method`: VI algorithm ('advi', 'fullrank_advi', 'svgd') +- `random_seed`: Random seed + +**Returns:** Approximation object for sampling and analysis + +**Example:** +```python +with model: + approx = pm.fit(n=50000) + # Draw samples from approximation + idata = approx.sample(1000) + # Or sample for MCMC initialization + start = approx.sample(return_inferencedata=False)[0] +``` + +**Trade-offs:** +- **Pros**: Much faster than MCMC, scales to large data +- **Cons**: Approximate, may miss posterior structure, underestimates uncertainty + +### Full-Rank ADVI + +Captures correlations between parameters. + +```python +with model: + approx = pm.fit(method='fullrank_advi') +``` + +More accurate than mean-field but slower. + +### SVGD (Stein Variational Gradient Descent) + +Non-parametric variational inference. + +```python +with model: + approx = pm.fit(method='svgd', n=20000) +``` + +Better captures multimodality but more computationally expensive. + +## Prior and Posterior Predictive Sampling + +### Prior Predictive Sampling + +Sample from the prior distribution (before seeing data). + +**`pm.sample_prior_predictive(samples=500, **kwargs)`** + +**Purpose:** +- Validate priors are reasonable +- Check implied predictions before fitting +- Ensure model generates plausible data + +**Example:** +```python +with model: + prior_pred = pm.sample_prior_predictive(samples=1000) + +# Visualize prior predictions +az.plot_ppc(prior_pred, group='prior') +``` + +### Posterior Predictive Sampling + +Sample from posterior predictive distribution (after fitting). + +**`pm.sample_posterior_predictive(trace, **kwargs)`** + +**Purpose:** +- Model validation via posterior predictive checks +- Generate predictions for new data +- Assess goodness-of-fit + +**Example:** +```python +with model: + # After sampling + idata = pm.sample() + + # Add posterior predictive samples + pm.sample_posterior_predictive(idata, extend_inferencedata=True) + +# Posterior predictive check +az.plot_ppc(idata) +``` + +### Predictions for New Data + +Update data and sample predictive distribution: + +```python +with model: + # Original model fit + idata = pm.sample() + + # Update with new predictor values + pm.set_data({'X': X_new}) + + # Sample predictions + post_pred_new = pm.sample_posterior_predictive( + idata.posterior, + var_names=['y_pred'] + ) +``` + +## Maximum A Posteriori (MAP) Estimation + +Find posterior mode (point estimate). + +**`pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)`** + +**When to use:** +- Quick point estimates +- Initialization for MCMC +- When full posterior not needed + +**Example:** +```python +with model: + map_estimate = pm.find_MAP() + print(map_estimate) +``` + +**Limitations:** +- Doesn't quantify uncertainty +- Can find local optima in multimodal posteriors +- Sensitive to prior specification + +## Inference Recommendations + +### Standard Workflow + +1. **Start with ADVI** for quick exploration: + ```python + approx = pm.fit(n=20000) + ``` + +2. **Run MCMC** for full inference: + ```python + idata = pm.sample(draws=2000, tune=1000) + ``` + +3. **Check diagnostics**: + ```python + az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars + ``` + +4. **Sample posterior predictive**: + ```python + pm.sample_posterior_predictive(idata, extend_inferencedata=True) + ``` + +### Choosing Inference Method + +| Scenario | Recommended Method | +|----------|-------------------| +| Small-medium models, need full uncertainty | MCMC with NUTS | +| Large models, initial exploration | ADVI | +| Discrete parameters | Metropolis or marginalize | +| Hierarchical models with divergences | Non-centered parameterization + NUTS | +| Very large data | Minibatch ADVI | +| Quick point estimates | MAP or ADVI | + +### Reparameterization Tricks + +**Non-centered parameterization** for hierarchical models: + +```python +# Centered (can cause divergences): +mu = pm.Normal('mu', 0, 10) +sigma = pm.HalfNormal('sigma', 1) +theta = pm.Normal('theta', mu, sigma, shape=n_groups) + +# Non-centered (better sampling): +mu = pm.Normal('mu', 0, 10) +sigma = pm.HalfNormal('sigma', 1) +theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups) +theta = pm.Deterministic('theta', mu + sigma * theta_offset) +``` + +**QR decomposition** for correlated predictors: + +```python +import numpy as np + +# QR decomposition +Q, R = np.linalg.qr(X) + +with pm.Model(): + # Uncorrelated coefficients + beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p) + + # Transform back to original scale + beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde)) + + mu = pm.math.dot(Q, beta_tilde) + sigma = pm.HalfNormal('sigma', 1) + y = pm.Normal('y', mu, sigma, observed=y_obs) +``` + +## Advanced Sampling + +### Sequential Monte Carlo (SMC) + +For complex posteriors or model evidence estimation: + +```python +with model: + idata = pm.sample_smc(draws=2000, chains=4) +``` + +Good for multimodal posteriors or when NUTS struggles. + +### Custom Initialization + +Provide starting values: + +```python +start = {'mu': 0, 'sigma': 1} +with model: + idata = pm.sample(start=start) +``` + +Or use MAP estimate: + +```python +with model: + start = pm.find_MAP() + idata = pm.sample(start=start) +``` diff --git a/scientific-packages/pymc/references/workflows.md b/scientific-packages/pymc/references/workflows.md new file mode 100644 index 0000000..764d9b8 --- /dev/null +++ b/scientific-packages/pymc/references/workflows.md @@ -0,0 +1,526 @@ +# PyMC Workflows and Common Patterns + +This reference provides standard workflows and patterns for building, validating, and analyzing Bayesian models in PyMC. + +## Standard Bayesian Workflow + +### Complete Workflow Template + +```python +import pymc as pm +import arviz as az +import numpy as np +import matplotlib.pyplot as plt + +# 1. PREPARE DATA +# =============== +X = ... # Predictor variables +y = ... # Observed outcomes + +# Standardize predictors for better sampling +X_scaled = (X - X.mean(axis=0)) / X.std(axis=0) + +# 2. BUILD MODEL +# ============== +with pm.Model() as model: + # Define coordinates for named dimensions + coords = { + 'predictors': ['var1', 'var2', 'var3'], + 'obs_id': np.arange(len(y)) + } + + # Priors + alpha = pm.Normal('alpha', mu=0, sigma=1) + beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') + sigma = pm.HalfNormal('sigma', sigma=1) + + # Linear predictor + mu = alpha + pm.math.dot(X_scaled, beta) + + # Likelihood + y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id') + +# 3. PRIOR PREDICTIVE CHECK +# ========================== +with model: + prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42) + +# Visualize prior predictions +az.plot_ppc(prior_pred, group='prior', num_pp_samples=100) +plt.title('Prior Predictive Check') +plt.show() + +# 4. FIT MODEL +# ============ +with model: + # Quick VI exploration (optional) + approx = pm.fit(n=20000, random_seed=42) + + # Full MCMC inference + idata = pm.sample( + draws=2000, + tune=1000, + chains=4, + target_accept=0.9, + random_seed=42, + idata_kwargs={'log_likelihood': True} # For model comparison + ) + +# 5. CHECK DIAGNOSTICS +# ==================== +# Summary statistics +print(az.summary(idata, var_names=['alpha', 'beta', 'sigma'])) + +# R-hat and ESS +summary = az.summary(idata) +if (summary['r_hat'] > 1.01).any(): + print("WARNING: Some R-hat values > 1.01, chains may not have converged") + +if (summary['ess_bulk'] < 400).any(): + print("WARNING: Some ESS values < 400, consider more samples") + +# Check divergences +divergences = idata.sample_stats.diverging.sum().item() +print(f"Number of divergences: {divergences}") + +# Trace plots +az.plot_trace(idata, var_names=['alpha', 'beta', 'sigma']) +plt.tight_layout() +plt.show() + +# 6. POSTERIOR PREDICTIVE CHECK +# ============================== +with model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42) + +# Visualize fit +az.plot_ppc(idata, num_pp_samples=100) +plt.title('Posterior Predictive Check') +plt.show() + +# 7. ANALYZE RESULTS +# ================== +# Posterior distributions +az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma']) +plt.tight_layout() +plt.show() + +# Forest plot for coefficients +az.plot_forest(idata, var_names=['beta'], combined=True) +plt.title('Coefficient Estimates') +plt.show() + +# 8. PREDICTIONS FOR NEW DATA +# ============================ +X_new = ... # New predictor values +X_new_scaled = (X_new - X.mean(axis=0)) / X.std(axis=0) + +with model: + # Update data + pm.set_data({'X': X_new_scaled}) + + # Sample predictions + post_pred = pm.sample_posterior_predictive( + idata.posterior, + var_names=['y_obs'], + random_seed=42 + ) + +# Prediction intervals +y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']) +y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs']) + +# 9. SAVE RESULTS +# =============== +idata.to_netcdf('model_results.nc') # Save for later +``` + +## Model Building Patterns + +### Linear Regression + +```python +with pm.Model() as linear_model: + # Priors + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + sigma = pm.HalfNormal('sigma', sigma=1) + + # Linear predictor + mu = alpha + pm.math.dot(X, beta) + + # Likelihood + y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs) +``` + +### Logistic Regression + +```python +with pm.Model() as logistic_model: + # Priors + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + + # Linear predictor + logit_p = alpha + pm.math.dot(X, beta) + + # Likelihood + y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs) +``` + +### Hierarchical/Multilevel Model + +```python +with pm.Model(coords={'group': group_names, 'obs': np.arange(n_obs)}) as hierarchical_model: + # Hyperpriors + mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10) + sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1) + + mu_beta = pm.Normal('mu_beta', mu=0, sigma=10) + sigma_beta = pm.HalfNormal('sigma_beta', sigma=1) + + # Group-level parameters (non-centered) + alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='group') + alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='group') + + beta_offset = pm.Normal('beta_offset', mu=0, sigma=1, dims='group') + beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_offset, dims='group') + + # Observation-level model + mu = alpha[group_idx] + beta[group_idx] * X + + sigma = pm.HalfNormal('sigma', sigma=1) + y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs, dims='obs') +``` + +### Poisson Regression (Count Data) + +```python +with pm.Model() as poisson_model: + # Priors + alpha = pm.Normal('alpha', mu=0, sigma=10) + beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors) + + # Linear predictor on log scale + log_lambda = alpha + pm.math.dot(X, beta) + + # Likelihood + y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs) +``` + +### Time Series (Autoregressive) + +```python +with pm.Model() as ar_model: + # Innovation standard deviation + sigma = pm.HalfNormal('sigma', sigma=1) + + # AR coefficients + rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order) + + # Initial distribution + init_dist = pm.Normal.dist(mu=0, sigma=sigma) + + # AR process + y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs) +``` + +### Mixture Model + +```python +with pm.Model() as mixture_model: + # Component weights + w = pm.Dirichlet('w', a=np.ones(n_components)) + + # Component parameters + mu = pm.Normal('mu', mu=0, sigma=10, shape=n_components) + sigma = pm.HalfNormal('sigma', sigma=1, shape=n_components) + + # Mixture + components = [pm.Normal.dist(mu=mu[i], sigma=sigma[i]) for i in range(n_components)] + y = pm.Mixture('y', w=w, comp_dists=components, observed=y_obs) +``` + +## Data Preparation Best Practices + +### Standardization + +Standardize continuous predictors for better sampling: + +```python +# Standardize +X_mean = X.mean(axis=0) +X_std = X.std(axis=0) +X_scaled = (X - X_mean) / X_std + +# Model with scaled data +with pm.Model() as model: + beta_scaled = pm.Normal('beta_scaled', 0, 1) + # ... rest of model ... + +# Transform back to original scale +beta_original = beta_scaled / X_std +alpha_original = alpha - (beta_scaled * X_mean / X_std).sum() +``` + +### Handling Missing Data + +Treat missing values as parameters: + +```python +# Identify missing values +missing_idx = np.isnan(X) +X_observed = np.where(missing_idx, 0, X) # Placeholder + +with pm.Model() as model: + # Prior for missing values + X_missing = pm.Normal('X_missing', mu=0, sigma=1, shape=missing_idx.sum()) + + # Combine observed and imputed + X_complete = pm.math.switch(missing_idx.flatten(), X_missing, X_observed.flatten()) + + # ... rest of model using X_complete ... +``` + +### Centering and Scaling + +For regression models, center predictors and outcome: + +```python +# Center +X_centered = X - X.mean(axis=0) +y_centered = y - y.mean() + +with pm.Model() as model: + # Simpler prior on intercept + alpha = pm.Normal('alpha', mu=0, sigma=1) # Intercept near 0 when centered + beta = pm.Normal('beta', mu=0, sigma=1, shape=n_predictors) + + mu = alpha + pm.math.dot(X_centered, beta) + sigma = pm.HalfNormal('sigma', sigma=1) + + y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y_centered) +``` + +## Prior Selection Guidelines + +### Weakly Informative Priors + +Use when you have limited prior knowledge: + +```python +# For standardized predictors +beta = pm.Normal('beta', mu=0, sigma=1) + +# For scale parameters +sigma = pm.HalfNormal('sigma', sigma=1) + +# For probabilities +p = pm.Beta('p', alpha=2, beta=2) # Slight preference for middle values +``` + +### Informative Priors + +Use domain knowledge: + +```python +# Effect size from literature: Cohen's d ≈ 0.3 +beta = pm.Normal('beta', mu=0.3, sigma=0.1) + +# Physical constraint: probability between 0.7-0.9 +p = pm.Beta('p', alpha=8, beta=2) # Check with prior predictive! +``` + +### Prior Predictive Checks + +Always validate priors: + +```python +with model: + prior_pred = pm.sample_prior_predictive(samples=1000) + +# Check if predictions are reasonable +print(f"Prior predictive range: {prior_pred.prior_predictive['y'].min():.2f} to {prior_pred.prior_predictive['y'].max():.2f}") +print(f"Observed range: {y_obs.min():.2f} to {y_obs.max():.2f}") + +# Visualize +az.plot_ppc(prior_pred, group='prior') +``` + +## Model Comparison Workflow + +### Comparing Multiple Models + +```python +import arviz as az + +# Fit multiple models +models = {} +idatas = {} + +# Model 1: Simple linear +with pm.Model() as models['linear']: + # ... define model ... + idatas['linear'] = pm.sample(idata_kwargs={'log_likelihood': True}) + +# Model 2: With interaction +with pm.Model() as models['interaction']: + # ... define model ... + idatas['interaction'] = pm.sample(idata_kwargs={'log_likelihood': True}) + +# Model 3: Hierarchical +with pm.Model() as models['hierarchical']: + # ... define model ... + idatas['hierarchical'] = pm.sample(idata_kwargs={'log_likelihood': True}) + +# Compare using LOO +comparison = az.compare(idatas, ic='loo') +print(comparison) + +# Visualize comparison +az.plot_compare(comparison) +plt.show() + +# Check LOO reliability +for name, idata in idatas.items(): + loo = az.loo(idata, pointwise=True) + high_pareto_k = (loo.pareto_k > 0.7).sum().item() + if high_pareto_k > 0: + print(f"Warning: {name} has {high_pareto_k} observations with high Pareto-k") +``` + +### Model Weights + +```python +# Get model weights (pseudo-BMA) +weights = comparison['weight'].values + +print("Model probabilities:") +for name, weight in zip(comparison.index, weights): + print(f" {name}: {weight:.2%}") + +# Model averaging (weighted predictions) +def weighted_predictions(idatas, weights): + preds = [] + for (name, idata), weight in zip(idatas.items(), weights): + pred = idata.posterior_predictive['y_obs'].mean(dim=['chain', 'draw']) + preds.append(weight * pred) + return sum(preds) + +averaged_pred = weighted_predictions(idatas, weights) +``` + +## Diagnostics and Troubleshooting + +### Diagnosing Sampling Problems + +```python +def diagnose_sampling(idata, var_names=None): + """Comprehensive sampling diagnostics""" + + # Check convergence + summary = az.summary(idata, var_names=var_names) + + print("=== Convergence Diagnostics ===") + bad_rhat = summary[summary['r_hat'] > 1.01] + if len(bad_rhat) > 0: + print(f"⚠️ {len(bad_rhat)} variables with R-hat > 1.01") + print(bad_rhat[['r_hat']]) + else: + print("✓ All R-hat values < 1.01") + + # Check effective sample size + print("\n=== Effective Sample Size ===") + low_ess = summary[summary['ess_bulk'] < 400] + if len(low_ess) > 0: + print(f"⚠️ {len(low_ess)} variables with ESS < 400") + print(low_ess[['ess_bulk', 'ess_tail']]) + else: + print("✓ All ESS values > 400") + + # Check divergences + print("\n=== Divergences ===") + divergences = idata.sample_stats.diverging.sum().item() + if divergences > 0: + print(f"⚠️ {divergences} divergent transitions") + print(" Consider: increase target_accept, reparameterize, or stronger priors") + else: + print("✓ No divergences") + + # Check tree depth + print("\n=== NUTS Statistics ===") + max_treedepth = idata.sample_stats.tree_depth.max().item() + hits_max = (idata.sample_stats.tree_depth == max_treedepth).sum().item() + if hits_max > 0: + print(f"⚠️ Hit max treedepth {hits_max} times") + print(" Consider: reparameterize or increase max_treedepth") + else: + print(f"✓ No max treedepth issues (max: {max_treedepth})") + + return summary + +# Usage +diagnose_sampling(idata, var_names=['alpha', 'beta', 'sigma']) +``` + +### Common Fixes + +| Problem | Solution | +|---------|----------| +| Divergences | Increase `target_accept=0.95`, use non-centered parameterization | +| Low ESS | Sample more draws, reparameterize to reduce correlation | +| High R-hat | Run longer chains, check for multimodality, improve initialization | +| Slow sampling | Use ADVI initialization, reparameterize, reduce model complexity | +| Biased posterior | Check prior predictive, ensure likelihood is correct | + +## Using Named Dimensions (dims) + +### Benefits of dims + +- More readable code +- Easier subsetting and analysis +- Better xarray integration + +```python +# Define coordinates +coords = { + 'predictors': ['age', 'income', 'education'], + 'groups': ['A', 'B', 'C'], + 'time': pd.date_range('2020-01-01', periods=100, freq='D') +} + +with pm.Model(coords=coords) as model: + # Use dims instead of shape + beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors') + alpha = pm.Normal('alpha', mu=0, sigma=1, dims='groups') + y = pm.Normal('y', mu=0, sigma=1, dims=['groups', 'time'], observed=data) + +# After sampling, dimensions are preserved +idata = pm.sample() + +# Easy subsetting +beta_age = idata.posterior['beta'].sel(predictors='age') +group_A = idata.posterior['alpha'].sel(groups='A') +``` + +## Saving and Loading Results + +```python +# Save InferenceData +idata.to_netcdf('results.nc') + +# Load InferenceData +loaded_idata = az.from_netcdf('results.nc') + +# Save model for later predictions +import pickle + +with open('model.pkl', 'wb') as f: + pickle.dump({'model': model, 'idata': idata}, f) + +# Load model +with open('model.pkl', 'rb') as f: + saved = pickle.load(f) + model = saved['model'] + idata = saved['idata'] +``` diff --git a/scientific-packages/pymc/scripts/model_comparison.py b/scientific-packages/pymc/scripts/model_comparison.py new file mode 100644 index 0000000..5c4c537 --- /dev/null +++ b/scientific-packages/pymc/scripts/model_comparison.py @@ -0,0 +1,387 @@ +""" +PyMC Model Comparison Script + +Utilities for comparing multiple Bayesian models using information criteria +and cross-validation metrics. + +Usage: + from scripts.model_comparison import compare_models, plot_model_comparison + + # Compare multiple models + comparison = compare_models( + {'model1': idata1, 'model2': idata2, 'model3': idata3}, + ic='loo' + ) + + # Visualize comparison + plot_model_comparison(comparison, output_path='model_comparison.png') +""" + +import arviz as az +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from typing import Dict + + +def compare_models(models_dict: Dict[str, az.InferenceData], + ic='loo', + scale='deviance', + verbose=True): + """ + Compare multiple models using information criteria. + + Parameters + ---------- + models_dict : dict + Dictionary mapping model names to InferenceData objects. + All models must have log_likelihood computed. + ic : str + Information criterion to use: 'loo' (default) or 'waic' + scale : str + Scale for IC: 'deviance' (default), 'log', or 'negative_log' + verbose : bool + Print detailed comparison results (default: True) + + Returns + ------- + pd.DataFrame + Comparison DataFrame with model rankings and statistics + + Notes + ----- + Models must be fit with idata_kwargs={'log_likelihood': True} or + log-likelihood computed afterwards with pm.compute_log_likelihood(). + """ + if verbose: + print("="*70) + print(f" " * 25 + f"MODEL COMPARISON ({ic.upper()})") + print("="*70) + + # Perform comparison + comparison = az.compare(models_dict, ic=ic, scale=scale) + + if verbose: + print("\nModel Rankings:") + print("-"*70) + print(comparison.to_string()) + + print("\n" + "="*70) + print("INTERPRETATION GUIDE") + print("="*70) + print(f"• rank: Model ranking (0 = best)") + print(f"• {ic}: {ic.upper()} estimate (lower is better)") + print(f"• p_{ic}: Effective number of parameters") + print(f"• d{ic}: Difference from best model") + print(f"• weight: Model probability (pseudo-BMA)") + print(f"• se: Standard error of {ic.upper()}") + print(f"• dse: Standard error of the difference") + print(f"• warning: True if model has reliability issues") + print(f"• scale: {scale}") + + print("\n" + "="*70) + print("MODEL SELECTION GUIDELINES") + print("="*70) + + best_model = comparison.index[0] + print(f"\n✓ Best model: {best_model}") + + # Check for clear winner + if len(comparison) > 1: + delta = comparison.iloc[1][f'd{ic}'] + delta_se = comparison.iloc[1]['dse'] + + if delta > 10: + print(f" → STRONG evidence for {best_model} (Δ{ic} > 10)") + elif delta > 4: + print(f" → MODERATE evidence for {best_model} (4 < Δ{ic} < 10)") + elif delta > 2: + print(f" → WEAK evidence for {best_model} (2 < Δ{ic} < 4)") + else: + print(f" → Models are SIMILAR (Δ{ic} < 2)") + print(f" Consider model averaging or choose based on simplicity") + + # Check if difference is significant relative to SE + if delta > 2 * delta_se: + print(f" → Difference is > 2 SE, likely reliable") + else: + print(f" → Difference is < 2 SE, uncertain distinction") + + # Check for warnings + if comparison['warning'].any(): + print("\n⚠️ WARNING: Some models have reliability issues") + warned_models = comparison[comparison['warning']].index.tolist() + print(f" Models with warnings: {', '.join(warned_models)}") + print(f" → Check Pareto-k diagnostics with check_loo_reliability()") + + return comparison + + +def check_loo_reliability(models_dict: Dict[str, az.InferenceData], + threshold=0.7, + verbose=True): + """ + Check LOO-CV reliability using Pareto-k diagnostics. + + Parameters + ---------- + models_dict : dict + Dictionary mapping model names to InferenceData objects + threshold : float + Pareto-k threshold for flagging observations (default: 0.7) + verbose : bool + Print detailed diagnostics (default: True) + + Returns + ------- + dict + Dictionary with Pareto-k diagnostics for each model + """ + if verbose: + print("="*70) + print(" " * 20 + "LOO RELIABILITY CHECK") + print("="*70) + + results = {} + + for name, idata in models_dict.items(): + if verbose: + print(f"\n{name}:") + print("-"*70) + + # Compute LOO with pointwise results + loo_result = az.loo(idata, pointwise=True) + pareto_k = loo_result.pareto_k.values + + # Count problematic observations + n_high = (pareto_k > threshold).sum() + n_very_high = (pareto_k > 1.0).sum() + + results[name] = { + 'pareto_k': pareto_k, + 'n_high': n_high, + 'n_very_high': n_very_high, + 'max_k': pareto_k.max(), + 'loo': loo_result + } + + if verbose: + print(f"Pareto-k diagnostics:") + print(f" • Good (k < 0.5): {(pareto_k < 0.5).sum()} observations") + print(f" • OK (0.5 ≤ k < 0.7): {((pareto_k >= 0.5) & (pareto_k < 0.7)).sum()} observations") + print(f" • Bad (0.7 ≤ k < 1.0): {((pareto_k >= 0.7) & (pareto_k < 1.0)).sum()} observations") + print(f" • Very bad (k ≥ 1.0): {(pareto_k >= 1.0).sum()} observations") + print(f" • Maximum k: {pareto_k.max():.3f}") + + if n_high > 0: + print(f"\n⚠️ {n_high} observations with k > {threshold}") + print(" LOO approximation may be unreliable for these points") + print(" Solutions:") + print(" → Use WAIC instead (less sensitive to outliers)") + print(" → Investigate influential observations") + print(" → Consider more flexible model") + + if n_very_high > 0: + print(f"\n⚠️ {n_very_high} observations with k > 1.0") + print(" These points have very high influence") + print(" → Strongly consider K-fold CV or other validation") + else: + print(f"✓ All Pareto-k values < {threshold}") + print(" LOO estimates are reliable") + + return results + + +def plot_model_comparison(comparison, output_path=None, show=True): + """ + Visualize model comparison results. + + Parameters + ---------- + comparison : pd.DataFrame + Comparison DataFrame from az.compare() + output_path : str, optional + If provided, save plot to this path + show : bool + Whether to display plot (default: True) + + Returns + ------- + matplotlib.figure.Figure + The comparison figure + """ + fig = plt.figure(figsize=(10, 6)) + az.plot_compare(comparison) + plt.title('Model Comparison', fontsize=14, fontweight='bold') + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Comparison plot saved to {output_path}") + + if show: + plt.show() + else: + plt.close() + + return fig + + +def model_averaging(models_dict: Dict[str, az.InferenceData], + weights=None, + var_name='y_obs', + ic='loo'): + """ + Perform Bayesian model averaging using model weights. + + Parameters + ---------- + models_dict : dict + Dictionary mapping model names to InferenceData objects + weights : array-like, optional + Model weights. If None, computed from IC (pseudo-BMA weights) + var_name : str + Name of the predicted variable (default: 'y_obs') + ic : str + Information criterion for computing weights if not provided + + Returns + ------- + np.ndarray + Averaged predictions across models + np.ndarray + Model weights used + """ + if weights is None: + comparison = az.compare(models_dict, ic=ic) + weights = comparison['weight'].values + model_names = comparison.index.tolist() + else: + model_names = list(models_dict.keys()) + weights = np.array(weights) + weights = weights / weights.sum() # Normalize + + print("="*70) + print(" " * 22 + "BAYESIAN MODEL AVERAGING") + print("="*70) + print("\nModel weights:") + for name, weight in zip(model_names, weights): + print(f" {name}: {weight:.4f} ({weight*100:.2f}%)") + + # Extract predictions and average + predictions = [] + for name in model_names: + idata = models_dict[name] + if 'posterior_predictive' in idata: + pred = idata.posterior_predictive[var_name].values + else: + print(f"Warning: {name} missing posterior_predictive, skipping") + continue + predictions.append(pred) + + # Weighted average + averaged = sum(w * p for w, p in zip(weights, predictions)) + + print(f"\n✓ Model averaging complete") + print(f" Combined predictions using {len(predictions)} models") + + return averaged, weights + + +def cross_validation_comparison(models_dict: Dict[str, az.InferenceData], + k=10, + verbose=True): + """ + Perform k-fold cross-validation comparison (conceptual guide). + + Note: This function provides guidance. Full k-fold CV requires + re-fitting models k times, which should be done in the main script. + + Parameters + ---------- + models_dict : dict + Dictionary of model names to InferenceData + k : int + Number of folds (default: 10) + verbose : bool + Print guidance + + Returns + ------- + None + """ + if verbose: + print("="*70) + print(" " * 20 + "K-FOLD CROSS-VALIDATION GUIDE") + print("="*70) + print(f"\nTo perform {k}-fold CV:") + print(""" +1. Split data into k folds +2. For each fold: + - Train all models on k-1 folds + - Compute log-likelihood on held-out fold +3. Sum log-likelihoods across folds for each model +4. Compare models using total CV score + +Example code: +------------- +from sklearn.model_selection import KFold + +kf = KFold(n_splits=k, shuffle=True, random_seed=42) +cv_scores = {name: [] for name in models_dict.keys()} + +for train_idx, test_idx in kf.split(X): + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + + for name in models_dict.keys(): + # Fit model on train set + with create_model(name, X_train, y_train) as model: + idata = pm.sample() + + # Compute log-likelihood on test set + with model: + pm.set_data({'X': X_test, 'y': y_test}) + log_lik = pm.compute_log_likelihood(idata).sum() + + cv_scores[name].append(log_lik) + +# Compare total CV scores +for name, scores in cv_scores.items(): + print(f"{name}: {np.sum(scores):.2f}") + """) + + print("\nNote: K-fold CV is expensive but most reliable for model comparison") + print(" Use when LOO has reliability issues (high Pareto-k values)") + + +# Example usage +if __name__ == '__main__': + print("This script provides model comparison utilities for PyMC.") + print("\nExample usage:") + print(""" + import pymc as pm + from scripts.model_comparison import compare_models, check_loo_reliability + + # Fit multiple models (must include log_likelihood) + with pm.Model() as model1: + # ... define model 1 ... + idata1 = pm.sample(idata_kwargs={'log_likelihood': True}) + + with pm.Model() as model2: + # ... define model 2 ... + idata2 = pm.sample(idata_kwargs={'log_likelihood': True}) + + # Compare models + models = {'Simple': idata1, 'Complex': idata2} + comparison = compare_models(models, ic='loo') + + # Check reliability + reliability = check_loo_reliability(models) + + # Visualize + plot_model_comparison(comparison, output_path='comparison.png') + + # Model averaging + averaged_pred, weights = model_averaging(models, var_name='y_obs') + """) diff --git a/scientific-packages/pymc/scripts/model_diagnostics.py b/scientific-packages/pymc/scripts/model_diagnostics.py new file mode 100644 index 0000000..9064d1b --- /dev/null +++ b/scientific-packages/pymc/scripts/model_diagnostics.py @@ -0,0 +1,350 @@ +""" +PyMC Model Diagnostics Script + +Comprehensive diagnostic checks for PyMC models. +Run this after sampling to validate results before interpretation. + +Usage: + from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report + + # Quick check + check_diagnostics(idata) + + # Full report with plots + create_diagnostic_report(idata, var_names=['alpha', 'beta', 'sigma'], output_dir='diagnostics/') +""" + +import arviz as az +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + + +def check_diagnostics(idata, var_names=None, ess_threshold=400, rhat_threshold=1.01): + """ + Perform comprehensive diagnostic checks on MCMC samples. + + Parameters + ---------- + idata : arviz.InferenceData + InferenceData object from pm.sample() + var_names : list, optional + Variables to check. If None, checks all model parameters + ess_threshold : int + Minimum acceptable effective sample size (default: 400) + rhat_threshold : float + Maximum acceptable R-hat value (default: 1.01) + + Returns + ------- + dict + Dictionary with diagnostic results and flags + """ + print("="*70) + print(" " * 20 + "MCMC DIAGNOSTICS REPORT") + print("="*70) + + # Get summary statistics + summary = az.summary(idata, var_names=var_names) + + results = { + 'summary': summary, + 'has_issues': False, + 'issues': [] + } + + # 1. Check R-hat (convergence) + print("\n1. CONVERGENCE CHECK (R-hat)") + print("-" * 70) + bad_rhat = summary[summary['r_hat'] > rhat_threshold] + + if len(bad_rhat) > 0: + print(f"⚠️ WARNING: {len(bad_rhat)} parameters have R-hat > {rhat_threshold}") + print("\nTop 10 worst R-hat values:") + print(bad_rhat[['r_hat']].sort_values('r_hat', ascending=False).head(10)) + print("\n⚠️ Chains may not have converged!") + print(" → Run longer chains or check for multimodality") + results['has_issues'] = True + results['issues'].append('convergence') + else: + print(f"✓ All R-hat values ≤ {rhat_threshold}") + print(" Chains have converged successfully") + + # 2. Check Effective Sample Size + print("\n2. EFFECTIVE SAMPLE SIZE (ESS)") + print("-" * 70) + low_ess_bulk = summary[summary['ess_bulk'] < ess_threshold] + low_ess_tail = summary[summary['ess_tail'] < ess_threshold] + + if len(low_ess_bulk) > 0 or len(low_ess_tail) > 0: + print(f"⚠️ WARNING: Some parameters have ESS < {ess_threshold}") + + if len(low_ess_bulk) > 0: + print(f"\n Bulk ESS issues ({len(low_ess_bulk)} parameters):") + print(low_ess_bulk[['ess_bulk']].sort_values('ess_bulk').head(10)) + + if len(low_ess_tail) > 0: + print(f"\n Tail ESS issues ({len(low_ess_tail)} parameters):") + print(low_ess_tail[['ess_tail']].sort_values('ess_tail').head(10)) + + print("\n⚠️ High autocorrelation detected!") + print(" → Sample more draws or reparameterize to reduce correlation") + results['has_issues'] = True + results['issues'].append('low_ess') + else: + print(f"✓ All ESS values ≥ {ess_threshold}") + print(" Sufficient effective samples") + + # 3. Check Divergences + print("\n3. DIVERGENT TRANSITIONS") + print("-" * 70) + divergences = idata.sample_stats.diverging.sum().item() + + if divergences > 0: + total_samples = len(idata.posterior.draw) * len(idata.posterior.chain) + divergence_rate = divergences / total_samples * 100 + + print(f"⚠️ WARNING: {divergences} divergent transitions ({divergence_rate:.2f}% of samples)") + print("\n Divergences indicate biased sampling in difficult posterior regions") + print(" Solutions:") + print(" → Increase target_accept (e.g., target_accept=0.95 or 0.99)") + print(" → Use non-centered parameterization for hierarchical models") + print(" → Add stronger/more informative priors") + print(" → Check for model misspecification") + results['has_issues'] = True + results['issues'].append('divergences') + results['n_divergences'] = divergences + else: + print("✓ No divergences detected") + print(" NUTS explored the posterior successfully") + + # 4. Check Tree Depth + print("\n4. TREE DEPTH") + print("-" * 70) + tree_depth = idata.sample_stats.tree_depth + max_tree_depth = tree_depth.max().item() + + # Typical max_treedepth is 10 (default in PyMC) + hits_max = (tree_depth >= 10).sum().item() + + if hits_max > 0: + total_samples = len(idata.posterior.draw) * len(idata.posterior.chain) + hit_rate = hits_max / total_samples * 100 + + print(f"⚠️ WARNING: Hit maximum tree depth {hits_max} times ({hit_rate:.2f}% of samples)") + print("\n Model may be difficult to explore efficiently") + print(" Solutions:") + print(" → Reparameterize model to improve geometry") + print(" → Increase max_treedepth (if necessary)") + results['issues'].append('max_treedepth') + else: + print(f"✓ No maximum tree depth issues") + print(f" Maximum tree depth reached: {max_tree_depth}") + + # 5. Check Energy (if available) + if hasattr(idata.sample_stats, 'energy'): + print("\n5. ENERGY DIAGNOSTICS") + print("-" * 70) + print("✓ Energy statistics available") + print(" Use az.plot_energy(idata) to visualize energy transitions") + print(" Good separation indicates healthy HMC sampling") + + # Summary + print("\n" + "="*70) + print("SUMMARY") + print("="*70) + + if not results['has_issues']: + print("✓ All diagnostics passed!") + print(" Your model has sampled successfully.") + print(" Proceed with inference and interpretation.") + else: + print("⚠️ Some diagnostics failed!") + print(f" Issues found: {', '.join(results['issues'])}") + print(" Review warnings above and consider re-running with adjustments.") + + print("="*70) + + return results + + +def create_diagnostic_report(idata, var_names=None, output_dir='diagnostics/', show=False): + """ + Create comprehensive diagnostic report with plots. + + Parameters + ---------- + idata : arviz.InferenceData + InferenceData object from pm.sample() + var_names : list, optional + Variables to plot. If None, uses all model parameters + output_dir : str + Directory to save diagnostic plots + show : bool + Whether to display plots (default: False, just save) + + Returns + ------- + dict + Diagnostic results from check_diagnostics + """ + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Run diagnostic checks + results = check_diagnostics(idata, var_names=var_names) + + print(f"\nGenerating diagnostic plots in '{output_dir}'...") + + # 1. Trace plots + fig, axes = plt.subplots( + len(var_names) if var_names else 5, + 2, + figsize=(12, 10) + ) + az.plot_trace(idata, var_names=var_names, axes=axes) + plt.tight_layout() + plt.savefig(output_path / 'trace_plots.png', dpi=300, bbox_inches='tight') + print(f" ✓ Saved trace plots") + if show: + plt.show() + else: + plt.close() + + # 2. Rank plots (check mixing) + fig = plt.figure(figsize=(12, 8)) + az.plot_rank(idata, var_names=var_names) + plt.tight_layout() + plt.savefig(output_path / 'rank_plots.png', dpi=300, bbox_inches='tight') + print(f" ✓ Saved rank plots") + if show: + plt.show() + else: + plt.close() + + # 3. Autocorrelation plots + fig = plt.figure(figsize=(12, 8)) + az.plot_autocorr(idata, var_names=var_names, combined=True) + plt.tight_layout() + plt.savefig(output_path / 'autocorr_plots.png', dpi=300, bbox_inches='tight') + print(f" ✓ Saved autocorrelation plots") + if show: + plt.show() + else: + plt.close() + + # 4. Energy plot (if available) + if hasattr(idata.sample_stats, 'energy'): + fig = plt.figure(figsize=(10, 6)) + az.plot_energy(idata) + plt.tight_layout() + plt.savefig(output_path / 'energy_plot.png', dpi=300, bbox_inches='tight') + print(f" ✓ Saved energy plot") + if show: + plt.show() + else: + plt.close() + + # 5. ESS plot + fig = plt.figure(figsize=(10, 6)) + az.plot_ess(idata, var_names=var_names, kind='evolution') + plt.tight_layout() + plt.savefig(output_path / 'ess_evolution.png', dpi=300, bbox_inches='tight') + print(f" ✓ Saved ESS evolution plot") + if show: + plt.show() + else: + plt.close() + + # Save summary to CSV + results['summary'].to_csv(output_path / 'summary_statistics.csv') + print(f" ✓ Saved summary statistics") + + print(f"\nDiagnostic report complete! Files saved in '{output_dir}'") + + return results + + +def compare_prior_posterior(idata, prior_idata, var_names=None, output_path=None): + """ + Compare prior and posterior distributions. + + Parameters + ---------- + idata : arviz.InferenceData + InferenceData with posterior samples + prior_idata : arviz.InferenceData + InferenceData with prior samples + var_names : list, optional + Variables to compare + output_path : str, optional + If provided, save plot to this path + + Returns + ------- + None + """ + fig, axes = plt.subplots( + len(var_names) if var_names else 3, + 1, + figsize=(10, 8) + ) + + if not isinstance(axes, np.ndarray): + axes = [axes] + + for idx, var in enumerate(var_names if var_names else list(idata.posterior.data_vars)[:3]): + # Plot prior + az.plot_dist( + prior_idata.prior[var].values.flatten(), + label='Prior', + ax=axes[idx], + color='blue', + alpha=0.3 + ) + + # Plot posterior + az.plot_dist( + idata.posterior[var].values.flatten(), + label='Posterior', + ax=axes[idx], + color='green', + alpha=0.3 + ) + + axes[idx].set_title(f'{var}: Prior vs Posterior') + axes[idx].legend() + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Prior-posterior comparison saved to {output_path}") + else: + plt.show() + + +# Example usage +if __name__ == '__main__': + print("This script provides diagnostic functions for PyMC models.") + print("\nExample usage:") + print(""" + import pymc as pm + from scripts.model_diagnostics import check_diagnostics, create_diagnostic_report + + # After sampling + with pm.Model() as model: + # ... define model ... + idata = pm.sample() + + # Quick diagnostic check + results = check_diagnostics(idata) + + # Full diagnostic report with plots + create_diagnostic_report( + idata, + var_names=['alpha', 'beta', 'sigma'], + output_dir='my_diagnostics/' + ) + """) diff --git a/scientific-packages/pymoo/SKILL.md b/scientific-packages/pymoo/SKILL.md new file mode 100644 index 0000000..d717349 --- /dev/null +++ b/scientific-packages/pymoo/SKILL.md @@ -0,0 +1,565 @@ +--- +name: pymoo +description: Multi-objective optimization framework for Python. Use this skill when working with optimization problems including single-objective, multi-objective, many-objective, constrained, or dynamic optimization. Apply when tasks involve finding optimal solutions, trade-off analysis, Pareto fronts, evolutionary algorithms (NSGA-II, NSGA-III, MOEA/D), genetic operators, constraint handling, or multi-criteria decision making. Relevant for engineering design optimization, portfolio allocation, combinatorial problems, and benchmarking optimization algorithms. +--- + +# Pymoo - Multi-Objective Optimization in Python + +## Overview + +Pymoo is a comprehensive Python framework for solving optimization problems with emphasis on multi-objective optimization. The library provides state-of-the-art single-objective and multi-objective algorithms, extensive benchmark problems, customizable genetic operators, advanced visualization tools, and multi-criteria decision making methods. Pymoo excels at finding trade-off solutions (Pareto fronts) for problems with conflicting objectives. + +## When to Use This Skill + +Apply this skill when: +- Solving optimization problems with one or multiple objectives +- Finding Pareto-optimal solutions and analyzing trade-offs +- Implementing evolutionary algorithms (GA, DE, PSO, NSGA-II/III) +- Working with constrained optimization problems +- Benchmarking algorithms on standard test problems (ZDT, DTLZ, WFG) +- Customizing genetic operators (crossover, mutation, selection) +- Visualizing high-dimensional optimization results +- Making decisions from multiple competing solutions +- Handling binary, discrete, continuous, or mixed-variable problems + +## Core Concepts + +### The Unified Interface + +Pymoo uses a consistent `minimize()` function for all optimization tasks: + +```python +from pymoo.optimize import minimize + +result = minimize( + problem, # What to optimize + algorithm, # How to optimize + termination, # When to stop + seed=1, + verbose=True +) +``` + +**Result object contains:** +- `result.X`: Decision variables of optimal solution(s) +- `result.F`: Objective values of optimal solution(s) +- `result.G`: Constraint violations (if constrained) +- `result.algorithm`: Algorithm object with history + +### Problem Types + +**Single-objective:** One objective to minimize/maximize +**Multi-objective:** 2-3 conflicting objectives → Pareto front +**Many-objective:** 4+ objectives → High-dimensional Pareto front +**Constrained:** Objectives + inequality/equality constraints +**Dynamic:** Time-varying objectives or constraints + +## Quick Start Workflows + +### Workflow 1: Single-Objective Optimization + +**When:** Optimizing one objective function + +**Steps:** +1. Define or select problem +2. Choose single-objective algorithm (GA, DE, PSO, CMA-ES) +3. Configure termination criteria +4. Run optimization +5. Extract best solution + +**Example:** +```python +from pymoo.algorithms.soo.nonconvex.ga import GA +from pymoo.problems import get_problem +from pymoo.optimize import minimize + +# Built-in problem +problem = get_problem("rastrigin", n_var=10) + +# Configure Genetic Algorithm +algorithm = GA( + pop_size=100, + eliminate_duplicates=True +) + +# Optimize +result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=True +) + +print(f"Best solution: {result.X}") +print(f"Best objective: {result.F[0]}") +``` + +**See:** `scripts/single_objective_example.py` for complete example + +### Workflow 2: Multi-Objective Optimization (2-3 objectives) + +**When:** Optimizing 2-3 conflicting objectives, need Pareto front + +**Algorithm choice:** NSGA-II (standard for bi/tri-objective) + +**Steps:** +1. Define multi-objective problem +2. Configure NSGA-II +3. Run optimization to obtain Pareto front +4. Visualize trade-offs +5. Apply decision making (optional) + +**Example:** +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.visualization.scatter import Scatter + +# Bi-objective benchmark problem +problem = get_problem("zdt1") + +# NSGA-II algorithm +algorithm = NSGA2(pop_size=100) + +# Optimize +result = minimize(problem, algorithm, ('n_gen', 200), seed=1) + +# Visualize Pareto front +plot = Scatter() +plot.add(result.F, label="Obtained Front") +plot.add(problem.pareto_front(), label="True Front", alpha=0.3) +plot.show() + +print(f"Found {len(result.F)} Pareto-optimal solutions") +``` + +**See:** `scripts/multi_objective_example.py` for complete example + +### Workflow 3: Many-Objective Optimization (4+ objectives) + +**When:** Optimizing 4 or more objectives + +**Algorithm choice:** NSGA-III (designed for many objectives) + +**Key difference:** Must provide reference directions for population guidance + +**Steps:** +1. Define many-objective problem +2. Generate reference directions +3. Configure NSGA-III with reference directions +4. Run optimization +5. Visualize using Parallel Coordinate Plot + +**Example:** +```python +from pymoo.algorithms.moo.nsga3 import NSGA3 +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.util.ref_dirs import get_reference_directions +from pymoo.visualization.pcp import PCP + +# Many-objective problem (5 objectives) +problem = get_problem("dtlz2", n_obj=5) + +# Generate reference directions (required for NSGA-III) +ref_dirs = get_reference_directions("das-dennis", n_dim=5, n_partitions=12) + +# Configure NSGA-III +algorithm = NSGA3(ref_dirs=ref_dirs) + +# Optimize +result = minimize(problem, algorithm, ('n_gen', 300), seed=1) + +# Visualize with Parallel Coordinates +plot = PCP(labels=[f"f{i+1}" for i in range(5)]) +plot.add(result.F, alpha=0.3) +plot.show() +``` + +**See:** `scripts/many_objective_example.py` for complete example + +### Workflow 4: Custom Problem Definition + +**When:** Solving domain-specific optimization problem + +**Steps:** +1. Extend `ElementwiseProblem` class +2. Define `__init__` with problem dimensions and bounds +3. Implement `_evaluate` method for objectives (and constraints) +4. Use with any algorithm + +**Unconstrained example:** +```python +from pymoo.core.problem import ElementwiseProblem +import numpy as np + +class MyProblem(ElementwiseProblem): + def __init__(self): + super().__init__( + n_var=2, # Number of variables + n_obj=2, # Number of objectives + xl=np.array([0, 0]), # Lower bounds + xu=np.array([5, 5]) # Upper bounds + ) + + def _evaluate(self, x, out, *args, **kwargs): + # Define objectives + f1 = x[0]**2 + x[1]**2 + f2 = (x[0]-1)**2 + (x[1]-1)**2 + + out["F"] = [f1, f2] +``` + +**Constrained example:** +```python +class ConstrainedProblem(ElementwiseProblem): + def __init__(self): + super().__init__( + n_var=2, + n_obj=2, + n_ieq_constr=2, # Inequality constraints + n_eq_constr=1, # Equality constraints + xl=np.array([0, 0]), + xu=np.array([5, 5]) + ) + + def _evaluate(self, x, out, *args, **kwargs): + # Objectives + out["F"] = [f1, f2] + + # Inequality constraints (g <= 0) + out["G"] = [g1, g2] + + # Equality constraints (h = 0) + out["H"] = [h1] +``` + +**Constraint formulation rules:** +- Inequality: Express as `g(x) <= 0` (feasible when ≤ 0) +- Equality: Express as `h(x) = 0` (feasible when = 0) +- Convert `g(x) >= b` to `-(g(x) - b) <= 0` + +**See:** `scripts/custom_problem_example.py` for complete examples + +### Workflow 5: Constraint Handling + +**When:** Problem has feasibility constraints + +**Approach options:** + +**1. Feasibility First (Default - Recommended)** +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 + +# Works automatically with constrained problems +algorithm = NSGA2(pop_size=100) +result = minimize(problem, algorithm, termination) + +# Check feasibility +feasible = result.CV[:, 0] == 0 # CV = constraint violation +print(f"Feasible solutions: {np.sum(feasible)}") +``` + +**2. Penalty Method** +```python +from pymoo.constraints.as_penalty import ConstraintsAsPenalty + +# Wrap problem to convert constraints to penalties +problem_penalized = ConstraintsAsPenalty(problem, penalty=1e6) +``` + +**3. Constraint as Objective** +```python +from pymoo.constraints.as_obj import ConstraintsAsObjective + +# Treat constraint violation as additional objective +problem_with_cv = ConstraintsAsObjective(problem) +``` + +**4. Specialized Algorithms** +```python +from pymoo.algorithms.soo.nonconvex.sres import SRES + +# SRES has built-in constraint handling +algorithm = SRES() +``` + +**See:** `references/constraints_mcdm.md` for comprehensive constraint handling guide + +### Workflow 6: Decision Making from Pareto Front + +**When:** Have Pareto front, need to select preferred solution(s) + +**Steps:** +1. Run multi-objective optimization +2. Normalize objectives to [0, 1] +3. Define preference weights +4. Apply MCDM method +5. Visualize selected solution + +**Example using Pseudo-Weights:** +```python +from pymoo.mcdm.pseudo_weights import PseudoWeights +import numpy as np + +# After obtaining result from multi-objective optimization +# Normalize objectives +F_norm = (result.F - result.F.min(axis=0)) / (result.F.max(axis=0) - result.F.min(axis=0)) + +# Define preferences (must sum to 1) +weights = np.array([0.3, 0.7]) # 30% f1, 70% f2 + +# Apply decision making +dm = PseudoWeights(weights) +selected_idx = dm.do(F_norm) + +# Get selected solution +best_solution = result.X[selected_idx] +best_objectives = result.F[selected_idx] + +print(f"Selected solution: {best_solution}") +print(f"Objective values: {best_objectives}") +``` + +**Other MCDM methods:** +- Compromise Programming: Select closest to ideal point +- Knee Point: Find balanced trade-off solutions +- Hypervolume Contribution: Select most diverse subset + +**See:** +- `scripts/decision_making_example.py` for complete example +- `references/constraints_mcdm.md` for detailed MCDM methods + +### Workflow 7: Visualization + +**Choose visualization based on number of objectives:** + +**2 objectives: Scatter Plot** +```python +from pymoo.visualization.scatter import Scatter + +plot = Scatter(title="Bi-objective Results") +plot.add(result.F, color="blue", alpha=0.7) +plot.show() +``` + +**3 objectives: 3D Scatter** +```python +plot = Scatter(title="Tri-objective Results") +plot.add(result.F) # Automatically renders in 3D +plot.show() +``` + +**4+ objectives: Parallel Coordinate Plot** +```python +from pymoo.visualization.pcp import PCP + +plot = PCP( + labels=[f"f{i+1}" for i in range(n_obj)], + normalize_each_axis=True +) +plot.add(result.F, alpha=0.3) +plot.show() +``` + +**Solution comparison: Petal Diagram** +```python +from pymoo.visualization.petal import Petal + +plot = Petal( + bounds=[result.F.min(axis=0), result.F.max(axis=0)], + labels=["Cost", "Weight", "Efficiency"] +) +plot.add(solution_A, label="Design A") +plot.add(solution_B, label="Design B") +plot.show() +``` + +**See:** `references/visualization.md` for all visualization types and usage + +## Algorithm Selection Guide + +### Single-Objective Problems + +| Algorithm | Best For | Key Features | +|-----------|----------|--------------| +| **GA** | General-purpose | Flexible, customizable operators | +| **DE** | Continuous optimization | Good global search | +| **PSO** | Smooth landscapes | Fast convergence | +| **CMA-ES** | Difficult/noisy problems | Self-adapting | + +### Multi-Objective Problems (2-3 objectives) + +| Algorithm | Best For | Key Features | +|-----------|----------|--------------| +| **NSGA-II** | Standard benchmark | Fast, reliable, well-tested | +| **R-NSGA-II** | Preference regions | Reference point guidance | +| **MOEA/D** | Decomposable problems | Scalarization approach | + +### Many-Objective Problems (4+ objectives) + +| Algorithm | Best For | Key Features | +|-----------|----------|--------------| +| **NSGA-III** | 4-15 objectives | Reference direction-based | +| **RVEA** | Adaptive search | Reference vector evolution | +| **AGE-MOEA** | Complex landscapes | Adaptive geometry | + +### Constrained Problems + +| Approach | Algorithm | When to Use | +|----------|-----------|-------------| +| Feasibility-first | Any algorithm | Large feasible region | +| Specialized | SRES, ISRES | Heavy constraints | +| Penalty | GA + penalty | Algorithm compatibility | + +**See:** `references/algorithms.md` for comprehensive algorithm reference + +## Benchmark Problems + +### Quick problem access: +```python +from pymoo.problems import get_problem + +# Single-objective +problem = get_problem("rastrigin", n_var=10) +problem = get_problem("rosenbrock", n_var=10) + +# Multi-objective +problem = get_problem("zdt1") # Convex front +problem = get_problem("zdt2") # Non-convex front +problem = get_problem("zdt3") # Disconnected front + +# Many-objective +problem = get_problem("dtlz2", n_obj=5, n_var=12) +problem = get_problem("dtlz7", n_obj=4) +``` + +**See:** `references/problems.md` for complete test problem reference + +## Genetic Operator Customization + +### Standard operator configuration: +```python +from pymoo.algorithms.soo.nonconvex.ga import GA +from pymoo.operators.crossover.sbx import SBX +from pymoo.operators.mutation.pm import PM + +algorithm = GA( + pop_size=100, + crossover=SBX(prob=0.9, eta=15), + mutation=PM(eta=20), + eliminate_duplicates=True +) +``` + +### Operator selection by variable type: + +**Continuous variables:** +- Crossover: SBX (Simulated Binary Crossover) +- Mutation: PM (Polynomial Mutation) + +**Binary variables:** +- Crossover: TwoPointCrossover, UniformCrossover +- Mutation: BitflipMutation + +**Permutations (TSP, scheduling):** +- Crossover: OrderCrossover (OX) +- Mutation: InversionMutation + +**See:** `references/operators.md` for comprehensive operator reference + +## Performance and Troubleshooting + +### Common issues and solutions: + +**Problem: Algorithm not converging** +- Increase population size +- Increase number of generations +- Check if problem is multimodal (try different algorithms) +- Verify constraints are correctly formulated + +**Problem: Poor Pareto front distribution** +- For NSGA-III: Adjust reference directions +- Increase population size +- Check for duplicate elimination +- Verify problem scaling + +**Problem: Few feasible solutions** +- Use constraint-as-objective approach +- Apply repair operators +- Try SRES/ISRES for constrained problems +- Check constraint formulation (should be g <= 0) + +**Problem: High computational cost** +- Reduce population size +- Decrease number of generations +- Use simpler operators +- Enable parallelization (if problem supports) + +### Best practices: + +1. **Normalize objectives** when scales differ significantly +2. **Set random seed** for reproducibility +3. **Save history** to analyze convergence: `save_history=True` +4. **Visualize results** to understand solution quality +5. **Compare with true Pareto front** when available +6. **Use appropriate termination criteria** (generations, evaluations, tolerance) +7. **Tune operator parameters** for problem characteristics + +## Resources + +This skill includes comprehensive reference documentation and executable examples: + +### references/ +Detailed documentation for in-depth understanding: + +- **algorithms.md**: Complete algorithm reference with parameters, usage, and selection guidelines +- **problems.md**: Benchmark test problems (ZDT, DTLZ, WFG) with characteristics +- **operators.md**: Genetic operators (sampling, selection, crossover, mutation) with configuration +- **visualization.md**: All visualization types with examples and selection guide +- **constraints_mcdm.md**: Constraint handling techniques and multi-criteria decision making methods + +**Search patterns for references:** +- Algorithm details: `grep -r "NSGA-II\|NSGA-III\|MOEA/D" references/` +- Constraint methods: `grep -r "Feasibility First\|Penalty\|Repair" references/` +- Visualization types: `grep -r "Scatter\|PCP\|Petal" references/` + +### scripts/ +Executable examples demonstrating common workflows: + +- **single_objective_example.py**: Basic single-objective optimization with GA +- **multi_objective_example.py**: Multi-objective optimization with NSGA-II, visualization +- **many_objective_example.py**: Many-objective optimization with NSGA-III, reference directions +- **custom_problem_example.py**: Defining custom problems (constrained and unconstrained) +- **decision_making_example.py**: Multi-criteria decision making with different preferences + +**Run examples:** +```bash +python3 scripts/single_objective_example.py +python3 scripts/multi_objective_example.py +python3 scripts/many_objective_example.py +python3 scripts/custom_problem_example.py +python3 scripts/decision_making_example.py +``` + +## Additional Notes + +**Installation:** +```bash +pip install pymoo +``` + +**Dependencies:** NumPy, SciPy, matplotlib, autograd (optional for gradient-based) + +**Documentation:** https://pymoo.org/ + +**Version:** This skill is based on pymoo 0.6.x + +**Common patterns:** +- Always use `ElementwiseProblem` for custom problems +- Constraints formulated as `g(x) <= 0` and `h(x) = 0` +- Reference directions required for NSGA-III +- Normalize objectives before MCDM +- Use appropriate termination: `('n_gen', N)` or `get_termination("f_tol", tol=0.001)` diff --git a/scientific-packages/pymoo/references/algorithms.md b/scientific-packages/pymoo/references/algorithms.md new file mode 100644 index 0000000..ca888c3 --- /dev/null +++ b/scientific-packages/pymoo/references/algorithms.md @@ -0,0 +1,180 @@ +# Pymoo Algorithms Reference + +Comprehensive reference for optimization algorithms available in pymoo. + +## Single-Objective Optimization Algorithms + +### Genetic Algorithm (GA) +**Purpose:** General-purpose single-objective evolutionary optimization +**Best for:** Continuous, discrete, or mixed-variable problems +**Algorithm type:** (μ+λ) genetic algorithm + +**Key parameters:** +- `pop_size`: Population size (default: 100) +- `sampling`: Initial population generation strategy +- `selection`: Parent selection mechanism (default: Tournament) +- `crossover`: Recombination operator (default: SBX) +- `mutation`: Variation operator (default: Polynomial) +- `eliminate_duplicates`: Remove redundant solutions (default: True) +- `n_offsprings`: Offspring per generation + +**Usage:** +```python +from pymoo.algorithms.soo.nonconvex.ga import GA +algorithm = GA(pop_size=100, eliminate_duplicates=True) +``` + +### Differential Evolution (DE) +**Purpose:** Single-objective continuous optimization +**Best for:** Continuous parameter optimization with good global search +**Algorithm type:** Population-based differential evolution + +**Variants:** Multiple DE strategies available (rand/1/bin, best/1/bin, etc.) + +### Particle Swarm Optimization (PSO) +**Purpose:** Single-objective optimization through swarm intelligence +**Best for:** Continuous problems, fast convergence on smooth landscapes + +### CMA-ES +**Purpose:** Covariance Matrix Adaptation Evolution Strategy +**Best for:** Continuous optimization, particularly for noisy or ill-conditioned problems + +### Pattern Search +**Purpose:** Direct search method +**Best for:** Problems where gradient information is unavailable + +### Nelder-Mead +**Purpose:** Simplex-based optimization +**Best for:** Local optimization of continuous functions + +## Multi-Objective Optimization Algorithms + +### NSGA-II (Non-dominated Sorting Genetic Algorithm II) +**Purpose:** Multi-objective optimization with 2-3 objectives +**Best for:** Bi- and tri-objective problems requiring well-distributed Pareto fronts +**Selection strategy:** Non-dominated sorting + crowding distance + +**Key features:** +- Fast non-dominated sorting +- Crowding distance for diversity +- Elitist approach +- Binary tournament mating selection + +**Key parameters:** +- `pop_size`: Population size (default: 100) +- `sampling`: Initial population strategy +- `crossover`: Default SBX for continuous +- `mutation`: Default Polynomial Mutation +- `survival`: RankAndCrowding + +**Usage:** +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 +algorithm = NSGA2(pop_size=100) +``` + +**When to use:** +- 2-3 objectives +- Need for distributed solutions across Pareto front +- Standard multi-objective benchmark + +### NSGA-III +**Purpose:** Many-objective optimization (4+ objectives) +**Best for:** Problems with 4 or more objectives requiring uniform Pareto front coverage +**Selection strategy:** Reference direction-based diversity maintenance + +**Key features:** +- Reference directions guide population +- Maintains diversity in high-dimensional objective spaces +- Niche preservation through reference points +- Underrepresented reference direction selection + +**Key parameters:** +- `ref_dirs`: Reference directions (REQUIRED) +- `pop_size`: Defaults to number of reference directions +- `crossover`: Default SBX +- `mutation`: Default Polynomial Mutation + +**Usage:** +```python +from pymoo.algorithms.moo.nsga3 import NSGA3 +from pymoo.util.ref_dirs import get_reference_directions + +ref_dirs = get_reference_directions("das-dennis", n_dim=4, n_partitions=12) +algorithm = NSGA3(ref_dirs=ref_dirs) +``` + +**NSGA-II vs NSGA-III:** +- Use NSGA-II for 2-3 objectives +- Use NSGA-III for 4+ objectives +- NSGA-III provides more uniform distribution +- NSGA-II has lower computational overhead + +### R-NSGA-II (Reference Point Based NSGA-II) +**Purpose:** Multi-objective optimization with preference articulation +**Best for:** When decision maker has preferred regions of Pareto front + +### U-NSGA-III (Unified NSGA-III) +**Purpose:** Improved version handling various scenarios +**Best for:** Many-objective problems with additional robustness + +### MOEA/D (Multi-Objective Evolutionary Algorithm based on Decomposition) +**Purpose:** Decomposition-based multi-objective optimization +**Best for:** Problems where decomposition into scalar subproblems is effective + +### AGE-MOEA +**Purpose:** Adaptive geometry estimation +**Best for:** Multi and many-objective problems with adaptive mechanisms + +### RVEA (Reference Vector guided Evolutionary Algorithm) +**Purpose:** Reference vector-based many-objective optimization +**Best for:** Many-objective problems with adaptive reference vectors + +### SMS-EMOA +**Purpose:** S-Metric Selection Evolutionary Multi-objective Algorithm +**Best for:** Problems where hypervolume indicator is critical +**Selection:** Uses dominated hypervolume contribution + +## Dynamic Multi-Objective Algorithms + +### D-NSGA-II +**Purpose:** Dynamic multi-objective problems +**Best for:** Time-varying objective functions or constraints + +### KGB-DMOEA +**Purpose:** Knowledge-guided dynamic multi-objective optimization +**Best for:** Dynamic problems leveraging historical information + +## Constrained Optimization + +### SRES (Stochastic Ranking Evolution Strategy) +**Purpose:** Single-objective constrained optimization +**Best for:** Heavily constrained problems + +### ISRES (Improved SRES) +**Purpose:** Enhanced constrained optimization +**Best for:** Complex constraint landscapes + +## Algorithm Selection Guidelines + +**For single-objective problems:** +- Start with GA for general problems +- Use DE for continuous optimization +- Try PSO for faster convergence on smooth problems +- Use CMA-ES for difficult/noisy landscapes + +**For multi-objective problems:** +- 2-3 objectives: NSGA-II +- 4+ objectives: NSGA-III +- Preference articulation: R-NSGA-II +- Decomposition-friendly: MOEA/D +- Hypervolume focus: SMS-EMOA + +**For constrained problems:** +- Feasibility-based survival selection (works with most algorithms) +- Heavy constraints: SRES/ISRES +- Penalty methods for algorithm compatibility + +**For dynamic problems:** +- Time-varying: D-NSGA-II +- Historical knowledge useful: KGB-DMOEA diff --git a/scientific-packages/pymoo/references/constraints_mcdm.md b/scientific-packages/pymoo/references/constraints_mcdm.md new file mode 100644 index 0000000..1cc967c --- /dev/null +++ b/scientific-packages/pymoo/references/constraints_mcdm.md @@ -0,0 +1,417 @@ +# Pymoo Constraints and Decision Making Reference + +Reference for constraint handling and multi-criteria decision making in pymoo. + +## Constraint Handling + +### Defining Constraints + +Constraints are specified in the Problem definition: + +```python +from pymoo.core.problem import ElementwiseProblem +import numpy as np + +class ConstrainedProblem(ElementwiseProblem): + def __init__(self): + super().__init__( + n_var=2, + n_obj=2, + n_ieq_constr=2, # Number of inequality constraints + n_eq_constr=1, # Number of equality constraints + xl=np.array([0, 0]), + xu=np.array([5, 5]) + ) + + def _evaluate(self, x, out, *args, **kwargs): + # Objectives + f1 = x[0]**2 + x[1]**2 + f2 = (x[0]-1)**2 + (x[1]-1)**2 + + out["F"] = [f1, f2] + + # Inequality constraints (formulated as g(x) <= 0) + g1 = x[0] + x[1] - 5 # x[0] + x[1] >= 5 → -(x[0] + x[1] - 5) <= 0 + g2 = x[0]**2 + x[1]**2 - 25 # x[0]^2 + x[1]^2 <= 25 + + out["G"] = [g1, g2] + + # Equality constraints (formulated as h(x) = 0) + h1 = x[0] - 2*x[1] + + out["H"] = [h1] +``` + +**Constraint formulation rules:** +- Inequality: `g(x) <= 0` (feasible when negative or zero) +- Equality: `h(x) = 0` (feasible when zero) +- Convert `g(x) >= 0` to `-g(x) <= 0` + +### Constraint Handling Techniques + +#### 1. Feasibility First (Default) +**Mechanism:** Always prefer feasible over infeasible solutions +**Comparison:** +1. Both feasible → compare by objective values +2. One feasible, one infeasible → feasible wins +3. Both infeasible → compare by constraint violation + +**Usage:** +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 + +# Feasibility first is default for most algorithms +algorithm = NSGA2(pop_size=100) +``` + +**Advantages:** +- Works with any sorting-based algorithm +- Simple and effective +- No parameter tuning + +**Disadvantages:** +- May struggle with small feasible regions +- Can ignore good infeasible solutions + +#### 2. Penalty Methods +**Mechanism:** Add penalty to objective based on constraint violation +**Formula:** `F_penalized = F + penalty_factor * violation` + +**Usage:** +```python +from pymoo.algorithms.soo.nonconvex.ga import GA +from pymoo.constraints.as_penalty import ConstraintsAsPenalty + +# Wrap problem with penalty +problem_with_penalty = ConstraintsAsPenalty(problem, penalty=1e6) + +algorithm = GA(pop_size=100) +``` + +**Parameters:** +- `penalty`: Penalty coefficient (tune based on problem scale) + +**Advantages:** +- Converts constrained to unconstrained problem +- Works with any optimization algorithm + +**Disadvantages:** +- Penalty parameter sensitive +- May need problem-specific tuning + +#### 3. Constraint as Objective +**Mechanism:** Treat constraint violation as additional objective +**Result:** Multi-objective problem with M+1 objectives (M original + constraint) + +**Usage:** +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.constraints.as_obj import ConstraintsAsObjective + +# Add constraint violation as objective +problem_with_cv_obj = ConstraintsAsObjective(problem) + +algorithm = NSGA2(pop_size=100) +``` + +**Advantages:** +- No parameter tuning +- Maintains infeasible solutions that may be useful +- Works well when feasible region is small + +**Disadvantages:** +- Increases problem dimensionality +- More complex Pareto front analysis + +#### 4. Epsilon-Constraint Handling +**Mechanism:** Dynamic feasibility threshold +**Concept:** Gradually tighten constraint tolerance over generations + +**Advantages:** +- Smooth transition to feasible region +- Helps with difficult constraint landscapes + +**Disadvantages:** +- Algorithm-specific implementation +- Requires parameter tuning + +#### 5. Repair Operators +**Mechanism:** Modify infeasible solutions to satisfy constraints +**Application:** After crossover/mutation, repair offspring + +**Usage:** +```python +from pymoo.core.repair import Repair + +class MyRepair(Repair): + def _do(self, problem, X, **kwargs): + # Project X onto feasible region + # Example: clip to bounds + X = np.clip(X, problem.xl, problem.xu) + return X + +from pymoo.algorithms.soo.nonconvex.ga import GA + +algorithm = GA(pop_size=100, repair=MyRepair()) +``` + +**Advantages:** +- Maintains feasibility throughout optimization +- Can encode domain knowledge + +**Disadvantages:** +- Requires problem-specific implementation +- May restrict search + +### Constraint-Handling Algorithms + +Some algorithms have built-in constraint handling: + +#### SRES (Stochastic Ranking Evolution Strategy) +**Purpose:** Single-objective constrained optimization +**Mechanism:** Stochastic ranking balances objectives and constraints + +**Usage:** +```python +from pymoo.algorithms.soo.nonconvex.sres import SRES + +algorithm = SRES() +``` + +#### ISRES (Improved SRES) +**Purpose:** Enhanced constrained optimization +**Improvements:** Better parameter adaptation + +**Usage:** +```python +from pymoo.algorithms.soo.nonconvex.isres import ISRES + +algorithm = ISRES() +``` + +### Constraint Handling Guidelines + +**Choose technique based on:** + +| Problem Characteristic | Recommended Technique | +|------------------------|----------------------| +| Large feasible region | Feasibility First | +| Small feasible region | Constraint as Objective, Repair | +| Heavily constrained | SRES/ISRES, Epsilon-constraint | +| Linear constraints | Repair (projection) | +| Nonlinear constraints | Feasibility First, Penalty | +| Known feasible solutions | Biased initialization | + +## Multi-Criteria Decision Making (MCDM) + +After obtaining a Pareto front, MCDM helps select preferred solution(s). + +### Decision Making Context + +**Pareto front characteristics:** +- Multiple non-dominated solutions +- Each represents different trade-off +- No objectively "best" solution +- Requires decision maker preferences + +### MCDM Methods in Pymoo + +#### 1. Pseudo-Weights +**Concept:** Weight each objective, select solution minimizing weighted sum +**Formula:** `score = w1*f1 + w2*f2 + ... + wM*fM` + +**Usage:** +```python +from pymoo.mcdm.pseudo_weights import PseudoWeights + +# Define weights (must sum to 1) +weights = np.array([0.3, 0.7]) # 30% weight on f1, 70% on f2 + +dm = PseudoWeights(weights) +best_idx = dm.do(result.F) +best_solution = result.X[best_idx] +``` + +**When to use:** +- Clear preference articulation available +- Objectives commensurable +- Linear trade-offs acceptable + +**Limitations:** +- Requires weight specification +- Linear assumption may not capture preferences +- Sensitive to objective scaling + +#### 2. Compromise Programming +**Concept:** Select solution closest to ideal point +**Metric:** Distance to ideal (e.g., Euclidean, Tchebycheff) + +**Usage:** +```python +from pymoo.mcdm.compromise_programming import CompromiseProgramming + +dm = CompromiseProgramming() +best_idx = dm.do(result.F, ideal=ideal_point, nadir=nadir_point) +``` + +**When to use:** +- Ideal objective values known or estimable +- Balanced consideration of all objectives +- No clear weight preferences + +#### 3. Interactive Decision Making +**Concept:** Iterative preference refinement +**Process:** +1. Show representative solutions to decision maker +2. Gather feedback on preferences +3. Focus search on preferred regions +4. Repeat until satisfactory solution found + +**Approaches:** +- Reference point methods +- Trade-off analysis +- Progressive preference articulation + +### Decision Making Workflow + +**Step 1: Normalize objectives** +```python +# Normalize to [0, 1] for fair comparison +F_norm = (result.F - result.F.min(axis=0)) / (result.F.max(axis=0) - result.F.min(axis=0)) +``` + +**Step 2: Analyze trade-offs** +```python +from pymoo.visualization.scatter import Scatter + +plot = Scatter() +plot.add(result.F) +plot.show() + +# Identify knee points, extreme solutions +``` + +**Step 3: Apply MCDM method** +```python +from pymoo.mcdm.pseudo_weights import PseudoWeights + +weights = np.array([0.4, 0.6]) # Based on preferences +dm = PseudoWeights(weights) +selected = dm.do(F_norm) +``` + +**Step 4: Validate selection** +```python +# Visualize selected solution +from pymoo.visualization.petal import Petal + +plot = Petal() +plot.add(result.F[selected], label="Selected") +# Add other candidates for comparison +plot.show() +``` + +### Advanced MCDM Techniques + +#### Knee Point Detection +**Concept:** Solutions where small improvement in one objective causes large degradation in others + +**Usage:** +```python +from pymoo.mcdm.knee import KneePoint + +km = KneePoint() +knee_idx = km.do(result.F) +knee_solutions = result.X[knee_idx] +``` + +**When to use:** +- No clear preferences +- Balanced trade-offs desired +- Convex Pareto fronts + +#### Hypervolume Contribution +**Concept:** Select solutions contributing most to hypervolume +**Use case:** Maintain diverse subset of solutions + +**Usage:** +```python +from pymoo.indicators.hv import HV + +hv = HV(ref_point=reference_point) +hv_contributions = hv.calc_contributions(result.F) + +# Select top contributors +top_k = 5 +top_indices = np.argsort(hv_contributions)[-top_k:] +selected_solutions = result.X[top_indices] +``` + +### Decision Making Guidelines + +**When decision maker has:** + +| Preference Information | Recommended Method | +|------------------------|-------------------| +| Clear objective weights | Pseudo-Weights | +| Ideal target values | Compromise Programming | +| No prior preferences | Knee Point, Visual inspection | +| Conflicting criteria | Interactive methods | +| Need diverse subset | Hypervolume contribution | + +**Best practices:** +1. **Normalize objectives** before MCDM +2. **Visualize Pareto front** to understand trade-offs +3. **Consider multiple methods** for robust selection +4. **Validate results** with domain experts +5. **Document assumptions** and preference sources +6. **Perform sensitivity analysis** on weights/parameters + +### Integration Example + +Complete workflow with constraint handling and decision making: + +```python +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.optimize import minimize +from pymoo.mcdm.pseudo_weights import PseudoWeights +import numpy as np + +# Define constrained problem +problem = MyConstrainedProblem() + +# Setup algorithm with feasibility-first constraint handling +algorithm = NSGA2( + pop_size=100, + eliminate_duplicates=True +) + +# Optimize +result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=True +) + +# Filter feasible solutions only +feasible_mask = result.CV[:, 0] == 0 # Constraint violation = 0 +F_feasible = result.F[feasible_mask] +X_feasible = result.X[feasible_mask] + +# Normalize objectives +F_norm = (F_feasible - F_feasible.min(axis=0)) / (F_feasible.max(axis=0) - F_feasible.min(axis=0)) + +# Apply MCDM +weights = np.array([0.5, 0.5]) +dm = PseudoWeights(weights) +best_idx = dm.do(F_norm) + +# Get final solution +best_solution = X_feasible[best_idx] +best_objectives = F_feasible[best_idx] + +print(f"Selected solution: {best_solution}") +print(f"Objective values: {best_objectives}") +``` diff --git a/scientific-packages/pymoo/references/operators.md b/scientific-packages/pymoo/references/operators.md new file mode 100644 index 0000000..a4152f1 --- /dev/null +++ b/scientific-packages/pymoo/references/operators.md @@ -0,0 +1,345 @@ +# Pymoo Genetic Operators Reference + +Comprehensive reference for genetic operators in pymoo. + +## Sampling Operators + +Sampling operators initialize populations at the start of optimization. + +### Random Sampling +**Purpose:** Generate random initial solutions +**Types:** +- `FloatRandomSampling`: Continuous variables +- `BinaryRandomSampling`: Binary variables +- `IntegerRandomSampling`: Integer variables +- `PermutationRandomSampling`: Permutation-based problems + +**Usage:** +```python +from pymoo.operators.sampling.rnd import FloatRandomSampling +sampling = FloatRandomSampling() +``` + +### Latin Hypercube Sampling (LHS) +**Purpose:** Space-filling initial population +**Benefit:** Better coverage of search space than random +**Types:** +- `LHS`: Standard Latin Hypercube + +**Usage:** +```python +from pymoo.operators.sampling.lhs import LHS +sampling = LHS() +``` + +### Custom Sampling +Provide initial population through Population object or NumPy array + +## Selection Operators + +Selection operators choose parents for reproduction. + +### Tournament Selection +**Purpose:** Select parents through tournament competition +**Mechanism:** Randomly select k individuals, choose best +**Parameters:** +- `pressure`: Tournament size (default: 2) +- `func_comp`: Comparison function + +**Usage:** +```python +from pymoo.operators.selection.tournament import TournamentSelection +selection = TournamentSelection(pressure=2) +``` + +### Random Selection +**Purpose:** Uniform random parent selection +**Use case:** Baseline or exploration-focused algorithms + +**Usage:** +```python +from pymoo.operators.selection.rnd import RandomSelection +selection = RandomSelection() +``` + +## Crossover Operators + +Crossover operators recombine parent solutions to create offspring. + +### For Continuous Variables + +#### Simulated Binary Crossover (SBX) +**Purpose:** Primary crossover for continuous optimization +**Mechanism:** Simulates single-point crossover of binary-encoded variables +**Parameters:** +- `prob`: Crossover probability (default: 0.9) +- `eta`: Distribution index (default: 15) + - Higher eta → offspring closer to parents + - Lower eta → more exploration + +**Usage:** +```python +from pymoo.operators.crossover.sbx import SBX +crossover = SBX(prob=0.9, eta=15) +``` + +**String shorthand:** `"real_sbx"` + +#### Differential Evolution Crossover +**Purpose:** DE-specific recombination +**Variants:** +- `DE/rand/1/bin` +- `DE/best/1/bin` +- `DE/current-to-best/1/bin` + +**Parameters:** +- `CR`: Crossover rate +- `F`: Scaling factor + +### For Binary Variables + +#### Single Point Crossover +**Purpose:** Cut and swap at one point +**Usage:** +```python +from pymoo.operators.crossover.pntx import SinglePointCrossover +crossover = SinglePointCrossover() +``` + +#### Two Point Crossover +**Purpose:** Cut and swap between two points +**Usage:** +```python +from pymoo.operators.crossover.pntx import TwoPointCrossover +crossover = TwoPointCrossover() +``` + +#### K-Point Crossover +**Purpose:** Multiple cut points +**Parameters:** +- `n_points`: Number of crossover points + +#### Uniform Crossover +**Purpose:** Each gene independently from either parent +**Parameters:** +- `prob`: Per-gene swap probability (default: 0.5) + +**Usage:** +```python +from pymoo.operators.crossover.ux import UniformCrossover +crossover = UniformCrossover(prob=0.5) +``` + +#### Half Uniform Crossover (HUX) +**Purpose:** Exchange exactly half of differing genes +**Benefit:** Maintains genetic diversity + +### For Permutations + +#### Order Crossover (OX) +**Purpose:** Preserve relative order from parents +**Use case:** Traveling salesman, scheduling problems + +**Usage:** +```python +from pymoo.operators.crossover.ox import OrderCrossover +crossover = OrderCrossover() +``` + +#### Edge Recombination Crossover (ERX) +**Purpose:** Preserve edge information from parents +**Use case:** Routing problems where edge connectivity matters + +#### Partially Mapped Crossover (PMX) +**Purpose:** Exchange segments while maintaining permutation validity + +## Mutation Operators + +Mutation operators introduce variation to maintain diversity. + +### For Continuous Variables + +#### Polynomial Mutation (PM) +**Purpose:** Primary mutation for continuous optimization +**Mechanism:** Polynomial probability distribution +**Parameters:** +- `prob`: Per-variable mutation probability +- `eta`: Distribution index (default: 20) + - Higher eta → smaller perturbations + - Lower eta → larger perturbations + +**Usage:** +```python +from pymoo.operators.mutation.pm import PM +mutation = PM(prob=None, eta=20) # prob=None means 1/n_var +``` + +**String shorthand:** `"real_pm"` + +**Probability guidelines:** +- `None` or `1/n_var`: Standard recommendation +- Higher for more exploration +- Lower for more exploitation + +### For Binary Variables + +#### Bitflip Mutation +**Purpose:** Flip bits with specified probability +**Parameters:** +- `prob`: Per-bit flip probability + +**Usage:** +```python +from pymoo.operators.mutation.bitflip import BitflipMutation +mutation = BitflipMutation(prob=0.05) +``` + +### For Integer Variables + +#### Integer Polynomial Mutation +**Purpose:** PM adapted for integers +**Ensures:** Valid integer values after mutation + +### For Permutations + +#### Inversion Mutation +**Purpose:** Reverse a segment of the permutation +**Use case:** Maintains some order structure + +**Usage:** +```python +from pymoo.operators.mutation.inversion import InversionMutation +mutation = InversionMutation() +``` + +#### Scramble Mutation +**Purpose:** Randomly shuffle a segment + +### Custom Mutation +Define custom mutation by extending `Mutation` class + +## Repair Operators + +Repair operators fix constraint violations or ensure solution feasibility. + +### Rounding Repair +**Purpose:** Round to nearest valid value +**Use case:** Integer/discrete variables with bound constraints + +### Bounce Back Repair +**Purpose:** Reflect out-of-bounds values back into feasible region +**Use case:** Box-constrained continuous problems + +### Projection Repair +**Purpose:** Project infeasible solutions onto feasible region +**Use case:** Linear constraints + +### Custom Repair +**Purpose:** Domain-specific constraint handling +**Implementation:** Extend `Repair` class + +**Example:** +```python +from pymoo.core.repair import Repair + +class MyRepair(Repair): + def _do(self, problem, X, **kwargs): + # Modify X to satisfy constraints + # Return repaired X + return X +``` + +## Operator Configuration Guidelines + +### Parameter Tuning + +**Crossover probability:** +- High (0.8-0.95): Standard for most problems +- Lower: More emphasis on mutation + +**Mutation probability:** +- `1/n_var`: Standard recommendation +- Higher: More exploration, slower convergence +- Lower: Faster convergence, risk of premature convergence + +**Distribution indices (eta):** +- Crossover eta (15-30): Higher for local search +- Mutation eta (20-50): Higher for exploitation + +### Problem-Specific Selection + +**Continuous problems:** +- Crossover: SBX +- Mutation: Polynomial Mutation +- Selection: Tournament + +**Binary problems:** +- Crossover: Two-point or Uniform +- Mutation: Bitflip +- Selection: Tournament + +**Permutation problems:** +- Crossover: Order Crossover (OX) +- Mutation: Inversion or Scramble +- Selection: Tournament + +**Mixed-variable problems:** +- Use appropriate operators per variable type +- Ensure operator compatibility + +### String-Based Configuration + +Pymoo supports convenient string-based operator specification: + +```python +from pymoo.algorithms.soo.nonconvex.ga import GA + +algorithm = GA( + pop_size=100, + sampling="real_random", + crossover="real_sbx", + mutation="real_pm" +) +``` + +**Available strings:** +- Sampling: `"real_random"`, `"real_lhs"`, `"bin_random"`, `"perm_random"` +- Crossover: `"real_sbx"`, `"real_de"`, `"int_sbx"`, `"bin_ux"`, `"bin_hux"` +- Mutation: `"real_pm"`, `"int_pm"`, `"bin_bitflip"`, `"perm_inv"` + +## Operator Combination Examples + +### Standard Continuous GA: +```python +from pymoo.operators.sampling.rnd import FloatRandomSampling +from pymoo.operators.crossover.sbx import SBX +from pymoo.operators.mutation.pm import PM +from pymoo.operators.selection.tournament import TournamentSelection + +sampling = FloatRandomSampling() +crossover = SBX(prob=0.9, eta=15) +mutation = PM(eta=20) +selection = TournamentSelection() +``` + +### Binary GA: +```python +from pymoo.operators.sampling.rnd import BinaryRandomSampling +from pymoo.operators.crossover.pntx import TwoPointCrossover +from pymoo.operators.mutation.bitflip import BitflipMutation + +sampling = BinaryRandomSampling() +crossover = TwoPointCrossover() +mutation = BitflipMutation(prob=0.05) +``` + +### Permutation GA (TSP): +```python +from pymoo.operators.sampling.rnd import PermutationRandomSampling +from pymoo.operators.crossover.ox import OrderCrossover +from pymoo.operators.mutation.inversion import InversionMutation + +sampling = PermutationRandomSampling() +crossover = OrderCrossover() +mutation = InversionMutation() +``` diff --git a/scientific-packages/pymoo/references/problems.md b/scientific-packages/pymoo/references/problems.md new file mode 100644 index 0000000..5fc679a --- /dev/null +++ b/scientific-packages/pymoo/references/problems.md @@ -0,0 +1,265 @@ +# Pymoo Test Problems Reference + +Comprehensive reference for benchmark optimization problems in pymoo. + +## Single-Objective Test Problems + +### Ackley Function +**Characteristics:** +- Highly multimodal +- Many local optima +- Tests algorithm's ability to escape local minima +- Continuous variables + +### Griewank Function +**Characteristics:** +- Multimodal with regularly distributed local minima +- Product term introduces interdependencies between variables +- Global minimum at origin + +### Rastrigin Function +**Characteristics:** +- Highly multimodal with regularly spaced local minima +- Challenging for gradient-based methods +- Tests global search capability + +### Rosenbrock Function +**Characteristics:** +- Unimodal but narrow valley to global optimum +- Tests algorithm's convergence in difficult landscape +- Classic benchmark for continuous optimization + +### Zakharov Function +**Characteristics:** +- Unimodal +- Single global minimum +- Tests basic convergence capability + +## Multi-Objective Test Problems (2-3 objectives) + +### ZDT Test Suite +**Purpose:** Standard benchmark for bi-objective optimization +**Construction:** f₂(x) = g(x) · h(f₁(x), g(x)) where g(x) = 1 at Pareto-optimal solutions + +#### ZDT1 +- **Variables:** 30 continuous +- **Bounds:** [0, 1] +- **Pareto front:** Convex +- **Purpose:** Basic convergence and diversity test + +#### ZDT2 +- **Variables:** 30 continuous +- **Bounds:** [0, 1] +- **Pareto front:** Non-convex (concave) +- **Purpose:** Tests handling of non-convex fronts + +#### ZDT3 +- **Variables:** 30 continuous +- **Bounds:** [0, 1] +- **Pareto front:** Disconnected (5 separate regions) +- **Purpose:** Tests diversity maintenance across discontinuous front + +#### ZDT4 +- **Variables:** 10 continuous (x₁ ∈ [0,1], x₂₋₁₀ ∈ [-10,10]) +- **Pareto front:** Convex +- **Difficulty:** 21⁹ local Pareto fronts +- **Purpose:** Tests global search with many local optima + +#### ZDT5 +- **Variables:** 11 discrete (bitstring) +- **Encoding:** x₁ uses 30 bits, x₂₋₁₁ use 5 bits each +- **Pareto front:** Convex +- **Purpose:** Tests discrete optimization and deceptive landscapes + +#### ZDT6 +- **Variables:** 10 continuous +- **Bounds:** [0, 1] +- **Pareto front:** Non-convex with non-uniform density +- **Purpose:** Tests handling of biased solution distributions + +**Usage:** +```python +from pymoo.problems.multi import ZDT1, ZDT2, ZDT3, ZDT4, ZDT5, ZDT6 +problem = ZDT1() # or ZDT2(), ZDT3(), etc. +``` + +### BNH (Binh and Korn) +**Characteristics:** +- 2 objectives +- 2 variables +- Constrained problem +- Tests constraint handling in multi-objective context + +### OSY (Osyczka and Kundu) +**Characteristics:** +- 6 objectives +- 6 variables +- Multiple constraints +- Real-world inspired + +### TNK (Tanaka) +**Characteristics:** +- 2 objectives +- 2 variables +- Disconnected feasible region +- Tests handling of disjoint search spaces + +### Truss2D +**Characteristics:** +- Structural engineering problem +- Bi-objective (weight vs displacement) +- Practical application test + +### Welded Beam +**Characteristics:** +- Engineering design problem +- Multiple constraints +- Practical optimization scenario + +### Omni-test +**Characteristics:** +- Configurable test problem +- Various difficulty levels +- Systematic testing + +### SYM-PART +**Characteristics:** +- Symmetric problem structure +- Tests specific algorithmic behaviors + +## Many-Objective Test Problems (4+ objectives) + +### DTLZ Test Suite +**Purpose:** Scalable many-objective benchmarks +**Objectives:** Configurable (typically 3-15) +**Variables:** Scalable + +#### DTLZ1 +- **Pareto front:** Linear (hyperplane) +- **Difficulty:** 11^k local Pareto fronts +- **Purpose:** Tests convergence with many local optima + +#### DTLZ2 +- **Pareto front:** Spherical (concave) +- **Difficulty:** Straightforward convergence +- **Purpose:** Basic many-objective diversity test + +#### DTLZ3 +- **Pareto front:** Spherical +- **Difficulty:** 3^k local Pareto fronts +- **Purpose:** Combines DTLZ1's multimodality with DTLZ2's geometry + +#### DTLZ4 +- **Pareto front:** Spherical with biased density +- **Difficulty:** Non-uniform solution distribution +- **Purpose:** Tests diversity maintenance with bias + +#### DTLZ5 +- **Pareto front:** Degenerate (curve in M-dimensional space) +- **Purpose:** Tests handling of degenerate fronts + +#### DTLZ6 +- **Pareto front:** Degenerate curve +- **Difficulty:** Harder convergence than DTLZ5 +- **Purpose:** Challenging degenerate front + +#### DTLZ7 +- **Pareto front:** Disconnected regions +- **Difficulty:** 2^(M-1) disconnected regions +- **Purpose:** Tests diversity across disconnected fronts + +**Usage:** +```python +from pymoo.problems.many import DTLZ1, DTLZ2 +problem = DTLZ1(n_var=7, n_obj=3) # 7 variables, 3 objectives +``` + +### WFG Test Suite +**Purpose:** Walking Fish Group scalable benchmarks +**Features:** More complex than DTLZ, various front shapes and difficulties + +**Variants:** WFG1-WFG9 with different characteristics +- Non-separable +- Deceptive +- Multimodal +- Biased +- Scaled fronts + +## Constrained Multi-Objective Problems + +### MW Test Suite +**Purpose:** Multi-objective problems with various constraint types +**Features:** Different constraint difficulty levels + +### DAS-CMOP +**Purpose:** Difficulty-adjustable and scalable constrained multi-objective problems +**Features:** Tunable constraint difficulty + +### MODAct +**Purpose:** Multi-objective optimization with active constraints +**Features:** Realistic constraint scenarios + +## Dynamic Multi-Objective Problems + +### DF Test Suite +**Purpose:** CEC2018 Competition dynamic multi-objective benchmarks +**Features:** +- Time-varying objectives +- Changing Pareto fronts +- Tests algorithm adaptability + +**Variants:** DF1-DF14 with different dynamics + +## Custom Problem Definition + +Define custom problems by extending base classes: + +```python +from pymoo.core.problem import ElementwiseProblem +import numpy as np + +class MyProblem(ElementwiseProblem): + def __init__(self): + super().__init__( + n_var=2, # number of variables + n_obj=2, # number of objectives + n_ieq_constr=0, # inequality constraints + n_eq_constr=0, # equality constraints + xl=np.array([0, 0]), # lower bounds + xu=np.array([1, 1]) # upper bounds + ) + + def _evaluate(self, x, out, *args, **kwargs): + # Define objectives + f1 = x[0]**2 + x[1]**2 + f2 = (x[0]-1)**2 + x[1]**2 + + out["F"] = [f1, f2] + + # Optional: constraints + # out["G"] = constraint_values # <= 0 + # out["H"] = equality_constraints # == 0 +``` + +## Problem Selection Guidelines + +**For algorithm development:** +- Simple convergence: DTLZ2, ZDT1 +- Multimodal: ZDT4, DTLZ1, DTLZ3 +- Non-convex: ZDT2 +- Disconnected: ZDT3, DTLZ7 + +**For comprehensive testing:** +- ZDT suite for bi-objective +- DTLZ suite for many-objective +- WFG for complex landscapes +- MW/DAS-CMOP for constraints + +**For real-world validation:** +- Engineering problems (Truss2D, Welded Beam) +- Match problem characteristics to application domain + +**Variable types:** +- Continuous: Most problems +- Discrete: ZDT5 +- Mixed: Define custom problem diff --git a/scientific-packages/pymoo/references/visualization.md b/scientific-packages/pymoo/references/visualization.md new file mode 100644 index 0000000..87ad45b --- /dev/null +++ b/scientific-packages/pymoo/references/visualization.md @@ -0,0 +1,353 @@ +# Pymoo Visualization Reference + +Comprehensive reference for visualization capabilities in pymoo. + +## Overview + +Pymoo provides eight visualization types for analyzing multi-objective optimization results. All plots wrap matplotlib and accept standard matplotlib keyword arguments for customization. + +## Core Visualization Types + +### 1. Scatter Plots +**Purpose:** Visualize objective space for 2D, 3D, or higher dimensions +**Best for:** Pareto fronts, solution distributions, algorithm comparisons + +**Usage:** +```python +from pymoo.visualization.scatter import Scatter + +# 2D scatter plot +plot = Scatter() +plot.add(result.F, color="red", label="Algorithm A") +plot.add(ref_pareto_front, color="black", alpha=0.3, label="True PF") +plot.show() + +# 3D scatter plot +plot = Scatter(title="3D Pareto Front") +plot.add(result.F) +plot.show() +``` + +**Parameters:** +- `title`: Plot title +- `figsize`: Figure size tuple (width, height) +- `legend`: Show legend (default: True) +- `labels`: Axis labels list + +**Add method parameters:** +- `color`: Color specification +- `alpha`: Transparency (0-1) +- `s`: Marker size +- `marker`: Marker style +- `label`: Legend label + +**N-dimensional projection:** +For >3 objectives, automatically creates scatter plot matrix + +### 2. Parallel Coordinate Plots (PCP) +**Purpose:** Compare multiple solutions across many objectives +**Best for:** Many-objective problems, comparing algorithm performance + +**Mechanism:** Each vertical axis represents one objective, lines connect objective values for each solution + +**Usage:** +```python +from pymoo.visualization.pcp import PCP + +plot = PCP() +plot.add(result.F, color="blue", alpha=0.5) +plot.add(reference_set, color="red", alpha=0.8) +plot.show() +``` + +**Parameters:** +- `title`: Plot title +- `figsize`: Figure size +- `labels`: Objective labels +- `bounds`: Normalization bounds (min, max) per objective +- `normalize_each_axis`: Normalize to [0,1] per axis (default: True) + +**Best practices:** +- Normalize for different objective scales +- Use transparency for overlapping lines +- Limit number of solutions for clarity (<1000) + +### 3. Heatmap +**Purpose:** Show solution density and distribution patterns +**Best for:** Understanding solution clustering, identifying gaps + +**Usage:** +```python +from pymoo.visualization.heatmap import Heatmap + +plot = Heatmap(title="Solution Density") +plot.add(result.F) +plot.show() +``` + +**Parameters:** +- `bins`: Number of bins per dimension (default: 20) +- `cmap`: Colormap name (e.g., "viridis", "plasma", "hot") +- `norm`: Normalization method + +**Interpretation:** +- Bright regions: High solution density +- Dark regions: Few or no solutions +- Reveals distribution uniformity + +### 4. Petal Diagram +**Purpose:** Radial representation of multiple objectives +**Best for:** Comparing individual solutions across objectives + +**Structure:** Each "petal" represents one objective, length indicates objective value + +**Usage:** +```python +from pymoo.visualization.petal import Petal + +plot = Petal(title="Solution Comparison", bounds=[min_vals, max_vals]) +plot.add(result.F[0], color="blue", label="Solution 1") +plot.add(result.F[1], color="red", label="Solution 2") +plot.show() +``` + +**Parameters:** +- `bounds`: [min, max] per objective for normalization +- `labels`: Objective names +- `reverse`: Reverse specific objectives (for minimization display) + +**Use cases:** +- Decision making between few solutions +- Presenting trade-offs to stakeholders + +### 5. Radar Charts +**Purpose:** Multi-criteria performance profiles +**Best for:** Comparing solution characteristics + +**Similar to:** Petal diagram but with connected vertices + +**Usage:** +```python +from pymoo.visualization.radar import Radar + +plot = Radar(bounds=[min_vals, max_vals]) +plot.add(solution_A, label="Design A") +plot.add(solution_B, label="Design B") +plot.show() +``` + +### 6. Radviz +**Purpose:** Dimensional reduction for visualization +**Best for:** High-dimensional data exploration, pattern recognition + +**Mechanism:** Projects high-dimensional points onto 2D circle, dimension anchors on perimeter + +**Usage:** +```python +from pymoo.visualization.radviz import Radviz + +plot = Radviz(title="High-dimensional Solution Space") +plot.add(result.F, color="blue", s=30) +plot.show() +``` + +**Parameters:** +- `endpoint_style`: Anchor point visualization +- `labels`: Dimension labels + +**Interpretation:** +- Points near anchor: High value in that dimension +- Central points: Balanced across dimensions +- Clusters: Similar solutions + +### 7. Star Coordinates +**Purpose:** Alternative high-dimensional visualization +**Best for:** Comparing multi-dimensional datasets + +**Mechanism:** Each dimension as axis from origin, points plotted based on values + +**Usage:** +```python +from pymoo.visualization.star_coordinate import StarCoordinate + +plot = StarCoordinate() +plot.add(result.F) +plot.show() +``` + +**Parameters:** +- `axis_style`: Axis appearance +- `axis_extension`: Axis length beyond max value +- `labels`: Dimension labels + +### 8. Video/Animation +**Purpose:** Show optimization progress over time +**Best for:** Understanding convergence behavior, presentations + +**Usage:** +```python +from pymoo.visualization.video import Video + +# Create animation from algorithm history +anim = Video(result.algorithm) +anim.save("optimization_progress.mp4") +``` + +**Requirements:** +- Algorithm must store history (use `save_history=True` in minimize) +- ffmpeg installed for video export + +**Customization:** +- Frame rate +- Plot type per frame +- Overlay information (generation, hypervolume, etc.) + +## Advanced Features + +### Multiple Dataset Overlay + +All plot types support adding multiple datasets: + +```python +plot = Scatter(title="Algorithm Comparison") +plot.add(nsga2_result.F, color="red", alpha=0.5, label="NSGA-II") +plot.add(nsga3_result.F, color="blue", alpha=0.5, label="NSGA-III") +plot.add(true_pareto_front, color="black", linewidth=2, label="True PF") +plot.show() +``` + +### Custom Styling + +Pass matplotlib kwargs directly: + +```python +plot = Scatter( + title="My Results", + figsize=(10, 8), + tight_layout=True +) +plot.add( + result.F, + color="red", + marker="o", + s=50, + alpha=0.7, + edgecolors="black", + linewidth=0.5 +) +``` + +### Normalization + +Normalize objectives to [0,1] for fair comparison: + +```python +plot = PCP(normalize_each_axis=True, bounds=[min_bounds, max_bounds]) +``` + +### Save to File + +Save plots instead of displaying: + +```python +plot = Scatter() +plot.add(result.F) +plot.save("my_plot.png", dpi=300) +``` + +## Visualization Selection Guide + +**Choose visualization based on:** + +| Problem Type | Primary Plot | Secondary Plot | +|--------------|--------------|----------------| +| 2-objective | Scatter | Heatmap | +| 3-objective | 3D Scatter | Parallel Coordinates | +| Many-objective (4-10) | Parallel Coordinates | Radviz | +| Many-objective (>10) | Radviz | Star Coordinates | +| Solution comparison | Petal/Radar | Parallel Coordinates | +| Algorithm convergence | Video | Scatter (final) | +| Distribution analysis | Heatmap | Scatter | + +**Combinations:** +- Scatter + Heatmap: Overall distribution + density +- PCP + Petal: Population overview + individual solutions +- Scatter + Video: Final result + convergence process + +## Common Visualization Workflows + +### 1. Algorithm Comparison +```python +from pymoo.visualization.scatter import Scatter + +plot = Scatter(title="Algorithm Comparison on ZDT1") +plot.add(ga_result.F, color="blue", label="GA", alpha=0.6) +plot.add(nsga2_result.F, color="red", label="NSGA-II", alpha=0.6) +plot.add(zdt1.pareto_front(), color="black", label="True PF") +plot.show() +``` + +### 2. Many-objective Analysis +```python +from pymoo.visualization.pcp import PCP + +plot = PCP( + title="5-objective DTLZ2 Results", + labels=["f1", "f2", "f3", "f4", "f5"], + normalize_each_axis=True +) +plot.add(result.F, alpha=0.3) +plot.show() +``` + +### 3. Decision Making +```python +from pymoo.visualization.petal import Petal + +# Compare top 3 solutions +candidates = result.F[:3] + +plot = Petal( + title="Top 3 Solutions", + bounds=[result.F.min(axis=0), result.F.max(axis=0)], + labels=["Cost", "Weight", "Efficiency", "Safety"] +) +for i, sol in enumerate(candidates): + plot.add(sol, label=f"Solution {i+1}") +plot.show() +``` + +### 4. Convergence Visualization +```python +from pymoo.optimize import minimize + +# Enable history +result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + save_history=True, + verbose=False +) + +# Create convergence plot +from pymoo.visualization.scatter import Scatter + +plot = Scatter(title="Convergence Over Generations") +for gen in [0, 50, 100, 150, 200]: + F = result.history[gen].opt.get("F") + plot.add(F, alpha=0.5, label=f"Gen {gen}") +plot.show() +``` + +## Tips and Best Practices + +1. **Use appropriate alpha:** For overlapping points, use `alpha=0.3-0.7` +2. **Normalize objectives:** Different scales? Normalize for fair visualization +3. **Label clearly:** Always provide meaningful labels and legends +4. **Limit data points:** >10000 points? Sample or use heatmap +5. **Color schemes:** Use colorblind-friendly palettes +6. **Save high-res:** Use `dpi=300` for publications +7. **Interactive exploration:** Consider plotly for interactive plots +8. **Combine views:** Show multiple perspectives for comprehensive analysis diff --git a/scientific-packages/pymoo/scripts/custom_problem_example.py b/scientific-packages/pymoo/scripts/custom_problem_example.py new file mode 100644 index 0000000..dce80bf --- /dev/null +++ b/scientific-packages/pymoo/scripts/custom_problem_example.py @@ -0,0 +1,181 @@ +""" +Custom problem definition example using pymoo. + +This script demonstrates how to define a custom optimization problem +and solve it using pymoo. +""" + +from pymoo.core.problem import ElementwiseProblem +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.optimize import minimize +from pymoo.visualization.scatter import Scatter +import numpy as np + + +class MyBiObjectiveProblem(ElementwiseProblem): + """ + Custom bi-objective optimization problem. + + Minimize: + f1(x) = x1^2 + x2^2 + f2(x) = (x1-1)^2 + (x2-1)^2 + + Subject to: + 0 <= x1 <= 5 + 0 <= x2 <= 5 + """ + + def __init__(self): + super().__init__( + n_var=2, # Number of decision variables + n_obj=2, # Number of objectives + n_ieq_constr=0, # Number of inequality constraints + n_eq_constr=0, # Number of equality constraints + xl=np.array([0, 0]), # Lower bounds + xu=np.array([5, 5]) # Upper bounds + ) + + def _evaluate(self, x, out, *args, **kwargs): + """Evaluate objectives for a single solution.""" + # Objective 1: Distance from origin + f1 = x[0]**2 + x[1]**2 + + # Objective 2: Distance from (1, 1) + f2 = (x[0] - 1)**2 + (x[1] - 1)**2 + + # Return objectives + out["F"] = [f1, f2] + + +class ConstrainedProblem(ElementwiseProblem): + """ + Custom constrained bi-objective problem. + + Minimize: + f1(x) = x1 + f2(x) = (1 + x2) / x1 + + Subject to: + x2 + 9*x1 >= 6 (g1 <= 0) + -x2 + 9*x1 >= 1 (g2 <= 0) + 0.1 <= x1 <= 1 + 0 <= x2 <= 5 + """ + + def __init__(self): + super().__init__( + n_var=2, + n_obj=2, + n_ieq_constr=2, # Two inequality constraints + xl=np.array([0.1, 0.0]), + xu=np.array([1.0, 5.0]) + ) + + def _evaluate(self, x, out, *args, **kwargs): + """Evaluate objectives and constraints.""" + # Objectives + f1 = x[0] + f2 = (1 + x[1]) / x[0] + + out["F"] = [f1, f2] + + # Inequality constraints (g <= 0) + # Convert g1: x2 + 9*x1 >= 6 → -(x2 + 9*x1 - 6) <= 0 + g1 = -(x[1] + 9 * x[0] - 6) + + # Convert g2: -x2 + 9*x1 >= 1 → -(-x2 + 9*x1 - 1) <= 0 + g2 = -(-x[1] + 9 * x[0] - 1) + + out["G"] = [g1, g2] + + +def solve_custom_problem(): + """Solve custom bi-objective problem.""" + + print("="*60) + print("CUSTOM PROBLEM - UNCONSTRAINED") + print("="*60) + + # Define custom problem + problem = MyBiObjectiveProblem() + + # Configure algorithm + algorithm = NSGA2(pop_size=100) + + # Solve + result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=False + ) + + print(f"Number of solutions: {len(result.F)}") + print(f"Objective space range:") + print(f" f1: [{result.F[:, 0].min():.3f}, {result.F[:, 0].max():.3f}]") + print(f" f2: [{result.F[:, 1].min():.3f}, {result.F[:, 1].max():.3f}]") + + # Visualize + plot = Scatter(title="Custom Bi-Objective Problem") + plot.add(result.F, color="blue", alpha=0.7) + plot.show() + + return result + + +def solve_constrained_problem(): + """Solve custom constrained problem.""" + + print("\n" + "="*60) + print("CUSTOM PROBLEM - CONSTRAINED") + print("="*60) + + # Define constrained problem + problem = ConstrainedProblem() + + # Configure algorithm + algorithm = NSGA2(pop_size=100) + + # Solve + result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=False + ) + + # Check feasibility + feasible = result.CV[:, 0] == 0 # Constraint violation = 0 + + print(f"Total solutions: {len(result.F)}") + print(f"Feasible solutions: {np.sum(feasible)}") + print(f"Infeasible solutions: {np.sum(~feasible)}") + + if np.any(feasible): + F_feasible = result.F[feasible] + print(f"\nFeasible objective space range:") + print(f" f1: [{F_feasible[:, 0].min():.3f}, {F_feasible[:, 0].max():.3f}]") + print(f" f2: [{F_feasible[:, 1].min():.3f}, {F_feasible[:, 1].max():.3f}]") + + # Visualize feasible solutions + plot = Scatter(title="Constrained Problem - Feasible Solutions") + plot.add(F_feasible, color="green", alpha=0.7, label="Feasible") + + if np.any(~feasible): + plot.add(result.F[~feasible], color="red", alpha=0.3, s=10, label="Infeasible") + + plot.show() + + return result + + +if __name__ == "__main__": + # Run both examples + result1 = solve_custom_problem() + result2 = solve_constrained_problem() + + print("\n" + "="*60) + print("EXAMPLES COMPLETED") + print("="*60) diff --git a/scientific-packages/pymoo/scripts/decision_making_example.py b/scientific-packages/pymoo/scripts/decision_making_example.py new file mode 100644 index 0000000..e906d3a --- /dev/null +++ b/scientific-packages/pymoo/scripts/decision_making_example.py @@ -0,0 +1,161 @@ +""" +Multi-criteria decision making example using pymoo. + +This script demonstrates how to select preferred solutions from +a Pareto front using various MCDM methods. +""" + +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.mcdm.pseudo_weights import PseudoWeights +from pymoo.visualization.scatter import Scatter +from pymoo.visualization.petal import Petal +import numpy as np + + +def run_optimization_for_decision_making(): + """Run optimization to obtain Pareto front.""" + + print("Running optimization to obtain Pareto front...") + + # Solve ZDT1 problem + problem = get_problem("zdt1") + algorithm = NSGA2(pop_size=100) + + result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=False + ) + + print(f"Obtained {len(result.F)} solutions in Pareto front\n") + + return problem, result + + +def apply_pseudo_weights(result, weights): + """Apply pseudo-weights MCDM method.""" + + print(f"Applying Pseudo-Weights with weights: {weights}") + + # Normalize objectives to [0, 1] + F_norm = (result.F - result.F.min(axis=0)) / (result.F.max(axis=0) - result.F.min(axis=0)) + + # Apply MCDM + dm = PseudoWeights(weights) + selected_idx = dm.do(F_norm) + + selected_x = result.X[selected_idx] + selected_f = result.F[selected_idx] + + print(f"Selected solution (decision variables): {selected_x}") + print(f"Selected solution (objectives): {selected_f}") + print() + + return selected_idx, selected_x, selected_f + + +def compare_different_preferences(result): + """Compare selections with different preference weights.""" + + print("="*60) + print("COMPARING DIFFERENT PREFERENCE WEIGHTS") + print("="*60 + "\n") + + # Define different preference scenarios + scenarios = [ + ("Equal preference", np.array([0.5, 0.5])), + ("Prefer f1", np.array([0.8, 0.2])), + ("Prefer f2", np.array([0.2, 0.8])), + ] + + selections = {} + + for name, weights in scenarios: + print(f"Scenario: {name}") + idx, x, f = apply_pseudo_weights(result, weights) + selections[name] = (idx, f) + + # Visualize all selections + plot = Scatter(title="Decision Making - Different Preferences") + plot.add(result.F, color="lightgray", alpha=0.5, s=20, label="Pareto Front") + + colors = ["red", "blue", "green"] + for (name, (idx, f)), color in zip(selections.items(), colors): + plot.add(f, color=color, s=100, marker="*", label=name) + + plot.show() + + return selections + + +def visualize_selected_solutions(result, selections): + """Visualize selected solutions using petal diagram.""" + + # Get objective bounds for normalization + f_min = result.F.min(axis=0) + f_max = result.F.max(axis=0) + + plot = Petal( + title="Selected Solutions Comparison", + bounds=[f_min, f_max], + labels=["f1", "f2"] + ) + + colors = ["red", "blue", "green"] + for (name, (idx, f)), color in zip(selections.items(), colors): + plot.add(f, color=color, label=name) + + plot.show() + + +def find_extreme_solutions(result): + """Find extreme solutions (best in each objective).""" + + print("\n" + "="*60) + print("EXTREME SOLUTIONS") + print("="*60 + "\n") + + # Best f1 (minimize f1) + best_f1_idx = np.argmin(result.F[:, 0]) + print(f"Best f1 solution: {result.F[best_f1_idx]}") + print(f" Decision variables: {result.X[best_f1_idx]}\n") + + # Best f2 (minimize f2) + best_f2_idx = np.argmin(result.F[:, 1]) + print(f"Best f2 solution: {result.F[best_f2_idx]}") + print(f" Decision variables: {result.X[best_f2_idx]}\n") + + return best_f1_idx, best_f2_idx + + +def main(): + """Main execution function.""" + + # Step 1: Run optimization + problem, result = run_optimization_for_decision_making() + + # Step 2: Find extreme solutions + best_f1_idx, best_f2_idx = find_extreme_solutions(result) + + # Step 3: Compare different preference weights + selections = compare_different_preferences(result) + + # Step 4: Visualize selections with petal diagram + visualize_selected_solutions(result, selections) + + print("="*60) + print("DECISION MAKING EXAMPLE COMPLETED") + print("="*60) + print("\nKey Takeaways:") + print("1. Different weights lead to different selected solutions") + print("2. Higher weight on an objective selects solutions better in that objective") + print("3. Visualization helps understand trade-offs") + print("4. MCDM methods help formalize decision maker preferences") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pymoo/scripts/many_objective_example.py b/scientific-packages/pymoo/scripts/many_objective_example.py new file mode 100644 index 0000000..2f3cb24 --- /dev/null +++ b/scientific-packages/pymoo/scripts/many_objective_example.py @@ -0,0 +1,72 @@ +""" +Many-objective optimization example using pymoo. + +This script demonstrates many-objective optimization (4+ objectives) +using NSGA-III on the DTLZ2 benchmark problem. +""" + +from pymoo.algorithms.moo.nsga3 import NSGA3 +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.util.ref_dirs import get_reference_directions +from pymoo.visualization.pcp import PCP +import numpy as np + + +def run_many_objective_optimization(): + """Run many-objective optimization example.""" + + # Define the problem - DTLZ2 with 5 objectives + n_obj = 5 + problem = get_problem("dtlz2", n_obj=n_obj) + + # Generate reference directions for NSGA-III + # Das-Dennis method for uniform distribution + ref_dirs = get_reference_directions("das-dennis", n_obj, n_partitions=12) + + print(f"Number of reference directions: {len(ref_dirs)}") + + # Configure NSGA-III algorithm + algorithm = NSGA3( + ref_dirs=ref_dirs, + eliminate_duplicates=True + ) + + # Run optimization + result = minimize( + problem, + algorithm, + ('n_gen', 300), + seed=1, + verbose=True + ) + + # Print results summary + print("\n" + "="*60) + print("MANY-OBJECTIVE OPTIMIZATION RESULTS") + print("="*60) + print(f"Number of objectives: {n_obj}") + print(f"Number of solutions: {len(result.F)}") + print(f"Number of generations: {result.algorithm.n_gen}") + print(f"Number of function evaluations: {result.algorithm.evaluator.n_eval}") + + # Show objective space statistics + print("\nObjective space statistics:") + print(f"Minimum values per objective: {result.F.min(axis=0)}") + print(f"Maximum values per objective: {result.F.max(axis=0)}") + print("="*60) + + # Visualize using Parallel Coordinate Plot + plot = PCP( + title=f"DTLZ2 ({n_obj} objectives) - NSGA-III Results", + labels=[f"f{i+1}" for i in range(n_obj)], + normalize_each_axis=True + ) + plot.add(result.F, alpha=0.3, color="blue") + plot.show() + + return result + + +if __name__ == "__main__": + result = run_many_objective_optimization() diff --git a/scientific-packages/pymoo/scripts/multi_objective_example.py b/scientific-packages/pymoo/scripts/multi_objective_example.py new file mode 100644 index 0000000..0f5dfd6 --- /dev/null +++ b/scientific-packages/pymoo/scripts/multi_objective_example.py @@ -0,0 +1,63 @@ +""" +Multi-objective optimization example using pymoo. + +This script demonstrates multi-objective optimization using +NSGA-II on the ZDT1 benchmark problem. +""" + +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.visualization.scatter import Scatter +import matplotlib.pyplot as plt + + +def run_multi_objective_optimization(): + """Run multi-objective optimization example.""" + + # Define the problem - ZDT1 (bi-objective) + problem = get_problem("zdt1") + + # Configure NSGA-II algorithm + algorithm = NSGA2( + pop_size=100, + eliminate_duplicates=True + ) + + # Run optimization + result = minimize( + problem, + algorithm, + ('n_gen', 200), + seed=1, + verbose=True + ) + + # Print results summary + print("\n" + "="*60) + print("MULTI-OBJECTIVE OPTIMIZATION RESULTS") + print("="*60) + print(f"Number of solutions in Pareto front: {len(result.F)}") + print(f"Number of generations: {result.algorithm.n_gen}") + print(f"Number of function evaluations: {result.algorithm.evaluator.n_eval}") + print("\nFirst 5 solutions (decision variables):") + print(result.X[:5]) + print("\nFirst 5 solutions (objective values):") + print(result.F[:5]) + print("="*60) + + # Visualize results + plot = Scatter(title="ZDT1 - NSGA-II Results") + plot.add(result.F, color="red", alpha=0.7, s=30, label="Obtained Pareto Front") + + # Add true Pareto front for comparison + pf = problem.pareto_front() + plot.add(pf, color="black", alpha=0.3, label="True Pareto Front") + + plot.show() + + return result + + +if __name__ == "__main__": + result = run_multi_objective_optimization() diff --git a/scientific-packages/pymoo/scripts/single_objective_example.py b/scientific-packages/pymoo/scripts/single_objective_example.py new file mode 100644 index 0000000..dd9bea1 --- /dev/null +++ b/scientific-packages/pymoo/scripts/single_objective_example.py @@ -0,0 +1,59 @@ +""" +Single-objective optimization example using pymoo. + +This script demonstrates basic single-objective optimization +using the Genetic Algorithm on the Sphere function. +""" + +from pymoo.algorithms.soo.nonconvex.ga import GA +from pymoo.problems import get_problem +from pymoo.optimize import minimize +from pymoo.operators.crossover.sbx import SBX +from pymoo.operators.mutation.pm import PM +from pymoo.operators.sampling.rnd import FloatRandomSampling +from pymoo.termination import get_termination +import numpy as np + + +def run_single_objective_optimization(): + """Run single-objective optimization example.""" + + # Define the problem - Sphere function (sum of squares) + problem = get_problem("sphere", n_var=10) + + # Configure the algorithm + algorithm = GA( + pop_size=100, + sampling=FloatRandomSampling(), + crossover=SBX(prob=0.9, eta=15), + mutation=PM(eta=20), + eliminate_duplicates=True + ) + + # Define termination criteria + termination = get_termination("n_gen", 100) + + # Run optimization + result = minimize( + problem, + algorithm, + termination, + seed=1, + verbose=True + ) + + # Print results + print("\n" + "="*60) + print("OPTIMIZATION RESULTS") + print("="*60) + print(f"Best solution: {result.X}") + print(f"Best objective value: {result.F[0]:.6f}") + print(f"Number of generations: {result.algorithm.n_gen}") + print(f"Number of function evaluations: {result.algorithm.evaluator.n_eval}") + print("="*60) + + return result + + +if __name__ == "__main__": + result = run_single_objective_optimization() diff --git a/scientific-packages/pytdc/SKILL.md b/scientific-packages/pytdc/SKILL.md new file mode 100644 index 0000000..fd9803e --- /dev/null +++ b/scientific-packages/pytdc/SKILL.md @@ -0,0 +1,445 @@ +--- +name: pytdc +description: Comprehensive toolkit for therapeutic science and drug discovery using PyTDC (Therapeutics Data Commons). Use this skill when working with drug discovery datasets, ADME/toxicity prediction, drug-target interactions, molecular generation, retrosynthesis, or benchmark evaluations. Applies to tasks involving therapeutic machine learning, pharmacological property prediction, or accessing curated drug discovery datasets. +--- + +# PyTDC (Therapeutics Data Commons) + +## Overview + +PyTDC is an open-science platform providing AI-ready datasets and benchmarks for drug discovery and development. It offers curated datasets spanning the entire therapeutics pipeline, from target discovery through clinical development, with standardized evaluation metrics and meaningful data splits. + +The platform organizes therapeutic tasks into three major categories: single-instance prediction for properties of individual biomedical entities, multi-instance prediction for relationships between multiple entities, and generation for creating new therapeutic molecules. + +## Installation & Setup + +Install PyTDC using pip: + +```bash +pip install PyTDC +``` + +To upgrade to the latest version: + +```bash +pip install PyTDC --upgrade +``` + +Core dependencies (automatically installed): +- numpy, pandas, tqdm, seaborn, scikit_learn, fuzzywuzzy + +Additional packages are installed automatically as needed for specific features. + +## Quick Start + +The basic pattern for accessing any TDC dataset follows this structure: + +```python +from tdc. import +data = (name='') +split = data.get_split(method='scaffold', seed=1, frac=[0.7, 0.1, 0.2]) +df = data.get_data(format='df') +``` + +Where: +- ``: One of `single_pred`, `multi_pred`, or `generation` +- ``: Specific task category (e.g., ADME, DTI, MolGen) +- ``: Dataset name within that task + +**Example - Loading ADME data:** + +```python +from tdc.single_pred import ADME +data = ADME(name='Caco2_Wang') +split = data.get_split(method='scaffold') +# Returns dict with 'train', 'valid', 'test' DataFrames +``` + +## Single-Instance Prediction Tasks + +Single-instance prediction involves forecasting properties of individual biomedical entities (molecules, proteins, etc.). + +### Available Task Categories + +#### 1. ADME (Absorption, Distribution, Metabolism, Excretion) + +Predict pharmacokinetic properties of drug molecules. + +```python +from tdc.single_pred import ADME +data = ADME(name='Caco2_Wang') # Intestinal permeability +# Other datasets: HIA_Hou, Bioavailability_Ma, Lipophilicity_AstraZeneca, etc. +``` + +**Common ADME datasets:** +- Caco2 - Intestinal permeability +- HIA - Human intestinal absorption +- Bioavailability - Oral bioavailability +- Lipophilicity - Octanol-water partition coefficient +- Solubility - Aqueous solubility +- BBB - Blood-brain barrier penetration +- CYP - Cytochrome P450 metabolism + +#### 2. Toxicity (Tox) + +Predict toxicity and adverse effects of compounds. + +```python +from tdc.single_pred import Tox +data = Tox(name='hERG') # Cardiotoxicity +# Other datasets: AMES, DILI, Carcinogens_Lagunin, etc. +``` + +**Common toxicity datasets:** +- hERG - Cardiac toxicity +- AMES - Mutagenicity +- DILI - Drug-induced liver injury +- Carcinogens - Carcinogenicity +- ClinTox - Clinical trial toxicity + +#### 3. HTS (High-Throughput Screening) + +Bioactivity predictions from screening data. + +```python +from tdc.single_pred import HTS +data = HTS(name='SARSCoV2_Vitro_Touret') +``` + +#### 4. QM (Quantum Mechanics) + +Quantum mechanical properties of molecules. + +```python +from tdc.single_pred import QM +data = QM(name='QM7') +``` + +#### 5. Other Single Prediction Tasks + +- **Yields**: Chemical reaction yield prediction +- **Epitope**: Epitope prediction for biologics +- **Develop**: Development-stage predictions +- **CRISPROutcome**: Gene editing outcome prediction + +### Data Format + +Single prediction datasets typically return DataFrames with columns: +- `Drug_ID` or `Compound_ID`: Unique identifier +- `Drug` or `X`: SMILES string or molecular representation +- `Y`: Target label (continuous or binary) + +## Multi-Instance Prediction Tasks + +Multi-instance prediction involves forecasting properties of interactions between multiple biomedical entities. + +### Available Task Categories + +#### 1. DTI (Drug-Target Interaction) + +Predict binding affinity between drugs and protein targets. + +```python +from tdc.multi_pred import DTI +data = DTI(name='BindingDB_Kd') +split = data.get_split() +``` + +**Available datasets:** +- BindingDB_Kd - Dissociation constant (52,284 pairs) +- BindingDB_IC50 - Half-maximal inhibitory concentration (991,486 pairs) +- BindingDB_Ki - Inhibition constant (375,032 pairs) +- DAVIS, KIBA - Kinase binding datasets + +**Data format:** Drug_ID, Target_ID, Drug (SMILES), Target (sequence), Y (binding affinity) + +#### 2. DDI (Drug-Drug Interaction) + +Predict interactions between drug pairs. + +```python +from tdc.multi_pred import DDI +data = DDI(name='DrugBank') +split = data.get_split() +``` + +Multi-class classification task predicting interaction types. Dataset contains 191,808 DDI pairs with 1,706 drugs. + +#### 3. PPI (Protein-Protein Interaction) + +Predict protein-protein interactions. + +```python +from tdc.multi_pred import PPI +data = PPI(name='HuRI') +``` + +#### 4. Other Multi-Prediction Tasks + +- **GDA**: Gene-disease associations +- **DrugRes**: Drug resistance prediction +- **DrugSyn**: Drug synergy prediction +- **PeptideMHC**: Peptide-MHC binding +- **AntibodyAff**: Antibody affinity prediction +- **MTI**: miRNA-target interactions +- **Catalyst**: Catalyst prediction +- **TrialOutcome**: Clinical trial outcome prediction + +## Generation Tasks + +Generation tasks involve creating novel biomedical entities with desired properties. + +### 1. Molecular Generation (MolGen) + +Generate diverse, novel molecules with desirable chemical properties. + +```python +from tdc.generation import MolGen +data = MolGen(name='ChEMBL_V29') +split = data.get_split() +``` + +Use with oracles to optimize for specific properties: + +```python +from tdc import Oracle +oracle = Oracle(name='GSK3B') +score = oracle('CC(C)Cc1ccc(cc1)C(C)C(O)=O') # Evaluate SMILES +``` + +See `references/oracles.md` for all available oracle functions. + +### 2. Retrosynthesis (RetroSyn) + +Predict reactants needed to synthesize a target molecule. + +```python +from tdc.generation import RetroSyn +data = RetroSyn(name='USPTO') +split = data.get_split() +``` + +Dataset contains 1,939,253 reactions from USPTO database. + +### 3. Paired Molecule Generation + +Generate molecule pairs (e.g., prodrug-drug pairs). + +```python +from tdc.generation import PairMolGen +data = PairMolGen(name='Prodrug') +``` + +For detailed oracle documentation and molecular generation workflows, refer to `references/oracles.md` and `scripts/molecular_generation.py`. + +## Benchmark Groups + +Benchmark groups provide curated collections of related datasets for systematic model evaluation. + +### ADMET Benchmark Group + +```python +from tdc.benchmark_group import admet_group +group = admet_group(path='data/') + +# Get benchmark datasets +benchmark = group.get('Caco2_Wang') +predictions = {} + +for seed in [1, 2, 3, 4, 5]: + train, valid = benchmark['train'], benchmark['valid'] + # Train model here + predictions[seed] = model.predict(benchmark['test']) + +# Evaluate with required 5 seeds +results = group.evaluate(predictions) +``` + +**ADMET Group includes 22 datasets** covering absorption, distribution, metabolism, excretion, and toxicity. + +### Other Benchmark Groups + +Available benchmark groups include collections for: +- ADMET properties +- Drug-target interactions +- Drug combination prediction +- And more specialized therapeutic tasks + +For benchmark evaluation workflows, see `scripts/benchmark_evaluation.py`. + +## Data Functions + +TDC provides comprehensive data processing utilities organized into four categories. + +### 1. Dataset Splits + +Retrieve train/validation/test partitions with various strategies: + +```python +# Scaffold split (default for most tasks) +split = data.get_split(method='scaffold', seed=1, frac=[0.7, 0.1, 0.2]) + +# Random split +split = data.get_split(method='random', seed=42, frac=[0.8, 0.1, 0.1]) + +# Cold split (for DTI/DDI tasks) +split = data.get_split(method='cold_drug', seed=1) # Unseen drugs in test +split = data.get_split(method='cold_target', seed=1) # Unseen targets in test +``` + +**Available split strategies:** +- `random`: Random shuffling +- `scaffold`: Scaffold-based (for chemical diversity) +- `cold_drug`, `cold_target`, `cold_drug_target`: For DTI tasks +- `temporal`: Time-based splits for temporal datasets + +### 2. Model Evaluation + +Use standardized metrics for evaluation: + +```python +from tdc import Evaluator + +# For binary classification +evaluator = Evaluator(name='ROC-AUC') +score = evaluator(y_true, y_pred) + +# For regression +evaluator = Evaluator(name='RMSE') +score = evaluator(y_true, y_pred) +``` + +**Available metrics:** ROC-AUC, PR-AUC, F1, Accuracy, RMSE, MAE, R2, Spearman, Pearson, and more. + +### 3. Data Processing + +TDC provides 11 key processing utilities: + +```python +from tdc.chem_utils import MolConvert + +# Molecule format conversion +converter = MolConvert(src='SMILES', dst='PyG') +pyg_graph = converter('CC(C)Cc1ccc(cc1)C(C)C(O)=O') +``` + +**Processing utilities include:** +- Molecule format conversion (SMILES, SELFIES, PyG, DGL, ECFP, etc.) +- Molecule filters (PAINS, drug-likeness) +- Label binarization and unit conversion +- Data balancing (over/under-sampling) +- Negative sampling for pair data +- Graph transformation +- Entity retrieval (CID to SMILES, UniProt to sequence) + +For comprehensive utilities documentation, see `references/utilities.md`. + +### 4. Molecule Generation Oracles + +TDC provides 17+ oracle functions for molecular optimization: + +```python +from tdc import Oracle + +# Single oracle +oracle = Oracle(name='DRD2') +score = oracle('CC(C)Cc1ccc(cc1)C(C)C(O)=O') + +# Multiple oracles +oracle = Oracle(name='JNK3') +scores = oracle(['SMILES1', 'SMILES2', 'SMILES3']) +``` + +For complete oracle documentation, see `references/oracles.md`. + +## Advanced Features + +### Retrieve Available Datasets + +```python +from tdc.utils import retrieve_dataset_names + +# Get all ADME datasets +adme_datasets = retrieve_dataset_names('ADME') + +# Get all DTI datasets +dti_datasets = retrieve_dataset_names('DTI') +``` + +### Label Transformations + +```python +# Get label mapping +label_map = data.get_label_map(name='DrugBank') + +# Convert labels +from tdc.chem_utils import label_transform +transformed = label_transform(y, from_unit='nM', to_unit='p') +``` + +### Database Queries + +```python +from tdc.utils import cid2smiles, uniprot2seq + +# Convert PubChem CID to SMILES +smiles = cid2smiles(2244) + +# Convert UniProt ID to amino acid sequence +sequence = uniprot2seq('P12345') +``` + +## Common Workflows + +### Workflow 1: Train a Single Prediction Model + +See `scripts/load_and_split_data.py` for a complete example: + +```python +from tdc.single_pred import ADME +from tdc import Evaluator + +# Load data +data = ADME(name='Caco2_Wang') +split = data.get_split(method='scaffold', seed=42) + +train, valid, test = split['train'], split['valid'], split['test'] + +# Train model (user implements) +# model.fit(train['Drug'], train['Y']) + +# Evaluate +evaluator = Evaluator(name='MAE') +# score = evaluator(test['Y'], predictions) +``` + +### Workflow 2: Benchmark Evaluation + +See `scripts/benchmark_evaluation.py` for a complete example with multiple seeds and proper evaluation protocol. + +### Workflow 3: Molecular Generation with Oracles + +See `scripts/molecular_generation.py` for an example of goal-directed generation using oracle functions. + +## Resources + +This skill includes bundled resources for common TDC workflows: + +### scripts/ + +- `load_and_split_data.py`: Template for loading and splitting TDC datasets with various strategies +- `benchmark_evaluation.py`: Template for running benchmark group evaluations with proper 5-seed protocol +- `molecular_generation.py`: Template for molecular generation using oracle functions + +### references/ + +- `datasets.md`: Comprehensive catalog of all available datasets organized by task type +- `oracles.md`: Complete documentation of all 17+ molecule generation oracles +- `utilities.md`: Detailed guide to data processing, splitting, and evaluation utilities + +## Additional Resources + +- **Official Website**: https://tdcommons.ai +- **Documentation**: https://tdc.readthedocs.io +- **GitHub**: https://github.com/mims-harvard/TDC +- **Paper**: NeurIPS 2021 - "Therapeutics Data Commons: Machine Learning Datasets and Tasks for Drug Discovery and Development" diff --git a/scientific-packages/pytdc/references/datasets.md b/scientific-packages/pytdc/references/datasets.md new file mode 100644 index 0000000..e936f0a --- /dev/null +++ b/scientific-packages/pytdc/references/datasets.md @@ -0,0 +1,246 @@ +# TDC Datasets Comprehensive Catalog + +This document provides a comprehensive catalog of all available datasets in the Therapeutics Data Commons, organized by task category. + +## Single-Instance Prediction Datasets + +### ADME (Absorption, Distribution, Metabolism, Excretion) + +**Absorption:** +- `Caco2_Wang` - Caco-2 cell permeability (906 compounds) +- `Caco2_AstraZeneca` - Caco-2 permeability from AstraZeneca (700 compounds) +- `HIA_Hou` - Human intestinal absorption (578 compounds) +- `Pgp_Broccatelli` - P-glycoprotein inhibition (1,212 compounds) +- `Bioavailability_Ma` - Oral bioavailability (640 compounds) +- `F20_edrug3d` - Oral bioavailability F>=20% (1,017 compounds) +- `F30_edrug3d` - Oral bioavailability F>=30% (1,017 compounds) + +**Distribution:** +- `BBB_Martins` - Blood-brain barrier penetration (1,975 compounds) +- `PPBR_AZ` - Plasma protein binding rate (1,797 compounds) +- `VDss_Lombardo` - Volume of distribution at steady state (1,130 compounds) + +**Metabolism:** +- `CYP2C19_Veith` - CYP2C19 inhibition (12,665 compounds) +- `CYP2D6_Veith` - CYP2D6 inhibition (13,130 compounds) +- `CYP3A4_Veith` - CYP3A4 inhibition (12,328 compounds) +- `CYP1A2_Veith` - CYP1A2 inhibition (12,579 compounds) +- `CYP2C9_Veith` - CYP2C9 inhibition (12,092 compounds) +- `CYP2C9_Substrate_CarbonMangels` - CYP2C9 substrate (666 compounds) +- `CYP2D6_Substrate_CarbonMangels` - CYP2D6 substrate (664 compounds) +- `CYP3A4_Substrate_CarbonMangels` - CYP3A4 substrate (667 compounds) + +**Excretion:** +- `Half_Life_Obach` - Half-life (667 compounds) +- `Clearance_Hepatocyte_AZ` - Hepatocyte clearance (1,020 compounds) +- `Clearance_Microsome_AZ` - Microsome clearance (1,102 compounds) + +**Solubility & Lipophilicity:** +- `Solubility_AqSolDB` - Aqueous solubility (9,982 compounds) +- `Lipophilicity_AstraZeneca` - Lipophilicity (logD) (4,200 compounds) +- `HydrationFreeEnergy_FreeSolv` - Hydration free energy (642 compounds) + +### Toxicity + +**Organ Toxicity:** +- `hERG` - hERG channel inhibition/cardiotoxicity (648 compounds) +- `hERG_Karim` - hERG blockers extended dataset (13,445 compounds) +- `DILI` - Drug-induced liver injury (475 compounds) +- `Skin_Reaction` - Skin reaction (404 compounds) +- `Carcinogens_Lagunin` - Carcinogenicity (278 compounds) +- `Respiratory_Toxicity` - Respiratory toxicity (278 compounds) + +**General Toxicity:** +- `AMES` - Ames mutagenicity (7,255 compounds) +- `LD50_Zhu` - Acute toxicity LD50 (7,385 compounds) +- `ClinTox` - Clinical trial toxicity (1,478 compounds) +- `SkinSensitization` - Skin sensitization (278 compounds) +- `EyeCorrosion` - Eye corrosion (278 compounds) +- `EyeIrritation` - Eye irritation (278 compounds) + +**Environmental Toxicity:** +- `Tox21-AhR` - Nuclear receptor signaling (8,169 compounds) +- `Tox21-AR` - Androgen receptor (9,362 compounds) +- `Tox21-AR-LBD` - Androgen receptor ligand binding (8,343 compounds) +- `Tox21-ARE` - Antioxidant response element (6,475 compounds) +- `Tox21-aromatase` - Aromatase inhibition (6,733 compounds) +- `Tox21-ATAD5` - DNA damage (8,163 compounds) +- `Tox21-ER` - Estrogen receptor (7,257 compounds) +- `Tox21-ER-LBD` - Estrogen receptor ligand binding (8,163 compounds) +- `Tox21-HSE` - Heat shock response (8,162 compounds) +- `Tox21-MMP` - Mitochondrial membrane potential (7,394 compounds) +- `Tox21-p53` - p53 pathway (8,163 compounds) +- `Tox21-PPAR-gamma` - PPAR gamma activation (7,396 compounds) + +### HTS (High-Throughput Screening) + +**SARS-CoV-2:** +- `SARSCoV2_Vitro_Touret` - In vitro antiviral activity (1,484 compounds) +- `SARSCoV2_3CLPro_Diamond` - 3CL protease inhibition (879 compounds) +- `SARSCoV2_Vitro_AlabdulKareem` - In vitro screening (5,953 compounds) + +**Other Targets:** +- `Orexin1_Receptor_Butkiewicz` - Orexin receptor screening (4,675 compounds) +- `M1_Receptor_Agonist_Butkiewicz` - M1 receptor agonist (1,700 compounds) +- `M1_Receptor_Antagonist_Butkiewicz` - M1 receptor antagonist (1,700 compounds) +- `HIV_Butkiewicz` - HIV inhibition (40,000+ compounds) +- `ToxCast` - Environmental chemical screening (8,597 compounds) + +### QM (Quantum Mechanics) + +- `QM7` - Quantum mechanics properties (7,160 molecules) +- `QM8` - Electronic spectra and excited states (21,786 molecules) +- `QM9` - Geometric, energetic, electronic, thermodynamic properties (133,885 molecules) + +### Yields + +- `Buchwald-Hartwig` - Reaction yield prediction (3,955 reactions) +- `USPTO_Yields` - Yield prediction from USPTO (853,879 reactions) + +### Epitope + +- `IEDBpep-DiseaseBinder` - Disease-associated epitope binding (6,080 peptides) +- `IEDBpep-NonBinder` - Non-binding peptides (24,320 peptides) + +### Develop (Development) + +- `Manufacturing` - Manufacturing success prediction +- `Formulation` - Formulation stability + +### CRISPROutcome + +- `CRISPROutcome_Doench` - Gene editing efficiency prediction (5,310 guide RNAs) + +## Multi-Instance Prediction Datasets + +### DTI (Drug-Target Interaction) + +**Binding Affinity:** +- `BindingDB_Kd` - Dissociation constant (52,284 pairs, 10,665 drugs, 1,413 proteins) +- `BindingDB_IC50` - Half-maximal inhibitory concentration (991,486 pairs, 549,205 drugs, 5,078 proteins) +- `BindingDB_Ki` - Inhibition constant (375,032 pairs, 174,662 drugs, 3,070 proteins) + +**Kinase Binding:** +- `DAVIS` - Davis kinase binding dataset (30,056 pairs, 68 drugs, 442 proteins) +- `KIBA` - KIBA kinase binding dataset (118,254 pairs, 2,111 drugs, 229 proteins) + +**Binary Interaction:** +- `BindingDB_Patent` - Patent-derived DTI (8,503 pairs) +- `BindingDB_Approval` - FDA-approved drug DTI (1,649 pairs) + +### DDI (Drug-Drug Interaction) + +- `DrugBank` - Drug-drug interactions (191,808 pairs, 1,706 drugs) +- `TWOSIDES` - Side effect-based DDI (4,649,441 pairs, 645 drugs) + +### PPI (Protein-Protein Interaction) + +- `HuRI` - Human reference protein interactome (52,569 interactions) +- `STRING` - Protein functional associations (19,247 interactions) + +### GDA (Gene-Disease Association) + +- `DisGeNET` - Gene-disease associations (81,746 pairs) +- `PrimeKG_GDA` - Gene-disease from PrimeKG knowledge graph + +### DrugRes (Drug Response/Resistance) + +- `GDSC1` - Genomics of Drug Sensitivity in Cancer v1 (178,000 pairs) +- `GDSC2` - Genomics of Drug Sensitivity in Cancer v2 (125,000 pairs) + +### DrugSyn (Drug Synergy) + +- `DrugComb` - Drug combination synergy (345,502 combinations) +- `DrugCombDB` - Drug combination database (448,555 combinations) +- `OncoPolyPharmacology` - Oncology drug combinations (22,737 combinations) + +### PeptideMHC + +- `MHC1_NetMHCpan` - MHC class I binding (184,983 pairs) +- `MHC2_NetMHCIIpan` - MHC class II binding (134,281 pairs) + +### AntibodyAff (Antibody Affinity) + +- `Protein_SAbDab` - Antibody-antigen affinity (1,500+ pairs) + +### MTI (miRNA-Target Interaction) + +- `miRTarBase` - Experimentally validated miRNA-target interactions (380,639 pairs) + +### Catalyst + +- `USPTO_Catalyst` - Catalyst prediction for reactions (11,000+ reactions) + +### TrialOutcome + +- `TrialOutcome_WuXi` - Clinical trial outcome prediction (3,769 trials) + +## Generation Datasets + +### MolGen (Molecular Generation) + +- `ChEMBL_V29` - Drug-like molecules from ChEMBL (1,941,410 molecules) +- `ZINC` - ZINC database subset (100,000+ molecules) +- `GuacaMol` - Goal-directed benchmark molecules +- `Moses` - Molecular sets benchmark (1,936,962 molecules) + +### RetroSyn (Retrosynthesis) + +- `USPTO` - Retrosynthesis from USPTO patents (1,939,253 reactions) +- `USPTO-50K` - Curated USPTO subset (50,000 reactions) + +### PairMolGen (Paired Molecule Generation) + +- `Prodrug` - Prodrug to drug transformations (1,000+ pairs) +- `Metabolite` - Drug to metabolite transformations + +## Using retrieve_dataset_names + +To programmatically access all available datasets for a specific task: + +```python +from tdc.utils import retrieve_dataset_names + +# Get all datasets for a specific task +adme_datasets = retrieve_dataset_names('ADME') +tox_datasets = retrieve_dataset_names('Tox') +dti_datasets = retrieve_dataset_names('DTI') +hts_datasets = retrieve_dataset_names('HTS') +``` + +## Dataset Statistics + +Access dataset statistics directly: + +```python +from tdc.single_pred import ADME +data = ADME(name='Caco2_Wang') + +# Print basic statistics +data.print_stats() + +# Get label distribution +data.label_distribution() +``` + +## Loading Datasets + +All datasets follow the same loading pattern: + +```python +from tdc. import +data = (name='') + +# Get full dataset +df = data.get_data(format='df') # or 'dict', 'DeepPurpose', etc. + +# Get train/valid/test split +split = data.get_split(method='scaffold', seed=1, frac=[0.7, 0.1, 0.2]) +``` + +## Notes + +- Dataset sizes and statistics are approximate and may be updated +- New datasets are regularly added to TDC +- Some datasets may require additional dependencies +- Check the official TDC website for the most up-to-date dataset list: https://tdcommons.ai/overview/ diff --git a/scientific-packages/pytdc/references/oracles.md b/scientific-packages/pytdc/references/oracles.md new file mode 100644 index 0000000..e12f157 --- /dev/null +++ b/scientific-packages/pytdc/references/oracles.md @@ -0,0 +1,400 @@ +# TDC Molecule Generation Oracles + +Oracles are functions that evaluate the quality of generated molecules across specific dimensions. TDC provides 17+ oracle functions for molecular optimization tasks in de novo drug design. + +## Overview + +Oracles measure molecular properties and serve two main purposes: + +1. **Goal-Directed Generation**: Optimize molecules to maximize/minimize specific properties +2. **Distribution Learning**: Evaluate whether generated molecules match desired property distributions + +## Using Oracles + +### Basic Usage + +```python +from tdc import Oracle + +# Initialize oracle +oracle = Oracle(name='GSK3B') + +# Evaluate single molecule (SMILES string) +score = oracle('CC(C)Cc1ccc(cc1)C(C)C(O)=O') + +# Evaluate multiple molecules +scores = oracle(['SMILES1', 'SMILES2', 'SMILES3']) +``` + +### Oracle Categories + +TDC oracles are organized into several categories based on the molecular property being evaluated. + +## Biochemical Oracles + +Predict binding affinity or activity against biological targets. + +### Target-Specific Oracles + +**DRD2 - Dopamine Receptor D2** +```python +oracle = Oracle(name='DRD2') +score = oracle(smiles) +``` +- Measures binding affinity to DRD2 receptor +- Important for neurological and psychiatric drug development +- Higher scores indicate stronger binding + +**GSK3B - Glycogen Synthase Kinase-3 Beta** +```python +oracle = Oracle(name='GSK3B') +score = oracle(smiles) +``` +- Predicts GSK3β inhibition +- Relevant for Alzheimer's, diabetes, and cancer research +- Higher scores indicate better inhibition + +**JNK3 - c-Jun N-terminal Kinase 3** +```python +oracle = Oracle(name='JNK3') +score = oracle(smiles) +``` +- Measures JNK3 kinase inhibition +- Target for neurodegenerative diseases +- Higher scores indicate stronger inhibition + +**5HT2A - Serotonin 2A Receptor** +```python +oracle = Oracle(name='5HT2A') +score = oracle(smiles) +``` +- Predicts serotonin receptor binding +- Important for psychiatric medications +- Higher scores indicate stronger binding + +**ACE - Angiotensin-Converting Enzyme** +```python +oracle = Oracle(name='ACE') +score = oracle(smiles) +``` +- Measures ACE inhibition +- Target for hypertension treatment +- Higher scores indicate better inhibition + +**MAPK - Mitogen-Activated Protein Kinase** +```python +oracle = Oracle(name='MAPK') +score = oracle(smiles) +``` +- Predicts MAPK inhibition +- Target for cancer and inflammatory diseases + +**CDK - Cyclin-Dependent Kinase** +```python +oracle = Oracle(name='CDK') +score = oracle(smiles) +``` +- Measures CDK inhibition +- Important for cancer drug development + +**P38 - p38 MAP Kinase** +```python +oracle = Oracle(name='P38') +score = oracle(smiles) +``` +- Predicts p38 MAPK inhibition +- Target for inflammatory diseases + +**PARP1 - Poly (ADP-ribose) Polymerase 1** +```python +oracle = Oracle(name='PARP1') +score = oracle(smiles) +``` +- Measures PARP1 inhibition +- Target for cancer treatment (DNA repair mechanism) + +**PIK3CA - Phosphatidylinositol-4,5-Bisphosphate 3-Kinase** +```python +oracle = Oracle(name='PIK3CA') +score = oracle(smiles) +``` +- Predicts PIK3CA inhibition +- Important target in oncology + +## Physicochemical Oracles + +Evaluate drug-like properties and ADME characteristics. + +### Drug-Likeness Oracles + +**QED - Quantitative Estimate of Drug-likeness** +```python +oracle = Oracle(name='QED') +score = oracle(smiles) +``` +- Combines multiple physicochemical properties +- Score ranges from 0 (non-drug-like) to 1 (drug-like) +- Based on Bickerton et al. criteria + +**Lipinski - Rule of Five** +```python +oracle = Oracle(name='Lipinski') +score = oracle(smiles) +``` +- Number of Lipinski rule violations +- Rules: MW ≤ 500, logP ≤ 5, HBD ≤ 5, HBA ≤ 10 +- Score of 0 means fully compliant + +### Molecular Properties + +**SA - Synthetic Accessibility** +```python +oracle = Oracle(name='SA') +score = oracle(smiles) +``` +- Estimates ease of synthesis +- Score ranges from 1 (easy) to 10 (difficult) +- Lower scores indicate easier synthesis + +**LogP - Octanol-Water Partition Coefficient** +```python +oracle = Oracle(name='LogP') +score = oracle(smiles) +``` +- Measures lipophilicity +- Important for membrane permeability +- Typical drug-like range: 0-5 + +**MW - Molecular Weight** +```python +oracle = Oracle(name='MW') +score = oracle(smiles) +``` +- Returns molecular weight in Daltons +- Drug-like range typically 150-500 Da + +## Composite Oracles + +Combine multiple properties for multi-objective optimization. + +**Isomer Meta** +```python +oracle = Oracle(name='Isomer_Meta') +score = oracle(smiles) +``` +- Evaluates specific isomeric properties +- Used for stereochemistry optimization + +**Median Molecules** +```python +oracle = Oracle(name='Median1', 'Median2') +score = oracle(smiles) +``` +- Tests ability to generate molecules with median properties +- Useful for distribution learning benchmarks + +**Rediscovery** +```python +oracle = Oracle(name='Rediscovery') +score = oracle(smiles) +``` +- Measures similarity to known reference molecules +- Tests ability to regenerate existing drugs + +**Similarity** +```python +oracle = Oracle(name='Similarity') +score = oracle(smiles) +``` +- Computes structural similarity to target molecules +- Based on molecular fingerprints (typically Tanimoto similarity) + +**Uniqueness** +```python +oracle = Oracle(name='Uniqueness') +scores = oracle(smiles_list) +``` +- Measures diversity in generated molecule set +- Returns fraction of unique molecules + +**Novelty** +```python +oracle = Oracle(name='Novelty') +scores = oracle(smiles_list, training_set) +``` +- Measures how different generated molecules are from training set +- Higher scores indicate more novel structures + +## Specialized Oracles + +**ASKCOS - Retrosynthesis Scoring** +```python +oracle = Oracle(name='ASKCOS') +score = oracle(smiles) +``` +- Evaluates synthetic feasibility using retrosynthesis +- Requires ASKCOS backend (IBM RXN) +- Scores based on retrosynthetic route availability + +**Docking Score** +```python +oracle = Oracle(name='Docking') +score = oracle(smiles) +``` +- Molecular docking score against target protein +- Requires protein structure and docking software +- Lower scores typically indicate better binding + +**Vina - AutoDock Vina Score** +```python +oracle = Oracle(name='Vina') +score = oracle(smiles) +``` +- Uses AutoDock Vina for protein-ligand docking +- Predicts binding affinity in kcal/mol +- More negative scores indicate stronger binding + +## Multi-Objective Optimization + +Combine multiple oracles for multi-property optimization: + +```python +from tdc import Oracle + +# Initialize multiple oracles +qed_oracle = Oracle(name='QED') +sa_oracle = Oracle(name='SA') +drd2_oracle = Oracle(name='DRD2') + +# Define custom scoring function +def multi_objective_score(smiles): + qed = qed_oracle(smiles) + sa = 1 / (1 + sa_oracle(smiles)) # Invert SA (lower is better) + drd2 = drd2_oracle(smiles) + + # Weighted combination + return 0.3 * qed + 0.3 * sa + 0.4 * drd2 + +# Evaluate molecule +score = multi_objective_score('CC(C)Cc1ccc(cc1)C(C)C(O)=O') +``` + +## Oracle Performance Considerations + +### Speed +- **Fast**: QED, SA, LogP, MW, Lipinski (rule-based calculations) +- **Medium**: Target-specific ML models (DRD2, GSK3B, etc.) +- **Slow**: Docking-based oracles (Vina, ASKCOS) + +### Reliability +- Oracles are ML models trained on specific datasets +- May not generalize to all chemical spaces +- Use multiple oracles to validate results + +### Batch Processing +```python +# Efficient batch evaluation +oracle = Oracle(name='GSK3B') +smiles_list = ['SMILES1', 'SMILES2', ..., 'SMILES1000'] +scores = oracle(smiles_list) # Faster than individual calls +``` + +## Common Workflows + +### Goal-Directed Generation +```python +from tdc import Oracle +from tdc.generation import MolGen + +# Load training data +data = MolGen(name='ChEMBL_V29') +train_smiles = data.get_data()['Drug'].tolist() + +# Initialize oracle +oracle = Oracle(name='GSK3B') + +# Generate molecules (user implements generative model) +# generated_smiles = generator.generate(n=1000) + +# Evaluate generated molecules +scores = oracle(generated_smiles) +best_molecules = [(s, score) for s, score in zip(generated_smiles, scores)] +best_molecules.sort(key=lambda x: x[1], reverse=True) + +print(f"Top 10 molecules:") +for smiles, score in best_molecules[:10]: + print(f"{smiles}: {score:.3f}") +``` + +### Distribution Learning +```python +from tdc import Oracle +import numpy as np + +# Initialize oracle +oracle = Oracle(name='QED') + +# Evaluate training set +train_scores = oracle(train_smiles) +train_mean = np.mean(train_scores) +train_std = np.std(train_scores) + +# Evaluate generated set +gen_scores = oracle(generated_smiles) +gen_mean = np.mean(gen_scores) +gen_std = np.std(gen_scores) + +# Compare distributions +print(f"Training: μ={train_mean:.3f}, σ={train_std:.3f}") +print(f"Generated: μ={gen_mean:.3f}, σ={gen_std:.3f}") +``` + +## Integration with TDC Benchmarks + +```python +from tdc.generation import MolGen + +# Use with GuacaMol benchmark +data = MolGen(name='GuacaMol') + +# Oracles are automatically integrated +# Each GuacaMol task has associated oracle +benchmark_results = data.evaluate_guacamol( + generated_molecules=your_molecules, + oracle_name='GSK3B' +) +``` + +## Notes + +- Oracle scores are predictions, not experimental measurements +- Always validate top candidates experimentally +- Different oracles may have different score ranges and interpretations +- Some oracles require additional dependencies or API access +- Check oracle documentation for specific details: https://tdcommons.ai/functions/oracles/ + +## Adding Custom Oracles + +To create custom oracle functions: + +```python +class CustomOracle: + def __init__(self): + # Initialize your model/method + pass + + def __call__(self, smiles): + # Implement your scoring logic + # Return score or list of scores + pass + +# Use like built-in oracles +custom_oracle = CustomOracle() +score = custom_oracle('CC(C)Cc1ccc(cc1)C(C)C(O)=O') +``` + +## References + +- TDC Oracles Documentation: https://tdcommons.ai/functions/oracles/ +- GuacaMol Paper: "GuacaMol: Benchmarking Models for de Novo Molecular Design" +- MOSES Paper: "Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models" diff --git a/scientific-packages/pytdc/references/utilities.md b/scientific-packages/pytdc/references/utilities.md new file mode 100644 index 0000000..c9e029f --- /dev/null +++ b/scientific-packages/pytdc/references/utilities.md @@ -0,0 +1,684 @@ +# TDC Utilities and Data Functions + +This document provides comprehensive documentation for TDC's data processing, evaluation, and utility functions. + +## Overview + +TDC provides utilities organized into four main categories: +1. **Dataset Splits** - Train/validation/test partitioning strategies +2. **Model Evaluation** - Standardized performance metrics +3. **Data Processing** - Molecule conversion, filtering, and transformation +4. **Entity Retrieval** - Database queries and conversions + +## 1. Dataset Splits + +Dataset splitting is crucial for evaluating model generalization. TDC provides multiple splitting strategies designed for therapeutic ML. + +### Basic Split Usage + +```python +from tdc.single_pred import ADME + +data = ADME(name='Caco2_Wang') + +# Get split with default parameters +split = data.get_split() +# Returns: {'train': DataFrame, 'valid': DataFrame, 'test': DataFrame} + +# Customize split parameters +split = data.get_split( + method='scaffold', + seed=42, + frac=[0.7, 0.1, 0.2] +) +``` + +### Split Methods + +#### Random Split +Random shuffling of data - suitable for general ML tasks. + +```python +split = data.get_split(method='random', seed=1) +``` + +**When to use:** +- Baseline model evaluation +- When chemical/temporal structure is not important +- Quick prototyping + +**Not recommended for:** +- Realistic drug discovery scenarios +- Evaluating generalization to new chemical matter + +#### Scaffold Split +Splits based on molecular scaffolds (Bemis-Murcko scaffolds) - ensures test molecules are structurally distinct from training. + +```python +split = data.get_split(method='scaffold', seed=1) +``` + +**When to use:** +- Default for most single prediction tasks +- Evaluating generalization to new chemical series +- Realistic drug discovery scenarios + +**How it works:** +1. Extract Bemis-Murcko scaffold from each molecule +2. Group molecules by scaffold +3. Assign scaffolds to train/valid/test sets +4. Ensures test molecules have unseen scaffolds + +#### Cold Splits (DTI/DDI Tasks) +For multi-instance prediction, cold splits ensure test set contains unseen drugs, targets, or both. + +**Cold Drug Split:** +```python +from tdc.multi_pred import DTI +data = DTI(name='BindingDB_Kd') +split = data.get_split(method='cold_drug', seed=1) +``` +- Test set contains drugs not seen during training +- Evaluates generalization to new compounds + +**Cold Target Split:** +```python +split = data.get_split(method='cold_target', seed=1) +``` +- Test set contains targets not seen during training +- Evaluates generalization to new proteins + +**Cold Drug-Target Split:** +```python +split = data.get_split(method='cold_drug_target', seed=1) +``` +- Test set contains novel drug-target pairs +- Most challenging evaluation scenario + +#### Temporal Split +For datasets with temporal information - ensures test data is from later time points. + +```python +split = data.get_split(method='temporal', seed=1) +``` + +**When to use:** +- Datasets with time stamps +- Simulating prospective prediction +- Clinical trial outcome prediction + +### Custom Split Fractions + +```python +# 80% train, 10% valid, 10% test +split = data.get_split(method='scaffold', frac=[0.8, 0.1, 0.1]) + +# 70% train, 15% valid, 15% test +split = data.get_split(method='scaffold', frac=[0.7, 0.15, 0.15]) +``` + +### Stratified Splits + +For classification tasks with imbalanced labels: + +```python +split = data.get_split(method='scaffold', stratified=True) +``` + +Maintains label distribution across train/valid/test sets. + +## 2. Model Evaluation + +TDC provides standardized evaluation metrics for different task types. + +### Basic Evaluator Usage + +```python +from tdc import Evaluator + +# Initialize evaluator +evaluator = Evaluator(name='ROC-AUC') + +# Evaluate predictions +score = evaluator(y_true, y_pred) +``` + +### Classification Metrics + +#### ROC-AUC +Receiver Operating Characteristic - Area Under Curve + +```python +evaluator = Evaluator(name='ROC-AUC') +score = evaluator(y_true, y_pred_proba) +``` + +**Best for:** +- Binary classification +- Imbalanced datasets +- Overall discriminative ability + +**Range:** 0-1 (higher is better, 0.5 is random) + +#### PR-AUC +Precision-Recall Area Under Curve + +```python +evaluator = Evaluator(name='PR-AUC') +score = evaluator(y_true, y_pred_proba) +``` + +**Best for:** +- Highly imbalanced datasets +- When positive class is rare +- Complements ROC-AUC + +**Range:** 0-1 (higher is better) + +#### F1 Score +Harmonic mean of precision and recall + +```python +evaluator = Evaluator(name='F1') +score = evaluator(y_true, y_pred_binary) +``` + +**Best for:** +- Balance between precision and recall +- Multi-class classification + +**Range:** 0-1 (higher is better) + +#### Accuracy +Fraction of correct predictions + +```python +evaluator = Evaluator(name='Accuracy') +score = evaluator(y_true, y_pred_binary) +``` + +**Best for:** +- Balanced datasets +- Simple baseline metric + +**Not recommended for:** Imbalanced datasets + +#### Cohen's Kappa +Agreement between predictions and ground truth, accounting for chance + +```python +evaluator = Evaluator(name='Kappa') +score = evaluator(y_true, y_pred_binary) +``` + +**Range:** -1 to 1 (higher is better, 0 is random) + +### Regression Metrics + +#### RMSE - Root Mean Squared Error +```python +evaluator = Evaluator(name='RMSE') +score = evaluator(y_true, y_pred) +``` + +**Best for:** +- Continuous predictions +- Penalizes large errors heavily + +**Range:** 0-∞ (lower is better) + +#### MAE - Mean Absolute Error +```python +evaluator = Evaluator(name='MAE') +score = evaluator(y_true, y_pred) +``` + +**Best for:** +- Continuous predictions +- More robust to outliers than RMSE + +**Range:** 0-∞ (lower is better) + +#### R² - Coefficient of Determination +```python +evaluator = Evaluator(name='R2') +score = evaluator(y_true, y_pred) +``` + +**Best for:** +- Variance explained by model +- Comparing different models + +**Range:** -∞ to 1 (higher is better, 1 is perfect) + +#### MSE - Mean Squared Error +```python +evaluator = Evaluator(name='MSE') +score = evaluator(y_true, y_pred) +``` + +**Range:** 0-∞ (lower is better) + +### Ranking Metrics + +#### Spearman Correlation +Rank correlation coefficient + +```python +evaluator = Evaluator(name='Spearman') +score = evaluator(y_true, y_pred) +``` + +**Best for:** +- Ranking tasks +- Non-linear relationships +- Ordinal data + +**Range:** -1 to 1 (higher is better) + +#### Pearson Correlation +Linear correlation coefficient + +```python +evaluator = Evaluator(name='Pearson') +score = evaluator(y_true, y_pred) +``` + +**Best for:** +- Linear relationships +- Continuous data + +**Range:** -1 to 1 (higher is better) + +### Multi-Label Classification + +```python +evaluator = Evaluator(name='Micro-F1') +score = evaluator(y_true_multilabel, y_pred_multilabel) +``` + +Available: `Micro-F1`, `Macro-F1`, `Micro-AUPR`, `Macro-AUPR` + +### Benchmark Group Evaluation + +For benchmark groups, evaluation requires multiple seeds: + +```python +from tdc.benchmark_group import admet_group + +group = admet_group(path='data/') +benchmark = group.get('Caco2_Wang') + +# Predictions must be dict with seeds as keys +predictions = {} +for seed in [1, 2, 3, 4, 5]: + # Train model and predict + predictions[seed] = model_predictions + +# Evaluate with mean and std across seeds +results = group.evaluate(predictions) +print(results) # {'Caco2_Wang': [mean_score, std_score]} +``` + +## 3. Data Processing + +TDC provides 11 comprehensive data processing utilities. + +### Molecule Format Conversion + +Convert between ~15 molecular representations. + +```python +from tdc.chem_utils import MolConvert + +# SMILES to PyTorch Geometric +converter = MolConvert(src='SMILES', dst='PyG') +pyg_graph = converter('CC(C)Cc1ccc(cc1)C(C)C(O)=O') + +# SMILES to DGL +converter = MolConvert(src='SMILES', dst='DGL') +dgl_graph = converter('CC(C)Cc1ccc(cc1)C(C)C(O)=O') + +# SMILES to Morgan Fingerprint (ECFP) +converter = MolConvert(src='SMILES', dst='ECFP') +fingerprint = converter('CC(C)Cc1ccc(cc1)C(C)C(O)=O') +``` + +**Available formats:** +- **Text**: SMILES, SELFIES, InChI +- **Fingerprints**: ECFP (Morgan), MACCS, RDKit, AtomPair, TopologicalTorsion +- **Graphs**: PyG (PyTorch Geometric), DGL (Deep Graph Library) +- **3D**: Graph3D, Coulomb Matrix, Distance Matrix + +**Batch conversion:** +```python +converter = MolConvert(src='SMILES', dst='PyG') +graphs = converter(['SMILES1', 'SMILES2', 'SMILES3']) +``` + +### Molecule Filters + +Remove non-drug-like molecules using curated chemical rules. + +```python +from tdc.chem_utils import MolFilter + +# Initialize filter with rules +mol_filter = MolFilter( + rules=['PAINS', 'BMS'], # Chemical filter rules + property_filters_dict={ + 'MW': (150, 500), # Molecular weight range + 'LogP': (-0.4, 5.6), # Lipophilicity range + 'HBD': (0, 5), # H-bond donors + 'HBA': (0, 10) # H-bond acceptors + } +) + +# Filter molecules +filtered_smiles = mol_filter(smiles_list) +``` + +**Available filter rules:** +- `PAINS` - Pan-Assay Interference Compounds +- `BMS` - Bristol-Myers Squibb HTS deck filters +- `Glaxo` - GlaxoSmithKline filters +- `Dundee` - University of Dundee filters +- `Inpharmatica` - Inpharmatica filters +- `LINT` - Pfizer LINT filters + +### Label Distribution Visualization + +```python +# Visualize label distribution +data.label_distribution() + +# Print statistics +data.print_stats() +``` + +Displays histogram and computes mean, median, std for continuous labels. + +### Label Binarization + +Convert continuous labels to binary using threshold. + +```python +from tdc.utils import binarize + +# Binarize with threshold +binary_labels = binarize(y_continuous, threshold=5.0, order='ascending') +# order='ascending': values >= threshold become 1 +# order='descending': values <= threshold become 1 +``` + +### Label Units Conversion + +Transform between measurement units. + +```python +from tdc.chem_utils import label_transform + +# Convert nM to pKd +y_pkd = label_transform(y_nM, from_unit='nM', to_unit='p') + +# Convert μM to nM +y_nM = label_transform(y_uM, from_unit='uM', to_unit='nM') +``` + +**Available conversions:** +- Binding affinity: nM, μM, pKd, pKi, pIC50 +- Log transformations +- Natural log conversions + +### Label Meaning + +Get interpretable descriptions for labels. + +```python +# Get label mapping +label_map = data.get_label_map(name='DrugBank') +print(label_map) +# {0: 'No interaction', 1: 'Increased effect', 2: 'Decreased effect', ...} +``` + +### Data Balancing + +Handle class imbalance via over/under-sampling. + +```python +from tdc.utils import balance + +# Oversample minority class +X_balanced, y_balanced = balance(X, y, method='oversample') + +# Undersample majority class +X_balanced, y_balanced = balance(X, y, method='undersample') +``` + +### Graph Transformation for Pair Data + +Convert paired data to graph representations. + +```python +from tdc.utils import create_graph_from_pairs + +# Create graph from drug-drug pairs +graph = create_graph_from_pairs( + pairs=ddi_pairs, # [(drug1, drug2, label), ...] + format='edge_list' # or 'PyG', 'DGL' +) +``` + +### Negative Sampling + +Generate negative samples for binary tasks. + +```python +from tdc.utils import negative_sample + +# Generate negative samples for DTI +negative_pairs = negative_sample( + positive_pairs=known_interactions, + all_drugs=drug_list, + all_targets=target_list, + ratio=1.0 # Negative:positive ratio +) +``` + +**Use cases:** +- Drug-target interaction prediction +- Drug-drug interaction tasks +- Creating balanced datasets + +### Entity Retrieval + +Convert between database identifiers. + +#### PubChem CID to SMILES +```python +from tdc.utils import cid2smiles + +smiles = cid2smiles(2244) # Aspirin +# Returns: 'CC(=O)Oc1ccccc1C(=O)O' +``` + +#### UniProt ID to Amino Acid Sequence +```python +from tdc.utils import uniprot2seq + +sequence = uniprot2seq('P12345') +# Returns: 'MVKVYAPASS...' +``` + +#### Batch Retrieval +```python +# Multiple CIDs +smiles_list = [cid2smiles(cid) for cid in [2244, 5090, 6323]] + +# Multiple UniProt IDs +sequences = [uniprot2seq(uid) for uid in ['P12345', 'Q9Y5S9']] +``` + +## 4. Advanced Utilities + +### Retrieve Dataset Names + +```python +from tdc.utils import retrieve_dataset_names + +# Get all datasets for a task +adme_datasets = retrieve_dataset_names('ADME') +dti_datasets = retrieve_dataset_names('DTI') +tox_datasets = retrieve_dataset_names('Tox') + +print(f"ADME datasets: {adme_datasets}") +``` + +### Fuzzy Search + +TDC supports fuzzy matching for dataset names: + +```python +from tdc.single_pred import ADME + +# These all work (typo-tolerant) +data = ADME(name='Caco2_Wang') +data = ADME(name='caco2_wang') +data = ADME(name='Caco2') # Partial match +``` + +### Data Format Options + +```python +# Pandas DataFrame (default) +df = data.get_data(format='df') + +# Dictionary +data_dict = data.get_data(format='dict') + +# DeepPurpose format (for DeepPurpose library) +dp_format = data.get_data(format='DeepPurpose') + +# PyG/DGL graphs (if applicable) +graphs = data.get_data(format='PyG') +``` + +### Data Loader Utilities + +```python +from tdc.utils import create_fold + +# Create cross-validation folds +folds = create_fold(data, fold=5, seed=42) +# Returns list of (train_idx, test_idx) tuples + +# Iterate through folds +for i, (train_idx, test_idx) in enumerate(folds): + train_data = data.iloc[train_idx] + test_data = data.iloc[test_idx] + # Train and evaluate +``` + +## Common Workflows + +### Workflow 1: Complete Data Pipeline + +```python +from tdc.single_pred import ADME +from tdc import Evaluator +from tdc.chem_utils import MolConvert, MolFilter + +# 1. Load data +data = ADME(name='Caco2_Wang') + +# 2. Filter molecules +mol_filter = MolFilter(rules=['PAINS']) +filtered_data = data.get_data() +filtered_data = filtered_data[ + filtered_data['Drug'].apply(lambda x: mol_filter([x])) +] + +# 3. Split data +split = data.get_split(method='scaffold', seed=42) +train, valid, test = split['train'], split['valid'], split['test'] + +# 4. Convert to graph representations +converter = MolConvert(src='SMILES', dst='PyG') +train_graphs = converter(train['Drug'].tolist()) + +# 5. Train model (user implements) +# model.fit(train_graphs, train['Y']) + +# 6. Evaluate +evaluator = Evaluator(name='MAE') +# score = evaluator(test['Y'], predictions) +``` + +### Workflow 2: Multi-Task Learning Preparation + +```python +from tdc.benchmark_group import admet_group +from tdc.chem_utils import MolConvert + +# Load benchmark group +group = admet_group(path='data/') + +# Get multiple datasets +datasets = ['Caco2_Wang', 'HIA_Hou', 'Bioavailability_Ma'] +all_data = {} + +for dataset_name in datasets: + benchmark = group.get(dataset_name) + all_data[dataset_name] = benchmark + +# Prepare for multi-task learning +converter = MolConvert(src='SMILES', dst='ECFP') +# Process each dataset... +``` + +### Workflow 3: DTI Cold Split Evaluation + +```python +from tdc.multi_pred import DTI +from tdc import Evaluator + +# Load DTI data +data = DTI(name='BindingDB_Kd') + +# Cold drug split +split = data.get_split(method='cold_drug', seed=42) +train, test = split['train'], split['test'] + +# Verify no drug overlap +train_drugs = set(train['Drug_ID']) +test_drugs = set(test['Drug_ID']) +assert len(train_drugs & test_drugs) == 0, "Drug leakage detected!" + +# Train and evaluate +# model.fit(train) +evaluator = Evaluator(name='RMSE') +# score = evaluator(test['Y'], predictions) +``` + +## Best Practices + +1. **Always use meaningful splits** - Use scaffold or cold splits for realistic evaluation +2. **Multiple seeds** - Run experiments with multiple seeds for robust results +3. **Appropriate metrics** - Choose metrics that match your task and dataset characteristics +4. **Data filtering** - Remove PAINS and non-drug-like molecules before training +5. **Format conversion** - Convert molecules to appropriate format for your model +6. **Batch processing** - Use batch operations for efficiency with large datasets + +## Performance Tips + +- Convert molecules in batch mode for faster processing +- Cache converted representations to avoid recomputation +- Use appropriate data formats for your framework (PyG, DGL, etc.) +- Filter data early in the pipeline to reduce computation + +## References + +- TDC Documentation: https://tdc.readthedocs.io +- Data Functions: https://tdcommons.ai/fct_overview/ +- Evaluation Metrics: https://tdcommons.ai/functions/model_eval/ +- Data Splits: https://tdcommons.ai/functions/data_split/ diff --git a/scientific-packages/pytdc/scripts/benchmark_evaluation.py b/scientific-packages/pytdc/scripts/benchmark_evaluation.py new file mode 100644 index 0000000..1568d6b --- /dev/null +++ b/scientific-packages/pytdc/scripts/benchmark_evaluation.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +TDC Benchmark Group Evaluation Template + +This script demonstrates how to use TDC benchmark groups for systematic +model evaluation following the required 5-seed protocol. + +Usage: + python benchmark_evaluation.py +""" + +from tdc.benchmark_group import admet_group +from tdc import Evaluator +import numpy as np +import pandas as pd + + +def load_benchmark_group(): + """ + Load the ADMET benchmark group + """ + print("=" * 60) + print("Loading ADMET Benchmark Group") + print("=" * 60) + + # Initialize benchmark group + group = admet_group(path='data/') + + # Get available benchmarks + print("\nAvailable benchmarks in ADMET group:") + benchmark_names = group.dataset_names + print(f"Total: {len(benchmark_names)} datasets") + + for i, name in enumerate(benchmark_names[:10], 1): + print(f" {i}. {name}") + + if len(benchmark_names) > 10: + print(f" ... and {len(benchmark_names) - 10} more") + + return group + + +def single_dataset_evaluation(group, dataset_name='Caco2_Wang'): + """ + Example: Evaluate on a single dataset with 5-seed protocol + """ + print("\n" + "=" * 60) + print(f"Example 1: Single Dataset Evaluation ({dataset_name})") + print("=" * 60) + + # Get dataset benchmarks + benchmark = group.get(dataset_name) + + print(f"\nBenchmark structure:") + print(f" Seeds: {list(benchmark.keys())}") + + # Required: Evaluate with 5 different seeds + predictions = {} + + for seed in [1, 2, 3, 4, 5]: + print(f"\n--- Seed {seed} ---") + + # Get train/valid data for this seed + train = benchmark[seed]['train'] + valid = benchmark[seed]['valid'] + + print(f"Train size: {len(train)}") + print(f"Valid size: {len(valid)}") + + # TODO: Replace with your model training + # model = YourModel() + # model.fit(train['Drug'], train['Y']) + + # For demonstration, create dummy predictions + # Replace with: predictions[seed] = model.predict(benchmark[seed]['test']) + test = benchmark[seed]['test'] + y_true = test['Y'].values + + # Simulate predictions (add controlled noise) + np.random.seed(seed) + y_pred = y_true + np.random.normal(0, 0.3, len(y_true)) + + predictions[seed] = y_pred + + # Evaluate this seed + evaluator = Evaluator(name='MAE') + score = evaluator(y_true, y_pred) + print(f"MAE for seed {seed}: {score:.4f}") + + # Evaluate across all seeds + print("\n--- Overall Evaluation ---") + results = group.evaluate(predictions) + + print(f"\nResults for {dataset_name}:") + mean_score, std_score = results[dataset_name] + print(f" Mean MAE: {mean_score:.4f}") + print(f" Std MAE: {std_score:.4f}") + + return predictions, results + + +def multiple_datasets_evaluation(group): + """ + Example: Evaluate on multiple datasets + """ + print("\n" + "=" * 60) + print("Example 2: Multiple Datasets Evaluation") + print("=" * 60) + + # Select a subset of datasets for demonstration + selected_datasets = ['Caco2_Wang', 'HIA_Hou', 'Bioavailability_Ma'] + + all_predictions = {} + all_results = {} + + for dataset_name in selected_datasets: + print(f"\n{'='*40}") + print(f"Evaluating: {dataset_name}") + print(f"{'='*40}") + + benchmark = group.get(dataset_name) + predictions = {} + + # Train and predict for each seed + for seed in [1, 2, 3, 4, 5]: + train = benchmark[seed]['train'] + test = benchmark[seed]['test'] + + # TODO: Replace with your model + # model = YourModel() + # model.fit(train['Drug'], train['Y']) + # predictions[seed] = model.predict(test['Drug']) + + # Dummy predictions for demonstration + np.random.seed(seed) + y_true = test['Y'].values + y_pred = y_true + np.random.normal(0, 0.3, len(y_true)) + predictions[seed] = y_pred + + all_predictions[dataset_name] = predictions + + # Evaluate this dataset + results = group.evaluate({dataset_name: predictions}) + all_results[dataset_name] = results[dataset_name] + + mean_score, std_score = results[dataset_name] + print(f" {dataset_name}: {mean_score:.4f} ± {std_score:.4f}") + + # Summary + print("\n" + "=" * 60) + print("Summary of Results") + print("=" * 60) + + results_df = pd.DataFrame([ + { + 'Dataset': name, + 'Mean MAE': f"{mean:.4f}", + 'Std MAE': f"{std:.4f}" + } + for name, (mean, std) in all_results.items() + ]) + + print(results_df.to_string(index=False)) + + return all_predictions, all_results + + +def custom_model_template(): + """ + Template for integrating your own model with TDC benchmarks + """ + print("\n" + "=" * 60) + print("Example 3: Custom Model Template") + print("=" * 60) + + code_template = ''' +# Template for using your own model with TDC benchmarks + +from tdc.benchmark_group import admet_group +from your_library import YourModel # Replace with your model + +# Initialize benchmark group +group = admet_group(path='data/') +benchmark = group.get('Caco2_Wang') + +predictions = {} + +for seed in [1, 2, 3, 4, 5]: + # Get data for this seed + train = benchmark[seed]['train'] + valid = benchmark[seed]['valid'] + test = benchmark[seed]['test'] + + # Extract features and labels + X_train, y_train = train['Drug'], train['Y'] + X_valid, y_valid = valid['Drug'], valid['Y'] + X_test = test['Drug'] + + # Initialize and train model + model = YourModel(random_state=seed) + model.fit(X_train, y_train) + + # Optionally use validation set for early stopping + # model.fit(X_train, y_train, validation_data=(X_valid, y_valid)) + + # Make predictions on test set + predictions[seed] = model.predict(X_test) + +# Evaluate with TDC +results = group.evaluate(predictions) +print(f"Results: {results}") +''' + + print("\nCustom Model Integration Template:") + print("=" * 60) + print(code_template) + + return code_template + + +def multi_seed_statistics(predictions_dict): + """ + Example: Analyzing multi-seed prediction statistics + """ + print("\n" + "=" * 60) + print("Example 4: Multi-Seed Statistics Analysis") + print("=" * 60) + + # Analyze prediction variability across seeds + all_preds = np.array([predictions_dict[seed] for seed in [1, 2, 3, 4, 5]]) + + print("\nPrediction statistics across 5 seeds:") + print(f" Shape: {all_preds.shape}") + print(f" Mean prediction: {all_preds.mean():.4f}") + print(f" Std across seeds: {all_preds.std(axis=0).mean():.4f}") + print(f" Min prediction: {all_preds.min():.4f}") + print(f" Max prediction: {all_preds.max():.4f}") + + # Per-sample variance + per_sample_std = all_preds.std(axis=0) + print(f"\nPer-sample prediction std:") + print(f" Mean: {per_sample_std.mean():.4f}") + print(f" Median: {np.median(per_sample_std):.4f}") + print(f" Max: {per_sample_std.max():.4f}") + + +def leaderboard_submission_guide(): + """ + Guide for submitting to TDC leaderboards + """ + print("\n" + "=" * 60) + print("Example 5: Leaderboard Submission Guide") + print("=" * 60) + + guide = """ +To submit results to TDC leaderboards: + +1. Evaluate your model following the 5-seed protocol: + - Use seeds [1, 2, 3, 4, 5] exactly as provided + - Do not modify the train/valid/test splits + - Report mean ± std across all 5 seeds + +2. Format your results: + results = group.evaluate(predictions) + # Returns: {'dataset_name': [mean_score, std_score]} + +3. Submit to leaderboard: + - Visit: https://tdcommons.ai/benchmark/admet_group/ + - Click on your dataset of interest + - Submit your results with: + * Model name and description + * Mean score ± standard deviation + * Reference to paper/code (if available) + +4. Best practices: + - Report all datasets in the benchmark group + - Include model hyperparameters + - Share code for reproducibility + - Compare against baseline models + +5. Evaluation metrics: + - ADMET Group uses MAE by default + - Other groups may use different metrics + - Check benchmark-specific requirements +""" + + print(guide) + + +def main(): + """ + Main function to run all benchmark evaluation examples + """ + print("\n" + "=" * 60) + print("TDC Benchmark Group Evaluation Examples") + print("=" * 60) + + # Load benchmark group + group = load_benchmark_group() + + # Example 1: Single dataset evaluation + predictions, results = single_dataset_evaluation(group) + + # Example 2: Multiple datasets evaluation + all_predictions, all_results = multiple_datasets_evaluation(group) + + # Example 3: Custom model template + custom_model_template() + + # Example 4: Multi-seed statistics + multi_seed_statistics(predictions) + + # Example 5: Leaderboard submission guide + leaderboard_submission_guide() + + print("\n" + "=" * 60) + print("Benchmark evaluation examples completed!") + print("=" * 60) + print("\nNext steps:") + print("1. Replace dummy predictions with your model") + print("2. Run full evaluation on all benchmark datasets") + print("3. Submit results to TDC leaderboard") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pytdc/scripts/load_and_split_data.py b/scientific-packages/pytdc/scripts/load_and_split_data.py new file mode 100644 index 0000000..50238c7 --- /dev/null +++ b/scientific-packages/pytdc/scripts/load_and_split_data.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +TDC Data Loading and Splitting Template + +This script demonstrates how to load TDC datasets and apply different +splitting strategies for model training and evaluation. + +Usage: + python load_and_split_data.py +""" + +from tdc.single_pred import ADME +from tdc.multi_pred import DTI +from tdc import Evaluator +import pandas as pd + + +def load_single_pred_example(): + """ + Example: Loading and splitting a single-prediction dataset (ADME) + """ + print("=" * 60) + print("Example 1: Single-Prediction Task (ADME)") + print("=" * 60) + + # Load Caco2 dataset (intestinal permeability) + print("\nLoading Caco2_Wang dataset...") + data = ADME(name='Caco2_Wang') + + # Get basic dataset info + print(f"\nDataset size: {len(data.get_data())} molecules") + data.print_stats() + + # Method 1: Scaffold split (default, recommended) + print("\n--- Scaffold Split ---") + split = data.get_split(method='scaffold', seed=42, frac=[0.7, 0.1, 0.2]) + + train = split['train'] + valid = split['valid'] + test = split['test'] + + print(f"Train: {len(train)} molecules") + print(f"Valid: {len(valid)} molecules") + print(f"Test: {len(test)} molecules") + + # Display sample data + print("\nSample training data:") + print(train.head(3)) + + # Method 2: Random split + print("\n--- Random Split ---") + split_random = data.get_split(method='random', seed=42, frac=[0.8, 0.1, 0.1]) + print(f"Train: {len(split_random['train'])} molecules") + print(f"Valid: {len(split_random['valid'])} molecules") + print(f"Test: {len(split_random['test'])} molecules") + + return split + + +def load_multi_pred_example(): + """ + Example: Loading and splitting a multi-prediction dataset (DTI) + """ + print("\n" + "=" * 60) + print("Example 2: Multi-Prediction Task (DTI)") + print("=" * 60) + + # Load BindingDB Kd dataset (drug-target interactions) + print("\nLoading BindingDB_Kd dataset...") + data = DTI(name='BindingDB_Kd') + + # Get basic dataset info + full_data = data.get_data() + print(f"\nDataset size: {len(full_data)} drug-target pairs") + print(f"Unique drugs: {full_data['Drug_ID'].nunique()}") + print(f"Unique targets: {full_data['Target_ID'].nunique()}") + + # Method 1: Random split + print("\n--- Random Split ---") + split_random = data.get_split(method='random', seed=42) + print(f"Train: {len(split_random['train'])} pairs") + print(f"Valid: {len(split_random['valid'])} pairs") + print(f"Test: {len(split_random['test'])} pairs") + + # Method 2: Cold drug split (unseen drugs in test) + print("\n--- Cold Drug Split ---") + split_cold_drug = data.get_split(method='cold_drug', seed=42) + + train = split_cold_drug['train'] + test = split_cold_drug['test'] + + # Verify no drug overlap + train_drugs = set(train['Drug_ID']) + test_drugs = set(test['Drug_ID']) + overlap = train_drugs & test_drugs + + print(f"Train: {len(train)} pairs, {len(train_drugs)} unique drugs") + print(f"Test: {len(test)} pairs, {len(test_drugs)} unique drugs") + print(f"Drug overlap: {len(overlap)} (should be 0)") + + # Method 3: Cold target split (unseen targets in test) + print("\n--- Cold Target Split ---") + split_cold_target = data.get_split(method='cold_target', seed=42) + + train = split_cold_target['train'] + test = split_cold_target['test'] + + train_targets = set(train['Target_ID']) + test_targets = set(test['Target_ID']) + overlap = train_targets & test_targets + + print(f"Train: {len(train)} pairs, {len(train_targets)} unique targets") + print(f"Test: {len(test)} pairs, {len(test_targets)} unique targets") + print(f"Target overlap: {len(overlap)} (should be 0)") + + # Display sample data + print("\nSample DTI data:") + print(full_data.head(3)) + + return split_cold_drug + + +def evaluation_example(split): + """ + Example: Evaluating model predictions with TDC evaluators + """ + print("\n" + "=" * 60) + print("Example 3: Model Evaluation") + print("=" * 60) + + test = split['test'] + + # For demonstration, create dummy predictions + # In practice, replace with your model's predictions + import numpy as np + np.random.seed(42) + + # Simulate predictions (replace with model.predict(test['Drug'])) + y_true = test['Y'].values + y_pred = y_true + np.random.normal(0, 0.5, len(y_true)) # Add noise + + # Evaluate with different metrics + print("\nEvaluating predictions...") + + # Regression metrics + mae_evaluator = Evaluator(name='MAE') + mae = mae_evaluator(y_true, y_pred) + print(f"MAE: {mae:.4f}") + + rmse_evaluator = Evaluator(name='RMSE') + rmse = rmse_evaluator(y_true, y_pred) + print(f"RMSE: {rmse:.4f}") + + r2_evaluator = Evaluator(name='R2') + r2 = r2_evaluator(y_true, y_pred) + print(f"R²: {r2:.4f}") + + spearman_evaluator = Evaluator(name='Spearman') + spearman = spearman_evaluator(y_true, y_pred) + print(f"Spearman: {spearman:.4f}") + + +def custom_split_example(): + """ + Example: Creating custom splits with different fractions + """ + print("\n" + "=" * 60) + print("Example 4: Custom Split Fractions") + print("=" * 60) + + data = ADME(name='HIA_Hou') + + # Custom split fractions + custom_fracs = [ + ([0.6, 0.2, 0.2], "60/20/20 split"), + ([0.8, 0.1, 0.1], "80/10/10 split"), + ([0.7, 0.15, 0.15], "70/15/15 split") + ] + + for frac, description in custom_fracs: + split = data.get_split(method='scaffold', seed=42, frac=frac) + print(f"\n{description}:") + print(f" Train: {len(split['train'])} ({frac[0]*100:.0f}%)") + print(f" Valid: {len(split['valid'])} ({frac[1]*100:.0f}%)") + print(f" Test: {len(split['test'])} ({frac[2]*100:.0f}%)") + + +def main(): + """ + Main function to run all examples + """ + print("\n" + "=" * 60) + print("TDC Data Loading and Splitting Examples") + print("=" * 60) + + # Example 1: Single prediction with scaffold split + split = load_single_pred_example() + + # Example 2: Multi prediction with cold splits + dti_split = load_multi_pred_example() + + # Example 3: Model evaluation + evaluation_example(split) + + # Example 4: Custom split fractions + custom_split_example() + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pytdc/scripts/molecular_generation.py b/scientific-packages/pytdc/scripts/molecular_generation.py new file mode 100644 index 0000000..7392f45 --- /dev/null +++ b/scientific-packages/pytdc/scripts/molecular_generation.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +""" +TDC Molecular Generation with Oracles Template + +This script demonstrates how to use TDC oracles for molecular generation +tasks including goal-directed generation and distribution learning. + +Usage: + python molecular_generation.py +""" + +from tdc.generation import MolGen +from tdc import Oracle +import numpy as np + + +def load_generation_dataset(): + """ + Load molecular generation dataset + """ + print("=" * 60) + print("Loading Molecular Generation Dataset") + print("=" * 60) + + # Load ChEMBL dataset + data = MolGen(name='ChEMBL_V29') + + # Get training molecules + split = data.get_split() + train_smiles = split['train']['Drug'].tolist() + + print(f"\nDataset: ChEMBL_V29") + print(f"Training molecules: {len(train_smiles)}") + + # Display sample molecules + print("\nSample SMILES:") + for i, smiles in enumerate(train_smiles[:5], 1): + print(f" {i}. {smiles}") + + return train_smiles + + +def single_oracle_example(): + """ + Example: Using a single oracle for molecular evaluation + """ + print("\n" + "=" * 60) + print("Example 1: Single Oracle Evaluation") + print("=" * 60) + + # Initialize oracle for GSK3B target + oracle = Oracle(name='GSK3B') + + # Test molecules + test_molecules = [ + 'CC(C)Cc1ccc(cc1)C(C)C(O)=O', # Ibuprofen + 'CC(=O)Oc1ccccc1C(=O)O', # Aspirin + 'Cn1c(=O)c2c(ncn2C)n(C)c1=O', # Caffeine + 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C' # Theophylline + ] + + print("\nEvaluating molecules with GSK3B oracle:") + print("-" * 60) + + for smiles in test_molecules: + score = oracle(smiles) + print(f"SMILES: {smiles}") + print(f"GSK3B score: {score:.4f}\n") + + +def multiple_oracles_example(): + """ + Example: Using multiple oracles for multi-objective optimization + """ + print("\n" + "=" * 60) + print("Example 2: Multiple Oracles (Multi-Objective)") + print("=" * 60) + + # Initialize multiple oracles + oracles = { + 'QED': Oracle(name='QED'), # Drug-likeness + 'SA': Oracle(name='SA'), # Synthetic accessibility + 'GSK3B': Oracle(name='GSK3B'), # Target binding + 'LogP': Oracle(name='LogP') # Lipophilicity + } + + # Test molecule + test_smiles = 'CC(C)Cc1ccc(cc1)C(C)C(O)=O' + + print(f"\nEvaluating: {test_smiles}") + print("-" * 60) + + scores = {} + for name, oracle in oracles.items(): + score = oracle(test_smiles) + scores[name] = score + print(f"{name:10s}: {score:.4f}") + + # Multi-objective score (weighted combination) + print("\n--- Multi-Objective Scoring ---") + + # Invert SA (lower is better, so we invert for maximization) + sa_score = 1.0 / (1.0 + scores['SA']) + + # Weighted combination + weights = {'QED': 0.3, 'SA': 0.2, 'GSK3B': 0.4, 'LogP': 0.1} + multi_score = ( + weights['QED'] * scores['QED'] + + weights['SA'] * sa_score + + weights['GSK3B'] * scores['GSK3B'] + + weights['LogP'] * (scores['LogP'] / 5.0) # Normalize LogP + ) + + print(f"Multi-objective score: {multi_score:.4f}") + print(f"Weights: {weights}") + + +def batch_evaluation_example(): + """ + Example: Batch evaluation of multiple molecules + """ + print("\n" + "=" * 60) + print("Example 3: Batch Evaluation") + print("=" * 60) + + # Generate sample molecules + molecules = [ + 'CC(C)Cc1ccc(cc1)C(C)C(O)=O', + 'CC(=O)Oc1ccccc1C(=O)O', + 'Cn1c(=O)c2c(ncn2C)n(C)c1=O', + 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', + 'CC(C)NCC(COc1ccc(cc1)COCCOC(C)C)O' + ] + + # Initialize oracle + oracle = Oracle(name='DRD2') + + print(f"\nBatch evaluating {len(molecules)} molecules with DRD2 oracle...") + + # Batch evaluation (more efficient than individual calls) + scores = oracle(molecules) + + print("\nResults:") + print("-" * 60) + for smiles, score in zip(molecules, scores): + print(f"{smiles[:40]:40s}... Score: {score:.4f}") + + # Statistics + print(f"\nStatistics:") + print(f" Mean score: {np.mean(scores):.4f}") + print(f" Std score: {np.std(scores):.4f}") + print(f" Min score: {np.min(scores):.4f}") + print(f" Max score: {np.max(scores):.4f}") + + +def goal_directed_generation_template(): + """ + Template for goal-directed molecular generation + """ + print("\n" + "=" * 60) + print("Example 4: Goal-Directed Generation Template") + print("=" * 60) + + template = ''' +# Template for goal-directed molecular generation + +from tdc.generation import MolGen +from tdc import Oracle +import numpy as np + +# 1. Load training data +data = MolGen(name='ChEMBL_V29') +train_smiles = data.get_split()['train']['Drug'].tolist() + +# 2. Initialize oracle(s) +oracle = Oracle(name='GSK3B') + +# 3. Initialize your generative model +# model = YourGenerativeModel() +# model.fit(train_smiles) + +# 4. Generation loop +num_iterations = 100 +num_molecules_per_iter = 100 +best_molecules = [] + +for iteration in range(num_iterations): + # Generate candidate molecules + # candidates = model.generate(num_molecules_per_iter) + + # Evaluate with oracle + scores = oracle(candidates) + + # Select top molecules + top_indices = np.argsort(scores)[-10:] + top_molecules = [candidates[i] for i in top_indices] + top_scores = [scores[i] for i in top_indices] + + # Store best molecules + best_molecules.extend(zip(top_molecules, top_scores)) + + # Optional: Fine-tune model on top molecules + # model.fine_tune(top_molecules) + + # Print progress + print(f"Iteration {iteration}: Best score = {max(scores):.4f}") + +# Sort and display top molecules +best_molecules.sort(key=lambda x: x[1], reverse=True) +print("\\nTop 10 molecules:") +for smiles, score in best_molecules[:10]: + print(f"{smiles}: {score:.4f}") +''' + + print("\nGoal-Directed Generation Template:") + print("=" * 60) + print(template) + + +def distribution_learning_example(train_smiles): + """ + Example: Distribution learning evaluation + """ + print("\n" + "=" * 60) + print("Example 5: Distribution Learning") + print("=" * 60) + + # Use subset for demonstration + train_subset = train_smiles[:1000] + + # Initialize oracle + oracle = Oracle(name='QED') + + print("\nEvaluating property distribution...") + + # Evaluate training set + print("Computing training set distribution...") + train_scores = oracle(train_subset) + + # Simulate generated molecules (in practice, use your generative model) + # For demo: add noise to training molecules + print("Computing generated set distribution...") + generated_scores = train_scores + np.random.normal(0, 0.1, len(train_scores)) + generated_scores = np.clip(generated_scores, 0, 1) # QED is [0, 1] + + # Compare distributions + print("\n--- Distribution Statistics ---") + print(f"Training set (n={len(train_subset)}):") + print(f" Mean: {np.mean(train_scores):.4f}") + print(f" Std: {np.std(train_scores):.4f}") + print(f" Median: {np.median(train_scores):.4f}") + + print(f"\nGenerated set (n={len(generated_scores)}):") + print(f" Mean: {np.mean(generated_scores):.4f}") + print(f" Std: {np.std(generated_scores):.4f}") + print(f" Median: {np.median(generated_scores):.4f}") + + # Distribution similarity metrics + from scipy.stats import ks_2samp + ks_statistic, p_value = ks_2samp(train_scores, generated_scores) + + print(f"\nKolmogorov-Smirnov Test:") + print(f" KS statistic: {ks_statistic:.4f}") + print(f" P-value: {p_value:.4f}") + + if p_value > 0.05: + print(" → Distributions are similar (p > 0.05)") + else: + print(" → Distributions are significantly different (p < 0.05)") + + +def available_oracles_info(): + """ + Display information about available oracles + """ + print("\n" + "=" * 60) + print("Example 6: Available Oracles") + print("=" * 60) + + oracle_info = { + 'Biochemical Targets': [ + 'DRD2', 'GSK3B', 'JNK3', '5HT2A', 'ACE', + 'MAPK', 'CDK', 'P38', 'PARP1', 'PIK3CA' + ], + 'Physicochemical Properties': [ + 'QED', 'SA', 'LogP', 'MW', 'Lipinski' + ], + 'Composite Metrics': [ + 'Isomer_Meta', 'Median1', 'Median2', + 'Rediscovery', 'Similarity', 'Uniqueness', 'Novelty' + ], + 'Specialized': [ + 'ASKCOS', 'Docking', 'Vina' + ] + } + + print("\nAvailable Oracle Categories:") + print("-" * 60) + + for category, oracles in oracle_info.items(): + print(f"\n{category}:") + for oracle_name in oracles: + print(f" - {oracle_name}") + + print("\nFor detailed oracle documentation, see:") + print(" references/oracles.md") + + +def constraint_satisfaction_example(): + """ + Example: Molecular generation with constraints + """ + print("\n" + "=" * 60) + print("Example 7: Constraint Satisfaction") + print("=" * 60) + + # Define constraints + constraints = { + 'QED': (0.5, 1.0), # Drug-likeness >= 0.5 + 'SA': (1.0, 5.0), # Easy to synthesize + 'MW': (200, 500), # Molecular weight 200-500 Da + 'LogP': (0, 3) # Lipophilicity 0-3 + } + + # Initialize oracles + oracles = {name: Oracle(name=name) for name in constraints.keys()} + + # Test molecules + test_molecules = [ + 'CC(C)Cc1ccc(cc1)C(C)C(O)=O', + 'CC(=O)Oc1ccccc1C(=O)O', + 'Cn1c(=O)c2c(ncn2C)n(C)c1=O' + ] + + print("\nConstraints:") + for prop, (min_val, max_val) in constraints.items(): + print(f" {prop}: [{min_val}, {max_val}]") + + print("\n" + "-" * 60) + print("Evaluating molecules against constraints:") + print("-" * 60) + + for smiles in test_molecules: + print(f"\nSMILES: {smiles}") + + satisfies_all = True + for prop, (min_val, max_val) in constraints.items(): + score = oracles[prop](smiles) + satisfies = min_val <= score <= max_val + + status = "✓" if satisfies else "✗" + print(f" {prop:10s}: {score:7.2f} [{min_val:5.1f}, {max_val:5.1f}] {status}") + + satisfies_all = satisfies_all and satisfies + + result = "PASS" if satisfies_all else "FAIL" + print(f" Overall: {result}") + + +def main(): + """ + Main function to run all molecular generation examples + """ + print("\n" + "=" * 60) + print("TDC Molecular Generation with Oracles Examples") + print("=" * 60) + + # Load generation dataset + train_smiles = load_generation_dataset() + + # Example 1: Single oracle + single_oracle_example() + + # Example 2: Multiple oracles + multiple_oracles_example() + + # Example 3: Batch evaluation + batch_evaluation_example() + + # Example 4: Goal-directed generation template + goal_directed_generation_template() + + # Example 5: Distribution learning + distribution_learning_example(train_smiles) + + # Example 6: Available oracles + available_oracles_info() + + # Example 7: Constraint satisfaction + constraint_satisfaction_example() + + print("\n" + "=" * 60) + print("Molecular generation examples completed!") + print("=" * 60) + print("\nNext steps:") + print("1. Implement your generative model") + print("2. Use oracles to guide generation") + print("3. Evaluate generated molecules") + print("4. Iterate and optimize") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/pytorch-lightning/SKILL.md b/scientific-packages/pytorch-lightning/SKILL.md new file mode 100644 index 0000000..dc5428f --- /dev/null +++ b/scientific-packages/pytorch-lightning/SKILL.md @@ -0,0 +1,660 @@ +--- +name: pytorch-lightning +description: Comprehensive toolkit for PyTorch Lightning, a deep learning framework for organizing PyTorch code. Use this skill when working with PyTorch Lightning for training deep learning models, implementing LightningModules, configuring Trainers, setting up distributed training, creating DataModules, or converting existing PyTorch code to Lightning format. The skill provides templates, reference documentation, and best practices for efficient deep learning workflows. +--- + +# PyTorch Lightning + +## Overview + +PyTorch Lightning is a deep learning framework that organizes PyTorch code to decouple research from engineering. It automates training loop complexity (multi-GPU, mixed precision, checkpointing, logging) while maintaining full flexibility over model architecture and training logic. + +**Core Philosophy:** Separate concerns +- **LightningModule** - Research code (model architecture, training logic) +- **Trainer** - Engineering automation (hardware, optimization, logging) +- **DataModule** - Data processing (downloading, loading, transforms) +- **Callbacks** - Non-essential functionality (checkpointing, early stopping) + +## When to Use This Skill + +Use this skill when: +- Building or training deep learning models with PyTorch +- Converting existing PyTorch code to Lightning structure +- Setting up distributed training across multiple GPUs or nodes +- Implementing custom training loops with validation and testing +- Organizing data processing pipelines +- Configuring experiment logging and model checkpointing +- Optimizing training performance and memory usage +- Working with large models requiring model parallelism + +## Quick Start + +### Basic Lightning Workflow + +1. **Define a LightningModule** (organize your model) +2. **Create a DataModule or DataLoaders** (organize your data) +3. **Configure a Trainer** (automate training) +4. **Train** with `trainer.fit()` + +### Minimal Example + +```python +import lightning as L +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset + +# 1. Define LightningModule +class SimpleModel(L.LightningModule): + def __init__(self, input_dim, output_dim): + super().__init__() + self.save_hyperparameters() + self.model = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) + +# 2. Prepare data +train_data = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1)) +train_loader = DataLoader(train_data, batch_size=32) + +# 3. Create Trainer +trainer = L.Trainer(max_epochs=10, accelerator='auto') + +# 4. Train +model = SimpleModel(input_dim=10, output_dim=1) +trainer.fit(model, train_loader) +``` + +## Core Workflows + +### 1. Creating a LightningModule + +Structure model code by implementing essential hooks: + +**Template:** Use `scripts/template_lightning_module.py` as a starting point. + +```python +class MyLightningModule(L.LightningModule): + def __init__(self, hyperparameters): + super().__init__() + self.save_hyperparameters() # Save for checkpointing + self.model = YourModel() + + def forward(self, x): + """Inference forward pass.""" + return self.model(x) + + def training_step(self, batch, batch_idx): + """Define training loop logic.""" + x, y = batch + y_hat = self(x) + loss = self.compute_loss(y_hat, y) + self.log('train_loss', loss, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + """Define validation logic.""" + x, y = batch + y_hat = self(x) + loss = self.compute_loss(y_hat, y) + self.log('val_loss', loss, on_epoch=True) + return loss + + def configure_optimizers(self): + """Return optimizer and optional scheduler.""" + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + } + } +``` + +**Key Points:** +- Use `self.save_hyperparameters()` to automatically save init args +- Use `self.log()` to track metrics across loggers +- Return loss from training_step for automatic optimization +- Keep model architecture separate from training logic + +### 2. Creating a DataModule + +Organize all data processing in a reusable module: + +**Template:** Use `scripts/template_datamodule.py` as a starting point. + +```python +class MyDataModule(L.LightningDataModule): + def __init__(self, data_dir, batch_size=32): + super().__init__() + self.save_hyperparameters() + + def prepare_data(self): + """Download data (called once, single process).""" + # Download datasets, tokenize, etc. + pass + + def setup(self, stage=None): + """Create datasets (called on every process).""" + if stage == 'fit' or stage is None: + # Create train/val datasets + self.train_dataset = ... + self.val_dataset = ... + + if stage == 'test' or stage is None: + # Create test dataset + self.test_dataset = ... + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size) +``` + +**Key Points:** +- `prepare_data()` for downloading (single process) +- `setup()` for creating datasets (every process) +- Use `stage` parameter to separate fit/test logic +- Makes data code reusable across projects + +### 3. Configuring the Trainer + +The Trainer automates training complexity: + +**Helper:** Use `scripts/quick_trainer_setup.py` for preset configurations. + +```python +from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping + +trainer = L.Trainer( + # Training duration + max_epochs=100, + + # Hardware + accelerator='auto', # 'cpu', 'gpu', 'tpu' + devices=1, # Number of devices or specific IDs + + # Optimization + precision='16-mixed', # Mixed precision training + gradient_clip_val=1.0, + accumulate_grad_batches=4, # Gradient accumulation + + # Validation + check_val_every_n_epoch=1, + val_check_interval=1.0, # Validate every epoch + + # Logging + log_every_n_steps=50, + logger=TensorBoardLogger('logs/'), + + # Callbacks + callbacks=[ + ModelCheckpoint(monitor='val_loss', mode='min'), + EarlyStopping(monitor='val_loss', patience=10), + ], + + # Debugging + fast_dev_run=False, # Quick test with few batches + enable_progress_bar=True, +) +``` + +**Common Presets:** + +```python +from scripts.quick_trainer_setup import create_trainer + +# Development preset (fast debugging) +trainer = create_trainer(preset='fast_dev', max_epochs=3) + +# Production preset (full features) +trainer = create_trainer(preset='production', max_epochs=100) + +# Distributed preset (multi-GPU) +trainer = create_trainer(preset='distributed', devices=4) +``` + +### 4. Training and Evaluation + +```python +# Training +trainer.fit(model, datamodule=dm) +# Or with dataloaders +trainer.fit(model, train_loader, val_loader) + +# Resume from checkpoint +trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt') + +# Testing +trainer.test(model, datamodule=dm) +# Or load best checkpoint +trainer.test(ckpt_path='best', datamodule=dm) + +# Prediction +predictions = trainer.predict(model, predict_loader) + +# Validation only +trainer.validate(model, datamodule=dm) +``` + +### 5. Distributed Training + +Lightning handles distributed training automatically: + +```python +# Single machine, multiple GPUs (Data Parallel) +trainer = L.Trainer( + accelerator='gpu', + devices=4, + strategy='ddp', # DistributedDataParallel +) + +# Multiple machines, multiple GPUs +trainer = L.Trainer( + accelerator='gpu', + devices=4, # GPUs per node + num_nodes=8, # Number of machines + strategy='ddp', +) + +# Large models (Model Parallel with FSDP) +trainer = L.Trainer( + accelerator='gpu', + devices=4, + strategy='fsdp', # Fully Sharded Data Parallel +) + +# Large models (Model Parallel with DeepSpeed) +trainer = L.Trainer( + accelerator='gpu', + devices=4, + strategy='deepspeed_stage_2', + precision='16-mixed', +) +``` + +**For detailed distributed training guide, see:** `references/distributed_training.md` + +**Strategy Selection:** +- Models < 500M params → Use `ddp` +- Models > 500M params → Use `fsdp` or `deepspeed` +- Maximum memory efficiency → Use DeepSpeed Stage 3 with offloading +- Native PyTorch → Use `fsdp` +- Cutting-edge features → Use `deepspeed` + +### 6. Callbacks + +Extend training with modular functionality: + +```python +from lightning.pytorch.callbacks import ( + ModelCheckpoint, + EarlyStopping, + LearningRateMonitor, + RichProgressBar, +) + +callbacks = [ + # Save best models + ModelCheckpoint( + monitor='val_loss', + mode='min', + save_top_k=3, + filename='{epoch}-{val_loss:.2f}', + ), + + # Stop when no improvement + EarlyStopping( + monitor='val_loss', + patience=10, + mode='min', + ), + + # Log learning rate + LearningRateMonitor(logging_interval='epoch'), + + # Rich progress bar + RichProgressBar(), +] + +trainer = L.Trainer(callbacks=callbacks) +``` + +**Custom Callbacks:** + +```python +from lightning.pytorch.callbacks import Callback + +class MyCustomCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + # Custom logic at end of each epoch + print(f"Epoch {trainer.current_epoch} completed") + + def on_validation_end(self, trainer, pl_module): + val_loss = trainer.callback_metrics.get('val_loss') + # Custom validation logic + pass +``` + +### 7. Logging + +Track experiments with various loggers: + +```python +from lightning.pytorch.loggers import ( + TensorBoardLogger, + WandbLogger, + CSVLogger, + MLFlowLogger, +) + +# Single logger +logger = TensorBoardLogger('logs/', name='my_experiment') + +# Multiple loggers +loggers = [ + TensorBoardLogger('logs/'), + WandbLogger(project='my_project'), + CSVLogger('logs/'), +] + +trainer = L.Trainer(logger=loggers) +``` + +**Logging in LightningModule:** + +```python +def training_step(self, batch, batch_idx): + loss = self.compute_loss(batch) + + # Log single metric + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + + # Log multiple metrics + metrics = {'loss': loss, 'acc': acc, 'f1': f1} + self.log_dict(metrics, on_step=True, on_epoch=True) + + return loss +``` + +## Converting Existing PyTorch Code + +### Standard PyTorch → Lightning + +**Before (PyTorch):** +```python +model = MyModel() +optimizer = torch.optim.Adam(model.parameters()) + +for epoch in range(num_epochs): + for batch in train_loader: + optimizer.zero_grad() + x, y = batch + y_hat = model(x) + loss = F.cross_entropy(y_hat, y) + loss.backward() + optimizer.step() +``` + +**After (Lightning):** +```python +class MyLightningModel(L.LightningModule): + def __init__(self): + super().__init__() + self.model = MyModel() + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters()) + +trainer = L.Trainer(max_epochs=num_epochs) +trainer.fit(model, train_loader) +``` + +**Key Changes:** +1. Wrap model in LightningModule +2. Move training loop logic to `training_step()` +3. Move optimizer setup to `configure_optimizers()` +4. Replace manual loop with `trainer.fit()` +5. Lightning handles: `.zero_grad()`, `.backward()`, `.step()`, device placement + +## Common Patterns + +### Reproducibility + +```python +from lightning.pytorch import seed_everything + +# Set seed for reproducibility +seed_everything(42, workers=True) + +trainer = L.Trainer(deterministic=True) +``` + +### Mixed Precision Training + +```python +# 16-bit mixed precision +trainer = L.Trainer(precision='16-mixed') + +# BFloat16 mixed precision (more stable) +trainer = L.Trainer(precision='bf16-mixed') +``` + +### Gradient Accumulation + +```python +# Effective batch size = 4x actual batch size +trainer = L.Trainer(accumulate_grad_batches=4) +``` + +### Learning Rate Finding + +```python +from lightning.pytorch.tuner import Tuner + +trainer = L.Trainer() +tuner = Tuner(trainer) + +# Find optimal learning rate +lr_finder = tuner.lr_find(model, train_dataloader) +model.hparams.learning_rate = lr_finder.suggestion() + +# Find optimal batch size +tuner.scale_batch_size(model, mode="power") +``` + +### Checkpointing and Loading + +```python +# Save checkpoint +trainer.fit(model, datamodule=dm) +# Checkpoint automatically saved to checkpoints/ + +# Load from checkpoint +model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') + +# Resume training +trainer.fit(model, datamodule=dm, ckpt_path='checkpoint.ckpt') + +# Test from checkpoint +trainer.test(ckpt_path='best', datamodule=dm) +``` + +### Debugging + +```python +# Quick test with few batches +trainer = L.Trainer(fast_dev_run=10) + +# Overfit on small data (debug model) +trainer = L.Trainer(overfit_batches=100) + +# Limit batches for quick iteration +trainer = L.Trainer( + limit_train_batches=100, + limit_val_batches=50, +) + +# Profile training +trainer = L.Trainer(profiler='simple') # or 'advanced' +``` + +## Best Practices + +### Code Organization + +1. **Separate concerns:** + - Model architecture in `__init__()` + - Training logic in `training_step()` + - Validation logic in `validation_step()` + - Data processing in DataModule + +2. **Use `save_hyperparameters()`:** + ```python + def __init__(self, lr, hidden_dim, dropout): + super().__init__() + self.save_hyperparameters() # Automatically saves all args + ``` + +3. **Device-agnostic code:** + ```python + # Avoid manual device placement + # BAD: tensor.cuda() + # GOOD: Lightning handles this automatically + + # Create tensors on model's device + new_tensor = torch.zeros(10, device=self.device) + ``` + +4. **Log comprehensively:** + ```python + self.log('metric', value, on_step=True, on_epoch=True, prog_bar=True) + ``` + +### Performance Optimization + +1. **Use DataLoader best practices:** + ```python + DataLoader( + dataset, + batch_size=32, + num_workers=4, # Multiple workers + pin_memory=True, # Faster GPU transfer + persistent_workers=True, # Keep workers alive + ) + ``` + +2. **Enable benchmark mode for fixed input sizes:** + ```python + trainer = L.Trainer(benchmark=True) + ``` + +3. **Use gradient clipping:** + ```python + trainer = L.Trainer(gradient_clip_val=1.0) + ``` + +4. **Enable mixed precision:** + ```python + trainer = L.Trainer(precision='16-mixed') + ``` + +### Distributed Training + +1. **Sync metrics across devices:** + ```python + self.log('metric', value, sync_dist=True) + ``` + +2. **Rank-specific operations:** + ```python + if self.trainer.is_global_zero: + # Only run on main process + self.save_artifacts() + ``` + +3. **Use appropriate strategy:** + - Small models → `ddp` + - Large models → `fsdp` or `deepspeed` + +## Resources + +### Scripts + +Executable templates for quick implementation: + +- **`template_lightning_module.py`** - Complete LightningModule template with all hooks, logging, and optimization patterns +- **`template_datamodule.py`** - Complete DataModule template with data loading, splitting, and transformation patterns +- **`quick_trainer_setup.py`** - Helper functions to create Trainers with preset configurations (development, production, distributed) + +### References + +Comprehensive documentation for deep-dive learning: + +- **`api_reference.md`** - Complete API reference covering LightningModule hooks, Trainer parameters, Callbacks, DataModules, Loggers, and common patterns +- **`distributed_training.md`** - In-depth guide for distributed training strategies (DDP, FSDP, DeepSpeed), multi-node setup, memory optimization, and troubleshooting + +Load references when needing detailed information: +```python +# Example: Load distributed training reference +# See references/distributed_training.md for comprehensive distributed training guide +``` + +## Troubleshooting + +### Common Issues + +**Out of Memory:** +- Reduce batch size +- Use gradient accumulation +- Enable mixed precision (`precision='16-mixed'`) +- Use FSDP or DeepSpeed for large models +- Enable activation checkpointing + +**Slow Training:** +- Use multiple DataLoader workers (`num_workers > 0`) +- Enable `pin_memory=True` and `persistent_workers=True` +- Enable `benchmark=True` for fixed input sizes +- Profile with `profiler='simple'` + +**Validation Not Running:** +- Check `check_val_every_n_epoch` setting +- Ensure validation data provided +- Verify `validation_step()` implemented + +**Checkpoints Not Saving:** +- Ensure `enable_checkpointing=True` +- Check `ModelCheckpoint` callback configuration +- Verify `monitor` metric exists in logs + +## Additional Resources + +- Official Documentation: https://lightning.ai/docs/pytorch/stable/ +- GitHub: https://github.com/Lightning-AI/lightning +- Community: https://lightning.ai/community + +When unclear about specific functionality, refer to `references/api_reference.md` for detailed API documentation or `references/distributed_training.md` for distributed training specifics. diff --git a/scientific-packages/pytorch-lightning/references/api_reference.md b/scientific-packages/pytorch-lightning/references/api_reference.md new file mode 100644 index 0000000..6a946ff --- /dev/null +++ b/scientific-packages/pytorch-lightning/references/api_reference.md @@ -0,0 +1,490 @@ +# PyTorch Lightning API Reference + +Comprehensive reference for PyTorch Lightning core APIs, hooks, and components. + +## LightningModule + +The LightningModule is the core abstraction for organizing PyTorch code in Lightning. + +### Essential Hooks + +#### `__init__(self, *args, **kwargs)` +Initialize the model, define layers, and save hyperparameters. + +```python +def __init__(self, learning_rate=1e-3, hidden_dim=128): + super().__init__() + self.save_hyperparameters() # Saves all args to self.hparams + self.model = nn.Sequential(...) +``` + +#### `forward(self, x)` +Define the forward pass for inference. Called by `predict_step` by default. + +```python +def forward(self, x): + return self.model(x) +``` + +#### `training_step(self, batch, batch_idx)` +Define the training loop logic. Return loss for automatic optimization. + +```python +def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss) + return loss +``` + +#### `validation_step(self, batch, batch_idx)` +Define the validation loop logic. Model automatically in eval mode with no gradients. + +```python +def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('val_loss', loss) + return loss +``` + +#### `test_step(self, batch, batch_idx)` +Define the test loop logic. Only runs when `trainer.test()` is called. + +```python +def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + return loss +``` + +#### `predict_step(self, batch, batch_idx, dataloader_idx=0)` +Define prediction logic for inference. Defaults to calling `forward()`. + +```python +def predict_step(self, batch, batch_idx, dataloader_idx=0): + x, y = batch + return self(x) +``` + +#### `configure_optimizers(self)` +Return optimizer(s) and optional learning rate scheduler(s). + +```python +def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + scheduler = ReduceLROnPlateau(optimizer, mode='min') + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + } + } +``` + +### Lifecycle Hooks + +#### Epoch-Level Hooks +- `on_train_epoch_start()` - Called at the start of each training epoch +- `on_train_epoch_end()` - Called at the end of each training epoch +- `on_validation_epoch_start()` - Called at the start of validation epoch +- `on_validation_epoch_end()` - Called at the end of validation epoch +- `on_test_epoch_start()` - Called at the start of test epoch +- `on_test_epoch_end()` - Called at the end of test epoch + +#### Batch-Level Hooks +- `on_train_batch_start(batch, batch_idx)` - Called before training batch +- `on_train_batch_end(outputs, batch, batch_idx)` - Called after training batch +- `on_validation_batch_start(batch, batch_idx)` - Called before validation batch +- `on_validation_batch_end(outputs, batch, batch_idx)` - Called after validation batch + +#### Training Lifecycle +- `on_fit_start()` - Called at the start of fit +- `on_fit_end()` - Called at the end of fit +- `on_train_start()` - Called at the start of training +- `on_train_end()` - Called at the end of training + +### Logging + +#### `self.log(name, value, **kwargs)` +Log a metric to all configured loggers. + +**Common Parameters:** +- `on_step` (bool) - Log at each batch step +- `on_epoch` (bool) - Log at the end of epoch (automatically aggregated) +- `prog_bar` (bool) - Display in progress bar +- `logger` (bool) - Send to logger +- `sync_dist` (bool) - Synchronize across all distributed processes +- `reduce_fx` (str) - Reduction function for distributed ("mean", "sum", etc.) + +```python +self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) +``` + +#### `self.log_dict(dictionary, **kwargs)` +Log multiple metrics at once. + +```python +metrics = {'loss': loss, 'acc': acc, 'f1': f1} +self.log_dict(metrics, on_step=True, on_epoch=True) +``` + +### Device Management + +- `self.device` - Current device (automatically managed) +- `self.to(device)` - Move model to device (usually handled automatically) + +**Best Practice:** Create tensors on model's device: +```python +new_tensor = torch.zeros(10, device=self.device) +``` + +### Hyperparameter Management + +#### `self.save_hyperparameters(*args, **kwargs)` +Automatically save init arguments to `self.hparams` and checkpoints. + +```python +def __init__(self, learning_rate, hidden_dim): + super().__init__() + self.save_hyperparameters() # Saves all args + # Access via self.hparams.learning_rate, self.hparams.hidden_dim +``` + +--- + +## Trainer + +The Trainer automates the training loop and engineering complexity. + +### Core Parameters + +#### Training Duration +- `max_epochs` (int) - Maximum number of epochs (default: 1000) +- `min_epochs` (int) - Minimum number of epochs +- `max_steps` (int) - Maximum number of optimizer steps +- `min_steps` (int) - Minimum number of optimizer steps +- `max_time` (str/dict) - Maximum training time ("DD:HH:MM:SS" or dict) + +#### Hardware Configuration +- `accelerator` (str) - Hardware to use: "cpu", "gpu", "tpu", "auto" +- `devices` (int/list) - Number or specific device IDs: 1, 4, [0,2], "auto" +- `num_nodes` (int) - Number of GPU nodes for distributed training +- `strategy` (str) - Training strategy: "ddp", "fsdp", "deepspeed", etc. + +#### Data Management +- `limit_train_batches` (int/float) - Limit training batches (0.0-1.0 for %, int for count) +- `limit_val_batches` (int/float) - Limit validation batches +- `limit_test_batches` (int/float) - Limit test batches +- `limit_predict_batches` (int/float) - Limit prediction batches + +#### Validation +- `check_val_every_n_epoch` (int) - Run validation every N epochs +- `val_check_interval` (int/float) - Validate every N batches or fraction +- `num_sanity_val_steps` (int) - Validation steps before training (default: 2) + +#### Optimization +- `gradient_clip_val` (float) - Clip gradients by value +- `gradient_clip_algorithm` (str) - "value" or "norm" +- `accumulate_grad_batches` (int) - Accumulate gradients over K batches +- `precision` (str) - Training precision: "32-true", "16-mixed", "bf16-mixed", "64-true" + +#### Logging and Checkpointing +- `logger` (Logger/list) - Logger instance(s) or True/False +- `log_every_n_steps` (int) - Logging frequency +- `enable_checkpointing` (bool) - Enable automatic checkpointing +- `callbacks` (list) - List of callback instances +- `default_root_dir` (str) - Default path for logs and checkpoints + +#### Debugging +- `fast_dev_run` (bool/int) - Run N batches for quick testing +- `overfit_batches` (int/float) - Overfit on limited data for debugging +- `detect_anomaly` (bool) - Enable PyTorch anomaly detection +- `profiler` (str/Profiler) - Profile training: "simple", "advanced", or custom + +#### Performance +- `benchmark` (bool) - Enable cudnn.benchmark for performance +- `deterministic` (bool) - Enable deterministic training for reproducibility +- `sync_batchnorm` (bool) - Synchronize batch norm across GPUs + +### Training Methods + +#### `trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)` +Run the full training routine. + +```python +trainer.fit(model, train_loader, val_loader) +# Or with DataModule +trainer.fit(model, datamodule=dm) +# Resume from checkpoint +trainer.fit(model, train_loader, val_loader, ckpt_path="path/to/checkpoint.ckpt") +``` + +#### `trainer.validate(model, dataloaders=None, datamodule=None, ckpt_path=None)` +Run validation independently. + +```python +trainer.validate(model, val_loader) +``` + +#### `trainer.test(model, dataloaders=None, datamodule=None, ckpt_path=None)` +Run test evaluation. + +```python +trainer.test(model, test_loader) +# Or load from checkpoint +trainer.test(ckpt_path="best_model.ckpt", datamodule=dm) +``` + +#### `trainer.predict(model, dataloaders=None, datamodule=None, ckpt_path=None)` +Run inference predictions. + +```python +predictions = trainer.predict(model, predict_loader) +``` + +--- + +## LightningDataModule + +Encapsulates all data processing logic in a reusable class. + +### Core Methods + +#### `prepare_data(self)` +Download and prepare data (called once on single process). +Do NOT set state here (no self.x = y). + +```python +def prepare_data(self): + # Download datasets + datasets.MNIST(self.data_dir, train=True, download=True) + datasets.MNIST(self.data_dir, train=False, download=True) +``` + +#### `setup(self, stage=None)` +Load data and create splits (called on every process/GPU). +Setting state is OK here. + +**stage parameter:** "fit", "validate", "test", or "predict" + +```python +def setup(self, stage=None): + if stage == "fit" or stage is None: + full_dataset = datasets.MNIST(self.data_dir, train=True) + self.train_dataset, self.val_dataset = random_split(full_dataset, [55000, 5000]) + + if stage == "test" or stage is None: + self.test_dataset = datasets.MNIST(self.data_dir, train=False) +``` + +#### DataLoader Methods +- `train_dataloader(self)` - Return training DataLoader +- `val_dataloader(self)` - Return validation DataLoader +- `test_dataloader(self)` - Return test DataLoader +- `predict_dataloader(self)` - Return prediction DataLoader + +```python +def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=32, shuffle=True) +``` + +### Optional Methods +- `teardown(stage=None)` - Cleanup after training/testing +- `state_dict()` - Save state for checkpointing +- `load_state_dict(state_dict)` - Load state from checkpoint + +--- + +## Callbacks + +Extend training with modular, reusable functionality. + +### Built-in Callbacks + +#### ModelCheckpoint +Save model checkpoints based on monitored metrics. + +```python +from lightning.pytorch.callbacks import ModelCheckpoint + +checkpoint_callback = ModelCheckpoint( + dirpath='checkpoints/', + filename='{epoch}-{val_loss:.2f}', + monitor='val_loss', + mode='min', + save_top_k=3, + save_last=True, + verbose=True, +) +``` + +**Key Parameters:** +- `monitor` - Metric to monitor +- `mode` - "min" or "max" +- `save_top_k` - Save top K models +- `save_last` - Always save last checkpoint +- `every_n_epochs` - Save every N epochs + +#### EarlyStopping +Stop training when metric stops improving. + +```python +from lightning.pytorch.callbacks import EarlyStopping + +early_stop = EarlyStopping( + monitor='val_loss', + patience=10, + mode='min', + verbose=True, +) +``` + +#### LearningRateMonitor +Log learning rate values. + +```python +from lightning.pytorch.callbacks import LearningRateMonitor + +lr_monitor = LearningRateMonitor(logging_interval='epoch') +``` + +#### RichProgressBar +Display rich progress bar with metrics. + +```python +from lightning.pytorch.callbacks import RichProgressBar + +progress_bar = RichProgressBar() +``` + +### Custom Callbacks + +Create custom callbacks by inheriting from `Callback`. + +```python +from lightning.pytorch.callbacks import Callback + +class MyCallback(Callback): + def on_train_start(self, trainer, pl_module): + print("Training starting!") + + def on_train_epoch_end(self, trainer, pl_module): + print(f"Epoch {trainer.current_epoch} ended") + + def on_validation_end(self, trainer, pl_module): + val_loss = trainer.callback_metrics.get('val_loss') + print(f"Validation loss: {val_loss}") +``` + +**Common Hooks:** +- `on_train_start/end` +- `on_train_epoch_start/end` +- `on_validation_epoch_start/end` +- `on_test_epoch_start/end` +- `on_before_backward/on_after_backward` +- `on_before_optimizer_step` + +--- + +## Loggers + +Track experiments with various logging frameworks. + +### TensorBoardLogger +```python +from lightning.pytorch.loggers import TensorBoardLogger + +logger = TensorBoardLogger(save_dir='logs/', name='my_experiment') +trainer = Trainer(logger=logger) +``` + +### WandbLogger +```python +from lightning.pytorch.loggers import WandbLogger + +logger = WandbLogger(project='my_project', name='experiment_1') +trainer = Trainer(logger=logger) +``` + +### MLFlowLogger +```python +from lightning.pytorch.loggers import MLFlowLogger + +logger = MLFlowLogger(experiment_name='my_exp', tracking_uri='file:./ml-runs') +trainer = Trainer(logger=logger) +``` + +### CSVLogger +```python +from lightning.pytorch.loggers import CSVLogger + +logger = CSVLogger(save_dir='logs/', name='my_experiment') +trainer = Trainer(logger=logger) +``` + +### Multiple Loggers +```python +loggers = [ + TensorBoardLogger('logs/'), + CSVLogger('logs/'), +] +trainer = Trainer(logger=loggers) +``` + +--- + +## Common Patterns + +### Reproducibility +```python +from lightning.pytorch import seed_everything + +seed_everything(42, workers=True) +trainer = Trainer(deterministic=True) +``` + +### Mixed Precision Training +```python +trainer = Trainer(precision='16-mixed') # or 'bf16-mixed' +``` + +### Multi-GPU Training +```python +# Data parallel (DDP) +trainer = Trainer(accelerator='gpu', devices=4, strategy='ddp') + +# Model parallel (FSDP) +trainer = Trainer(accelerator='gpu', devices=4, strategy='fsdp') +``` + +### Gradient Accumulation +```python +trainer = Trainer(accumulate_grad_batches=4) # Effective batch size = 4x +``` + +### Learning Rate Finding +```python +from lightning.pytorch.tuner import Tuner + +trainer = Trainer() +tuner = Tuner(trainer) +lr_finder = tuner.lr_find(model, train_dataloader) +model.hparams.learning_rate = lr_finder.suggestion() +``` + +### Loading from Checkpoint +```python +# Load model +model = MyLightningModule.load_from_checkpoint('checkpoint.ckpt') + +# Resume training +trainer.fit(model, ckpt_path='checkpoint.ckpt') +``` diff --git a/scientific-packages/pytorch-lightning/references/distributed_training.md b/scientific-packages/pytorch-lightning/references/distributed_training.md new file mode 100644 index 0000000..ee159c8 --- /dev/null +++ b/scientific-packages/pytorch-lightning/references/distributed_training.md @@ -0,0 +1,508 @@ +# Distributed and Model Parallel Training + +Comprehensive guide for distributed training strategies in PyTorch Lightning. + +## Overview + +PyTorch Lightning provides seamless distributed training across multiple GPUs, machines, and TPUs with minimal code changes. The framework automatically handles the complexity of distributed training while keeping code device-agnostic and readable. + +## Training Strategies + +### Data Parallel (DDP - DistributedDataParallel) + +**Best for:** Most models (< 500M parameters) where the full model fits in GPU memory. + +**How it works:** Each GPU holds a complete copy of the model and trains on a different batch subset. Gradients are synchronized across GPUs during backward pass. + +```python +# Single-node, multi-GPU +trainer = Trainer( + accelerator='gpu', + devices=4, # Use 4 GPUs + strategy='ddp', +) + +# Multi-node, multi-GPU +trainer = Trainer( + accelerator='gpu', + devices=4, # GPUs per node + num_nodes=2, # Number of nodes + strategy='ddp', +) +``` + +**Advantages:** +- Most widely used and tested +- Works with most PyTorch code +- Good scaling efficiency +- No code changes required in LightningModule + +**When to use:** Default choice for most distributed training scenarios. + +### FSDP (Fully Sharded Data Parallel) + +**Best for:** Large models (500M+ parameters) that don't fit in single GPU memory. + +**How it works:** Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores a subset of the model. + +```python +trainer = Trainer( + accelerator='gpu', + devices=4, + strategy='fsdp', +) + +# With configuration +from lightning.pytorch.strategies import FSDPStrategy + +strategy = FSDPStrategy( + sharding_strategy="FULL_SHARD", # Full sharding + cpu_offload=False, # Offload to CPU + mixed_precision=torch.float16, +) + +trainer = Trainer( + accelerator='gpu', + devices=4, + strategy=strategy, +) +``` + +**Sharding Strategies:** +- `FULL_SHARD` - Shard parameters, gradients, and optimizer states +- `SHARD_GRAD_OP` - Shard only gradients and optimizer states +- `NO_SHARD` - DDP-like (no sharding) +- `HYBRID_SHARD` - Shard within node, DDP across nodes + +**Advanced FSDP Configuration:** +```python +from lightning.pytorch.strategies import FSDPStrategy + +strategy = FSDPStrategy( + sharding_strategy="FULL_SHARD", + activation_checkpointing=True, # Save memory + cpu_offload=True, # Offload parameters to CPU + backward_prefetch="BACKWARD_PRE", # Prefetch strategy + forward_prefetch=True, + limit_all_gathers=True, +) +``` + +**When to use:** +- Models > 500M parameters +- Limited GPU memory +- Native PyTorch solution preferred +- Migrating from standalone PyTorch FSDP + +### DeepSpeed + +**Best for:** Cutting-edge features, massive models, or existing DeepSpeed users. + +**How it works:** Comprehensive optimization library with multiple stages of memory and compute optimization. + +```python +# Basic DeepSpeed +trainer = Trainer( + accelerator='gpu', + devices=4, + strategy='deepspeed', + precision='16-mixed', +) + +# With configuration +from lightning.pytorch.strategies import DeepSpeedStrategy + +strategy = DeepSpeedStrategy( + stage=2, # ZeRO Stage (1, 2, or 3) + offload_optimizer=True, + offload_parameters=True, +) + +trainer = Trainer( + accelerator='gpu', + devices=4, + strategy=strategy, +) +``` + +**ZeRO Stages:** +- **Stage 1:** Shard optimizer states +- **Stage 2:** Shard optimizer states + gradients +- **Stage 3:** Shard optimizer states + gradients + parameters (like FSDP) + +**With DeepSpeed Config File:** +```python +strategy = DeepSpeedStrategy(config="deepspeed_config.json") +``` + +Example `deepspeed_config.json`: +```json +{ + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "allgather_bucket_size": 2e8, + "reduce_bucket_size": 2e8 + }, + "activation_checkpointing": { + "partition_activations": true, + "cpu_checkpointing": true + }, + "fp16": { + "enabled": true + }, + "gradient_clipping": 1.0 +} +``` + +**When to use:** +- Need specific DeepSpeed features +- Maximum memory efficiency required +- Already familiar with DeepSpeed +- Training extremely large models + +### DDP Spawn + +**Note:** Generally avoid using `ddp_spawn`. Use `ddp` instead. + +```python +trainer = Trainer(strategy='ddp_spawn') # Not recommended +``` + +**Issues with ddp_spawn:** +- Cannot return values from `.fit()` +- Pickling issues with unpicklable objects +- Slower than `ddp` +- More memory overhead + +**When to use:** Only for debugging or if `ddp` doesn't work on your system. + +## Multi-Node Training + +### Basic Multi-Node Setup + +```python +# On each node, run the same command +trainer = Trainer( + accelerator='gpu', + devices=4, # GPUs per node + num_nodes=8, # Total number of nodes + strategy='ddp', +) +``` + +### SLURM Cluster + +Lightning automatically detects SLURM environment: + +```python +trainer = Trainer( + accelerator='gpu', + devices=4, + num_nodes=8, + strategy='ddp', +) +``` + +**SLURM Submit Script:** +```bash +#!/bin/bash +#SBATCH --nodes=8 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --job-name=lightning_training + +python train.py +``` + +### Manual Cluster Setup + +```python +from lightning.pytorch.strategies import DDPStrategy + +strategy = DDPStrategy( + cluster_environment='TorchElastic', # or 'SLURM', 'LSF', 'Kubeflow' +) + +trainer = Trainer( + accelerator='gpu', + devices=4, + num_nodes=8, + strategy=strategy, +) +``` + +## Memory Optimization Techniques + +### Gradient Accumulation + +Simulate larger batch sizes without increasing memory: + +```python +trainer = Trainer( + accumulate_grad_batches=4, # Accumulate 4 batches before optimizer step +) + +# Variable accumulation by epoch +trainer = Trainer( + accumulate_grad_batches={ + 0: 8, # Epochs 0-4: accumulate 8 batches + 5: 4, # Epochs 5+: accumulate 4 batches + } +) +``` + +### Activation Checkpointing + +Trade computation for memory by recomputing activations during backward pass: + +```python +# FSDP +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) + +class MyModule(L.LightningModule): + def configure_model(self): + # Wrap specific layers for activation checkpointing + self.model = MyTransformer() + apply_activation_checkpointing( + self.model, + checkpoint_wrapper_fn=lambda m: checkpoint_wrapper(m, CheckpointImpl.NO_REENTRANT), + check_fn=lambda m: isinstance(m, TransformerBlock), + ) +``` + +### Mixed Precision Training + +Reduce memory usage and increase speed with mixed precision: + +```python +# 16-bit mixed precision +trainer = Trainer(precision='16-mixed') + +# BFloat16 mixed precision (more stable, requires newer GPUs) +trainer = Trainer(precision='bf16-mixed') +``` + +### CPU Offloading + +Offload parameters or optimizer states to CPU: + +```python +# FSDP with CPU offload +from lightning.pytorch.strategies import FSDPStrategy + +strategy = FSDPStrategy( + cpu_offload=True, # Offload parameters to CPU +) + +# DeepSpeed with CPU offload +from lightning.pytorch.strategies import DeepSpeedStrategy + +strategy = DeepSpeedStrategy( + stage=3, + offload_optimizer=True, + offload_parameters=True, +) +``` + +## Performance Optimization + +### Synchronize Batch Normalization + +Synchronize batch norm statistics across GPUs: + +```python +trainer = Trainer( + accelerator='gpu', + devices=4, + strategy='ddp', + sync_batchnorm=True, # Sync batch norm across GPUs +) +``` + +### Find Optimal Batch Size + +```python +from lightning.pytorch.tuner import Tuner + +trainer = Trainer() +tuner = Tuner(trainer) + +# Auto-scale batch size +tuner.scale_batch_size(model, mode="power") # or "binsearch" +``` + +### Gradient Clipping + +Prevent gradient explosion in distributed training: + +```python +trainer = Trainer( + gradient_clip_val=1.0, + gradient_clip_algorithm='norm', # or 'value' +) +``` + +### Benchmark Mode + +Enable cudnn.benchmark for consistent input sizes: + +```python +trainer = Trainer( + benchmark=True, # Optimize for consistent input sizes +) +``` + +## Distributed Data Loading + +### Automatic Distributed Sampling + +Lightning automatically handles distributed sampling: + +```python +# No changes needed - Lightning handles this automatically +def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=32, + shuffle=True, # Lightning converts to DistributedSampler + ) +``` + +### Manual Control + +```python +# Disable automatic distributed sampler +trainer = Trainer( + use_distributed_sampler=False, +) + +# Manual distributed sampler +from torch.utils.data.distributed import DistributedSampler + +def train_dataloader(self): + sampler = DistributedSampler(self.train_dataset) + return DataLoader( + self.train_dataset, + batch_size=32, + sampler=sampler, + ) +``` + +### Data Loading Best Practices + +```python +def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=32, + num_workers=4, # Use multiple workers + pin_memory=True, # Faster CPU-GPU transfer + persistent_workers=True, # Keep workers alive between epochs + ) +``` + +## Common Patterns + +### Logging in Distributed Training + +```python +def training_step(self, batch, batch_idx): + loss = self.compute_loss(batch) + + # Automatically syncs across processes + self.log('train_loss', loss, sync_dist=True) + + return loss +``` + +### Rank-Specific Operations + +```python +def training_step(self, batch, batch_idx): + # Run only on rank 0 (main process) + if self.trainer.is_global_zero: + print("This only prints once across all processes") + + # Get current rank + rank = self.trainer.global_rank + world_size = self.trainer.world_size + + return loss +``` + +### Barrier Synchronization + +```python +def on_train_epoch_end(self): + # Wait for all processes + self.trainer.strategy.barrier() + + # Now all processes are synchronized + if self.trainer.is_global_zero: + # Save something only once + self.save_artifacts() +``` + +## Troubleshooting + +### Common Issues + +**1. Out of Memory:** +- Reduce batch size +- Enable gradient accumulation +- Use FSDP or DeepSpeed +- Enable activation checkpointing +- Use mixed precision + +**2. Slow Training:** +- Check data loading (use `num_workers > 0`) +- Enable `pin_memory=True` and `persistent_workers=True` +- Use `benchmark=True` for consistent input sizes +- Profile with `profiler='simple'` + +**3. Hanging:** +- Ensure all processes execute same collectives +- Check for `if` statements that differ across ranks +- Use barrier synchronization when needed + +**4. Inconsistent Results:** +- Set `deterministic=True` +- Use `seed_everything()` +- Ensure proper gradient synchronization + +### Debugging Distributed Training + +```python +# Test with single GPU first +trainer = Trainer(accelerator='gpu', devices=1) + +# Then test with 2 GPUs +trainer = Trainer(accelerator='gpu', devices=2, strategy='ddp') + +# Use fast_dev_run for quick testing +trainer = Trainer( + accelerator='gpu', + devices=2, + strategy='ddp', + fast_dev_run=10, # Run 10 batches only +) +``` + +## Strategy Selection Guide + +| Model Size | Available Memory | Recommended Strategy | +|-----------|------------------|---------------------| +| < 500M params | Fits in 1 GPU | Single GPU | +| < 500M params | Fits across GPUs | DDP | +| 500M - 3B params | Limited memory | FSDP or DeepSpeed Stage 2 | +| 3B+ params | Very limited memory | FSDP or DeepSpeed Stage 3 | +| Any size | Maximum efficiency | DeepSpeed with offloading | +| Multiple nodes | Any | DDP (< 500M) or FSDP/DeepSpeed (> 500M) | diff --git a/scientific-packages/pytorch-lightning/scripts/quick_trainer_setup.py b/scientific-packages/pytorch-lightning/scripts/quick_trainer_setup.py new file mode 100644 index 0000000..b07ae84 --- /dev/null +++ b/scientific-packages/pytorch-lightning/scripts/quick_trainer_setup.py @@ -0,0 +1,262 @@ +""" +Helper script to quickly set up a PyTorch Lightning Trainer with common configurations. + +This script provides preset configurations for different training scenarios +and makes it easy to create a Trainer with best practices. +""" + +import lightning as L +from lightning.pytorch.callbacks import ( + ModelCheckpoint, + EarlyStopping, + LearningRateMonitor, + RichProgressBar, + ModelSummary, +) +from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger + + +def create_trainer( + preset: str = "default", + max_epochs: int = 100, + accelerator: str = "auto", + devices: int = 1, + log_dir: str = "./logs", + experiment_name: str = "lightning_experiment", + enable_checkpointing: bool = True, + enable_early_stopping: bool = True, + **kwargs +): + """ + Create a Lightning Trainer with preset configurations. + + Args: + preset: Configuration preset - "default", "fast_dev", "production", "distributed" + max_epochs: Maximum number of training epochs + accelerator: Device to use ("auto", "gpu", "cpu", "tpu") + devices: Number of devices to use + log_dir: Directory for logs and checkpoints + experiment_name: Name for the experiment + enable_checkpointing: Whether to enable model checkpointing + enable_early_stopping: Whether to enable early stopping + **kwargs: Additional arguments to pass to Trainer + + Returns: + Configured Lightning Trainer instance + """ + + callbacks = [] + logger_list = [] + + # Configure based on preset + if preset == "fast_dev": + # Fast development run - minimal epochs, quick debugging + config = { + "fast_dev_run": False, + "max_epochs": 3, + "limit_train_batches": 100, + "limit_val_batches": 50, + "log_every_n_steps": 10, + "enable_progress_bar": True, + "enable_model_summary": True, + } + + elif preset == "production": + # Production-ready configuration with all bells and whistles + config = { + "max_epochs": max_epochs, + "precision": "16-mixed", + "gradient_clip_val": 1.0, + "log_every_n_steps": 50, + "enable_progress_bar": True, + "enable_model_summary": True, + "deterministic": True, + "benchmark": True, + } + + # Add model checkpointing + if enable_checkpointing: + callbacks.append( + ModelCheckpoint( + dirpath=f"{log_dir}/{experiment_name}/checkpoints", + filename="{epoch}-{val_loss:.2f}", + monitor="val_loss", + mode="min", + save_top_k=3, + save_last=True, + verbose=True, + ) + ) + + # Add early stopping + if enable_early_stopping: + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=10, + mode="min", + verbose=True, + ) + ) + + # Add learning rate monitor + callbacks.append(LearningRateMonitor(logging_interval="epoch")) + + # Add TensorBoard logger + logger_list.append( + TensorBoardLogger( + save_dir=log_dir, + name=experiment_name, + version=None, + ) + ) + + elif preset == "distributed": + # Distributed training configuration + config = { + "max_epochs": max_epochs, + "strategy": "ddp", + "precision": "16-mixed", + "sync_batchnorm": True, + "use_distributed_sampler": True, + "log_every_n_steps": 50, + "enable_progress_bar": True, + } + + # Add model checkpointing + if enable_checkpointing: + callbacks.append( + ModelCheckpoint( + dirpath=f"{log_dir}/{experiment_name}/checkpoints", + filename="{epoch}-{val_loss:.2f}", + monitor="val_loss", + mode="min", + save_top_k=3, + save_last=True, + ) + ) + + else: # default + # Default configuration - balanced for most use cases + config = { + "max_epochs": max_epochs, + "log_every_n_steps": 50, + "enable_progress_bar": True, + "enable_model_summary": True, + } + + # Add basic checkpointing + if enable_checkpointing: + callbacks.append( + ModelCheckpoint( + dirpath=f"{log_dir}/{experiment_name}/checkpoints", + filename="{epoch}-{val_loss:.2f}", + monitor="val_loss", + save_last=True, + ) + ) + + # Add CSV logger + logger_list.append( + CSVLogger( + save_dir=log_dir, + name=experiment_name, + ) + ) + + # Add progress bar + if config.get("enable_progress_bar", True): + callbacks.append(RichProgressBar()) + + # Merge with provided kwargs + final_config = { + **config, + "accelerator": accelerator, + "devices": devices, + "callbacks": callbacks, + "logger": logger_list if logger_list else True, + **kwargs, + } + + # Create and return trainer + return L.Trainer(**final_config) + + +def create_debugging_trainer(): + """Create a trainer optimized for debugging.""" + return create_trainer( + preset="fast_dev", + max_epochs=1, + limit_train_batches=10, + limit_val_batches=5, + num_sanity_val_steps=2, + ) + + +def create_gpu_trainer(num_gpus: int = 1, precision: str = "16-mixed"): + """Create a trainer optimized for GPU training.""" + return create_trainer( + preset="production", + accelerator="gpu", + devices=num_gpus, + precision=precision, + ) + + +def create_distributed_trainer(num_gpus: int = 2, num_nodes: int = 1): + """Create a trainer for distributed training across multiple GPUs.""" + return create_trainer( + preset="distributed", + accelerator="gpu", + devices=num_gpus, + num_nodes=num_nodes, + strategy="ddp", + ) + + +# Example usage +if __name__ == "__main__": + print("Creating different trainer configurations...\n") + + # 1. Default trainer + print("1. Default trainer:") + trainer_default = create_trainer(preset="default", max_epochs=50) + print(f" Max epochs: {trainer_default.max_epochs}") + print(f" Accelerator: {trainer_default.accelerator}") + print(f" Callbacks: {len(trainer_default.callbacks)}") + print() + + # 2. Fast development trainer + print("2. Fast development trainer:") + trainer_dev = create_trainer(preset="fast_dev") + print(f" Max epochs: {trainer_dev.max_epochs}") + print(f" Train batches limit: {trainer_dev.limit_train_batches}") + print() + + # 3. Production trainer + print("3. Production trainer:") + trainer_prod = create_trainer( + preset="production", + max_epochs=100, + experiment_name="my_experiment" + ) + print(f" Max epochs: {trainer_prod.max_epochs}") + print(f" Precision: {trainer_prod.precision}") + print(f" Callbacks: {len(trainer_prod.callbacks)}") + print() + + # 4. Debugging trainer + print("4. Debugging trainer:") + trainer_debug = create_debugging_trainer() + print(f" Max epochs: {trainer_debug.max_epochs}") + print(f" Train batches: {trainer_debug.limit_train_batches}") + print() + + # 5. GPU trainer + print("5. GPU trainer:") + trainer_gpu = create_gpu_trainer(num_gpus=1) + print(f" Accelerator: {trainer_gpu.accelerator}") + print(f" Precision: {trainer_gpu.precision}") + print() + + print("All trainer configurations created successfully!") diff --git a/scientific-packages/pytorch-lightning/scripts/template_datamodule.py b/scientific-packages/pytorch-lightning/scripts/template_datamodule.py new file mode 100644 index 0000000..4b33027 --- /dev/null +++ b/scientific-packages/pytorch-lightning/scripts/template_datamodule.py @@ -0,0 +1,221 @@ +""" +Template for creating a PyTorch Lightning DataModule. + +This template includes all common hooks and patterns for organizing +data processing workflows with best practices. +""" + +import lightning as L +from torch.utils.data import DataLoader, Dataset, random_split +import torch + + +class TemplateDataset(Dataset): + """Example dataset - replace with your actual dataset.""" + + def __init__(self, data, targets, transform=None): + self.data = data + self.targets = targets + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + x = self.data[idx] + y = self.targets[idx] + + if self.transform: + x = self.transform(x) + + return x, y + + +class TemplateDataModule(L.LightningDataModule): + """Template DataModule with all common hooks and patterns.""" + + def __init__( + self, + data_dir: str = "./data", + batch_size: int = 32, + num_workers: int = 4, + train_val_split: tuple = (0.8, 0.2), + seed: int = 42, + pin_memory: bool = True, + persistent_workers: bool = True, + ): + super().__init__() + + # Save hyperparameters + self.save_hyperparameters() + + # Initialize attributes + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_split = train_val_split + self.seed = seed + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + # Placeholders for datasets + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + # Placeholder for transforms + self.train_transform = None + self.val_transform = None + self.test_transform = None + + def prepare_data(self): + """ + Download and prepare data (called only on 1 GPU/TPU in distributed settings). + Use this for downloading, tokenizing, etc. Do NOT set state here (no self.x = y). + """ + # Example: Download datasets + # datasets.MNIST(self.data_dir, train=True, download=True) + # datasets.MNIST(self.data_dir, train=False, download=True) + pass + + def setup(self, stage: str = None): + """ + Load data and create train/val/test splits (called on every GPU/TPU in distributed). + Use this for splitting, creating datasets, etc. Setting state is OK here (self.x = y). + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict' + """ + + # Fit stage: setup training and validation datasets + if stage == "fit" or stage is None: + # Load full dataset + # Example: full_dataset = datasets.MNIST(self.data_dir, train=True, transform=self.train_transform) + + # Create dummy data for template + full_data = torch.randn(1000, 784) + full_targets = torch.randint(0, 10, (1000,)) + full_dataset = TemplateDataset(full_data, full_targets, transform=self.train_transform) + + # Split into train and validation + train_size = int(len(full_dataset) * self.train_val_split[0]) + val_size = len(full_dataset) - train_size + + self.train_dataset, self.val_dataset = random_split( + full_dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(self.seed) + ) + + # Apply validation transform if different from train + if self.val_transform: + self.val_dataset.dataset.transform = self.val_transform + + # Test stage: setup test dataset + if stage == "test" or stage is None: + # Example: self.test_dataset = datasets.MNIST( + # self.data_dir, train=False, transform=self.test_transform + # ) + + # Create dummy test data for template + test_data = torch.randn(200, 784) + test_targets = torch.randint(0, 10, (200,)) + self.test_dataset = TemplateDataset(test_data, test_targets, transform=self.test_transform) + + # Predict stage: setup prediction dataset + if stage == "predict" or stage is None: + # Example: self.predict_dataset = YourCustomDataset(...) + + # Create dummy predict data for template + predict_data = torch.randn(100, 784) + predict_targets = torch.zeros(100, dtype=torch.long) + self.predict_dataset = TemplateDataset(predict_data, predict_targets) + + def train_dataloader(self): + """Return training dataloader.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + ) + + def val_dataloader(self): + """Return validation dataloader.""" + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + ) + + def test_dataloader(self): + """Return test dataloader.""" + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + ) + + def predict_dataloader(self): + """Return prediction dataloader.""" + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers if self.num_workers > 0 else False, + ) + + def teardown(self, stage: str = None): + """Clean up after fit, validate, test, or predict.""" + # Example: close database connections, clear caches, etc. + pass + + def state_dict(self): + """Save state for checkpointing.""" + # Return anything you want to save in the checkpoint + return {} + + def load_state_dict(self, state_dict): + """Load state from checkpoint.""" + # Restore state from checkpoint + pass + + +# Example usage +if __name__ == "__main__": + # Create datamodule + datamodule = TemplateDataModule( + data_dir="./data", + batch_size=32, + num_workers=4, + train_val_split=(0.8, 0.2), + ) + + # Prepare and setup data + datamodule.prepare_data() + datamodule.setup("fit") + + # Get dataloaders + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + + print("Template DataModule created successfully!") + print(f"Train batches: {len(train_loader)}") + print(f"Val batches: {len(val_loader)}") + print(f"Batch size: {datamodule.batch_size}") + + # Test a batch + batch = next(iter(train_loader)) + x, y = batch + print(f"Batch shape: {x.shape}, {y.shape}") diff --git a/scientific-packages/pytorch-lightning/scripts/template_lightning_module.py b/scientific-packages/pytorch-lightning/scripts/template_lightning_module.py new file mode 100644 index 0000000..4bea776 --- /dev/null +++ b/scientific-packages/pytorch-lightning/scripts/template_lightning_module.py @@ -0,0 +1,215 @@ +""" +Template for creating a PyTorch Lightning LightningModule. + +This template includes all common hooks and patterns for building +a Lightning model with best practices. +""" + +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import Adam, SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + + +class TemplateLightningModule(L.LightningModule): + """Template LightningModule with all common hooks and patterns.""" + + def __init__( + self, + # Model architecture parameters + input_dim: int = 784, + hidden_dim: int = 128, + output_dim: int = 10, + # Optimization parameters + learning_rate: float = 1e-3, + optimizer_type: str = "adam", + scheduler_type: str = None, + # Other hyperparameters + dropout: float = 0.1, + ): + super().__init__() + + # Save hyperparameters for checkpointing and logging + self.save_hyperparameters() + + # Define model architecture + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, output_dim) + ) + + # Define loss function + self.criterion = nn.CrossEntropyLoss() + + # For tracking validation outputs (optional) + self.validation_step_outputs = [] + + def forward(self, x): + """Forward pass for inference.""" + return self.model(x) + + def training_step(self, batch, batch_idx): + """Training step - called for each training batch.""" + x, y = batch + + # Forward pass + logits = self(x) + loss = self.criterion(logits, y) + + # Calculate accuracy + preds = torch.argmax(logits, dim=1) + acc = (preds == y).float().mean() + + # Log metrics + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def validation_step(self, batch, batch_idx): + """Validation step - called for each validation batch.""" + x, y = batch + + # Forward pass (model automatically in eval mode) + logits = self(x) + loss = self.criterion(logits, y) + + # Calculate accuracy + preds = torch.argmax(logits, dim=1) + acc = (preds == y).float().mean() + + # Log metrics + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True) + + # Optional: store outputs for epoch-level processing + self.validation_step_outputs.append({"loss": loss, "acc": acc}) + + return loss + + def on_validation_epoch_end(self): + """Called at the end of validation epoch.""" + # Optional: process all validation outputs + if self.validation_step_outputs: + avg_loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean() + avg_acc = torch.stack([x["acc"] for x in self.validation_step_outputs]).mean() + + # Log epoch-level metrics if needed + # self.log("val_epoch_loss", avg_loss) + # self.log("val_epoch_acc", avg_acc) + + # Clear outputs + self.validation_step_outputs.clear() + + def test_step(self, batch, batch_idx): + """Test step - called for each test batch.""" + x, y = batch + + # Forward pass + logits = self(x) + loss = self.criterion(logits, y) + + # Calculate accuracy + preds = torch.argmax(logits, dim=1) + acc = (preds == y).float().mean() + + # Log metrics + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.log("test_acc", acc, on_step=False, on_epoch=True) + + return loss + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Prediction step - called for each prediction batch.""" + x, y = batch + logits = self(x) + preds = torch.argmax(logits, dim=1) + return preds + + def configure_optimizers(self): + """Configure optimizer and learning rate scheduler.""" + # Create optimizer + if self.hparams.optimizer_type.lower() == "adam": + optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate) + elif self.hparams.optimizer_type.lower() == "sgd": + optimizer = SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9) + else: + raise ValueError(f"Unknown optimizer: {self.hparams.optimizer_type}") + + # Configure with scheduler if specified + if self.hparams.scheduler_type: + if self.hparams.scheduler_type.lower() == "reduce_on_plateau": + scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + } + } + elif self.hparams.scheduler_type.lower() == "step": + scheduler = StepLR(optimizer, step_size=10, gamma=0.1) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + "frequency": 1, + } + } + + return optimizer + + # Optional: Additional hooks for custom behavior + + def on_train_start(self): + """Called at the beginning of training.""" + pass + + def on_train_epoch_start(self): + """Called at the beginning of each training epoch.""" + pass + + def on_train_epoch_end(self): + """Called at the end of each training epoch.""" + pass + + def on_train_end(self): + """Called at the end of training.""" + pass + + +# Example usage +if __name__ == "__main__": + # Create model + model = TemplateLightningModule( + input_dim=784, + hidden_dim=128, + output_dim=10, + learning_rate=1e-3, + optimizer_type="adam", + scheduler_type="reduce_on_plateau" + ) + + # Create trainer + trainer = L.Trainer( + max_epochs=10, + accelerator="auto", + devices=1, + log_every_n_steps=50, + ) + + # Note: You would need to provide dataloaders + # trainer.fit(model, train_dataloader, val_dataloader) + + print("Template LightningModule created successfully!") + print(f"Model hyperparameters: {model.hparams}") diff --git a/scientific-packages/rdkit/SKILL.md b/scientific-packages/rdkit/SKILL.md new file mode 100644 index 0000000..7afcff9 --- /dev/null +++ b/scientific-packages/rdkit/SKILL.md @@ -0,0 +1,763 @@ +--- +name: rdkit +description: Comprehensive cheminformatics toolkit for molecular manipulation, analysis, and visualization. Use this skill when working with chemical structures (SMILES, MOL files, SDF), calculating molecular descriptors, performing substructure searches, generating fingerprints, visualizing molecules, processing chemical reactions, or conducting drug discovery workflows. +--- + +# RDKit Cheminformatics Toolkit + +## Overview + +RDKit is a comprehensive cheminformatics library providing Python APIs for molecular analysis and manipulation. This skill provides guidance for reading/writing molecular structures, calculating descriptors, fingerprinting, substructure searching, chemical reactions, 2D/3D coordinate generation, and molecular visualization. Use this skill for drug discovery, computational chemistry, and cheminformatics research tasks. + +## Core Capabilities + +### 1. Molecular I/O and Creation + +**Reading Molecules:** + +Read molecular structures from various formats: + +```python +from rdkit import Chem + +# From SMILES strings +mol = Chem.MolFromSmiles('Cc1ccccc1') # Returns Mol object or None + +# From MOL files +mol = Chem.MolFromMolFile('path/to/file.mol') + +# From MOL blocks (string data) +mol = Chem.MolFromMolBlock(mol_block_string) + +# From InChI +mol = Chem.MolFromInchi('InChI=1S/C6H6/c1-2-4-6-5-3-1/h1-6H') +``` + +**Writing Molecules:** + +Convert molecules to text representations: + +```python +# To canonical SMILES +smiles = Chem.MolToSmiles(mol) + +# To MOL block +mol_block = Chem.MolToMolBlock(mol) + +# To InChI +inchi = Chem.MolToInchi(mol) +``` + +**Batch Processing:** + +For processing multiple molecules, use Supplier/Writer objects: + +```python +# Read SDF files +suppl = Chem.SDMolSupplier('molecules.sdf') +for mol in suppl: + if mol is not None: # Check for parsing errors + # Process molecule + pass + +# Read SMILES files +suppl = Chem.SmilesMolSupplier('molecules.smi', titleLine=False) + +# For large files or compressed data +with gzip.open('molecules.sdf.gz') as f: + suppl = Chem.ForwardSDMolSupplier(f) + for mol in suppl: + # Process molecule + pass + +# Multithreaded processing for large datasets +suppl = Chem.MultithreadedSDMolSupplier('molecules.sdf') + +# Write molecules to SDF +writer = Chem.SDWriter('output.sdf') +for mol in molecules: + writer.write(mol) +writer.close() +``` + +**Important Notes:** +- All `MolFrom*` functions return `None` on failure with error messages +- Always check for `None` before processing molecules +- Molecules are automatically sanitized on import (validates valence, perceives aromaticity) + +### 2. Molecular Sanitization and Validation + +RDKit automatically sanitizes molecules during parsing, executing 13 steps including valence checking, aromaticity perception, and chirality assignment. + +**Sanitization Control:** + +```python +# Disable automatic sanitization +mol = Chem.MolFromSmiles('C1=CC=CC=C1', sanitize=False) + +# Manual sanitization +Chem.SanitizeMol(mol) + +# Detect problems before sanitization +problems = Chem.DetectChemistryProblems(mol) +for problem in problems: + print(problem.GetType(), problem.Message()) + +# Partial sanitization (skip specific steps) +from rdkit.Chem import rdMolStandardize +Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ALL ^ Chem.SANITIZE_PROPERTIES) +``` + +**Common Sanitization Issues:** +- Atoms with explicit valence exceeding maximum allowed will raise exceptions +- Invalid aromatic rings will cause kekulization errors +- Radical electrons may not be properly assigned without explicit specification + +### 3. Molecular Analysis and Properties + +**Accessing Molecular Structure:** + +```python +# Iterate atoms and bonds +for atom in mol.GetAtoms(): + print(atom.GetSymbol(), atom.GetIdx(), atom.GetDegree()) + +for bond in mol.GetBonds(): + print(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType()) + +# Ring information +ring_info = mol.GetRingInfo() +ring_info.NumRings() +ring_info.AtomRings() # Returns tuples of atom indices + +# Check if atom is in ring +atom = mol.GetAtomWithIdx(0) +atom.IsInRing() +atom.IsInRingSize(6) # Check for 6-membered rings + +# Find smallest set of smallest rings (SSSR) +from rdkit.Chem import GetSymmSSSR +rings = GetSymmSSSR(mol) +``` + +**Stereochemistry:** + +```python +# Find chiral centers +from rdkit.Chem import FindMolChiralCenters +chiral_centers = FindMolChiralCenters(mol, includeUnassigned=True) +# Returns list of (atom_idx, chirality) tuples + +# Assign stereochemistry from 3D coordinates +from rdkit.Chem import AssignStereochemistryFrom3D +AssignStereochemistryFrom3D(mol) + +# Check bond stereochemistry +bond = mol.GetBondWithIdx(0) +stereo = bond.GetStereo() # STEREONONE, STEREOZ, STEREOE, etc. +``` + +**Fragment Analysis:** + +```python +# Get disconnected fragments +frags = Chem.GetMolFrags(mol, asMols=True) + +# Fragment on specific bonds +from rdkit.Chem import FragmentOnBonds +frag_mol = FragmentOnBonds(mol, [bond_idx1, bond_idx2]) + +# Count ring systems +from rdkit.Chem.Scaffolds import MurckoScaffold +scaffold = MurckoScaffold.GetScaffoldForMol(mol) +``` + +### 4. Molecular Descriptors and Properties + +**Basic Descriptors:** + +```python +from rdkit.Chem import Descriptors + +# Molecular weight +mw = Descriptors.MolWt(mol) +exact_mw = Descriptors.ExactMolWt(mol) + +# LogP (lipophilicity) +logp = Descriptors.MolLogP(mol) + +# Topological polar surface area +tpsa = Descriptors.TPSA(mol) + +# Number of hydrogen bond donors/acceptors +hbd = Descriptors.NumHDonors(mol) +hba = Descriptors.NumHAcceptors(mol) + +# Number of rotatable bonds +rot_bonds = Descriptors.NumRotatableBonds(mol) + +# Number of aromatic rings +aromatic_rings = Descriptors.NumAromaticRings(mol) +``` + +**Batch Descriptor Calculation:** + +```python +# Calculate all descriptors at once +all_descriptors = Descriptors.CalcMolDescriptors(mol) +# Returns dictionary: {'MolWt': 180.16, 'MolLogP': 1.23, ...} + +# Get list of available descriptor names +descriptor_names = [desc[0] for desc in Descriptors._descList] +``` + +**Lipinski's Rule of Five:** + +```python +# Check drug-likeness +mw = Descriptors.MolWt(mol) <= 500 +logp = Descriptors.MolLogP(mol) <= 5 +hbd = Descriptors.NumHDonors(mol) <= 5 +hba = Descriptors.NumHAcceptors(mol) <= 10 + +is_drug_like = mw and logp and hbd and hba +``` + +### 5. Fingerprints and Molecular Similarity + +**Fingerprint Types:** + +```python +from rdkit.Chem import AllChem, RDKFingerprint +from rdkit.Chem.AtomPairs import Pairs, Torsions +from rdkit.Chem import MACCSkeys + +# RDKit topological fingerprint +fp = Chem.RDKFingerprint(mol) + +# Morgan fingerprints (circular fingerprints, similar to ECFP) +fp = AllChem.GetMorganFingerprint(mol, radius=2) +fp_bits = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) + +# MACCS keys (166-bit structural key) +fp = MACCSkeys.GenMACCSKeys(mol) + +# Atom pair fingerprints +fp = Pairs.GetAtomPairFingerprint(mol) + +# Topological torsion fingerprints +fp = Torsions.GetTopologicalTorsionFingerprint(mol) + +# Avalon fingerprints (if available) +from rdkit.Avalon import pyAvalonTools +fp = pyAvalonTools.GetAvalonFP(mol) +``` + +**Similarity Calculation:** + +```python +from rdkit import DataStructs + +# Calculate Tanimoto similarity +fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2) +fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2) +similarity = DataStructs.TanimotoSimilarity(fp1, fp2) + +# Calculate similarity for multiple molecules +similarities = DataStructs.BulkTanimotoSimilarity(fp1, [fp2, fp3, fp4]) + +# Other similarity metrics +dice = DataStructs.DiceSimilarity(fp1, fp2) +cosine = DataStructs.CosineSimilarity(fp1, fp2) +``` + +**Clustering and Diversity:** + +```python +# Butina clustering based on fingerprint similarity +from rdkit.ML.Cluster import Butina + +# Calculate distance matrix +dists = [] +fps = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols] +for i in range(len(fps)): + sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) + dists.extend([1-sim for sim in sims]) + +# Cluster with distance cutoff +clusters = Butina.ClusterData(dists, len(fps), distThresh=0.3, isDistData=True) +``` + +### 6. Substructure Searching and SMARTS + +**Basic Substructure Matching:** + +```python +# Define query using SMARTS +query = Chem.MolFromSmarts('[#6]1:[#6]:[#6]:[#6]:[#6]:[#6]:1') # Benzene ring + +# Check if molecule contains substructure +has_match = mol.HasSubstructMatch(query) + +# Get all matches (returns tuple of tuples with atom indices) +matches = mol.GetSubstructMatches(query) + +# Get only first match +match = mol.GetSubstructMatch(query) +``` + +**Common SMARTS Patterns:** + +```python +# Primary alcohols +primary_alcohol = Chem.MolFromSmarts('[CH2][OH1]') + +# Carboxylic acids +carboxylic_acid = Chem.MolFromSmarts('C(=O)[OH]') + +# Amides +amide = Chem.MolFromSmarts('C(=O)N') + +# Aromatic heterocycles +aromatic_n = Chem.MolFromSmarts('[nR]') # Aromatic nitrogen in ring + +# Macrocycles (rings > 12 atoms) +macrocycle = Chem.MolFromSmarts('[r{12-}]') +``` + +**Matching Rules:** +- Unspecified properties in query match any value in target +- Hydrogens are ignored unless explicitly specified +- Charged query atom won't match uncharged target atom +- Aromatic query atom won't match aliphatic target atom (unless query is generic) + +### 7. Chemical Reactions + +**Reaction SMARTS:** + +```python +from rdkit.Chem import AllChem + +# Define reaction using SMARTS: reactants >> products +rxn = AllChem.ReactionFromSmarts('[C:1]=[O:2]>>[C:1][O:2]') # Ketone reduction + +# Apply reaction to molecules +reactants = (mol1,) +products = rxn.RunReactants(reactants) + +# Products is tuple of tuples (one tuple per product set) +for product_set in products: + for product in product_set: + # Sanitize product + Chem.SanitizeMol(product) +``` + +**Reaction Features:** +- Atom mapping preserves specific atoms between reactants and products +- Dummy atoms in products are replaced by corresponding reactant atoms +- "Any" bonds inherit bond order from reactants +- Chirality preserved unless explicitly changed + +**Reaction Similarity:** + +```python +# Generate reaction fingerprints +fp = AllChem.CreateDifferenceFingerprintForReaction(rxn) + +# Compare reactions +similarity = DataStructs.TanimotoSimilarity(fp1, fp2) +``` + +### 8. 2D and 3D Coordinate Generation + +**2D Coordinate Generation:** + +```python +from rdkit.Chem import AllChem + +# Generate 2D coordinates for depiction +AllChem.Compute2DCoords(mol) + +# Align molecule to template structure +template = Chem.MolFromSmiles('c1ccccc1') +AllChem.Compute2DCoords(template) +AllChem.GenerateDepictionMatching2DStructure(mol, template) +``` + +**3D Coordinate Generation and Conformers:** + +```python +# Generate single 3D conformer using ETKDG +AllChem.EmbedMolecule(mol, randomSeed=42) + +# Generate multiple conformers +conf_ids = AllChem.EmbedMultipleConfs(mol, numConfs=10, randomSeed=42) + +# Optimize geometry with force field +AllChem.UFFOptimizeMolecule(mol) # UFF force field +AllChem.MMFFOptimizeMolecule(mol) # MMFF94 force field + +# Optimize all conformers +for conf_id in conf_ids: + AllChem.MMFFOptimizeMolecule(mol, confId=conf_id) + +# Calculate RMSD between conformers +from rdkit.Chem import AllChem +rms = AllChem.GetConformerRMS(mol, conf_id1, conf_id2) + +# Align molecules +AllChem.AlignMol(probe_mol, ref_mol) +``` + +**Constrained Embedding:** + +```python +# Embed with part of molecule constrained to specific coordinates +AllChem.ConstrainedEmbed(mol, core_mol) +``` + +### 9. Molecular Visualization + +**Basic Drawing:** + +```python +from rdkit.Chem import Draw + +# Draw single molecule to PIL image +img = Draw.MolToImage(mol, size=(300, 300)) +img.save('molecule.png') + +# Draw to file directly +Draw.MolToFile(mol, 'molecule.png') + +# Draw multiple molecules in grid +mols = [mol1, mol2, mol3, mol4] +img = Draw.MolsToGridImage(mols, molsPerRow=2, subImgSize=(200, 200)) +``` + +**Highlighting Substructures:** + +```python +# Highlight substructure match +query = Chem.MolFromSmarts('c1ccccc1') +match = mol.GetSubstructMatch(query) + +img = Draw.MolToImage(mol, highlightAtoms=match) + +# Custom highlight colors +highlight_colors = {atom_idx: (1, 0, 0) for atom_idx in match} # Red +img = Draw.MolToImage(mol, highlightAtoms=match, + highlightAtomColors=highlight_colors) +``` + +**Customizing Visualization:** + +```python +from rdkit.Chem.Draw import rdMolDraw2D + +# Create drawer with custom options +drawer = rdMolDraw2D.MolDraw2DCairo(300, 300) +opts = drawer.drawOptions() + +# Customize options +opts.addAtomIndices = True +opts.addStereoAnnotation = True +opts.bondLineWidth = 2 + +# Draw molecule +drawer.DrawMolecule(mol) +drawer.FinishDrawing() + +# Save to file +with open('molecule.png', 'wb') as f: + f.write(drawer.GetDrawingText()) +``` + +**Jupyter Notebook Integration:** + +```python +# Enable inline display in Jupyter +from rdkit.Chem.Draw import IPythonConsole + +# Customize default display +IPythonConsole.ipython_useSVG = True # Use SVG instead of PNG +IPythonConsole.molSize = (300, 300) # Default size + +# Molecules now display automatically +mol # Shows molecule image +``` + +**Visualizing Fingerprint Bits:** + +```python +# Show what molecular features a fingerprint bit represents +from rdkit.Chem import Draw + +# For Morgan fingerprints +bit_info = {} +fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, bitInfo=bit_info) + +# Draw environment for specific bit +img = Draw.DrawMorganBit(mol, bit_id, bit_info) +``` + +### 10. Molecular Modification + +**Adding/Removing Hydrogens:** + +```python +# Add explicit hydrogens +mol_h = Chem.AddHs(mol) + +# Remove explicit hydrogens +mol = Chem.RemoveHs(mol_h) +``` + +**Kekulization and Aromaticity:** + +```python +# Convert aromatic bonds to alternating single/double +Chem.Kekulize(mol) + +# Set aromaticity +Chem.SetAromaticity(mol) +``` + +**Replacing Substructures:** + +```python +# Replace substructure with another structure +query = Chem.MolFromSmarts('c1ccccc1') # Benzene +replacement = Chem.MolFromSmiles('C1CCCCC1') # Cyclohexane + +new_mol = Chem.ReplaceSubstructs(mol, query, replacement)[0] +``` + +**Neutralizing Charges:** + +```python +# Remove formal charges by adding/removing hydrogens +from rdkit.Chem.MolStandardize import rdMolStandardize + +# Using Uncharger +uncharger = rdMolStandardize.Uncharger() +mol_neutral = uncharger.uncharge(mol) +``` + +### 11. Working with Molecular Hashes and Standardization + +**Molecular Hashing:** + +```python +from rdkit.Chem import rdMolHash + +# Generate Murcko scaffold hash +scaffold_hash = rdMolHash.MolHash(mol, rdMolHash.HashFunction.MurckoScaffold) + +# Canonical SMILES hash +canonical_hash = rdMolHash.MolHash(mol, rdMolHash.HashFunction.CanonicalSmiles) + +# Regioisomer hash (ignores stereochemistry) +regio_hash = rdMolHash.MolHash(mol, rdMolHash.HashFunction.Regioisomer) +``` + +**Randomized SMILES:** + +```python +# Generate random SMILES representations (for data augmentation) +from rdkit.Chem import MolToRandomSmilesVect + +random_smiles = MolToRandomSmilesVect(mol, numSmiles=10, randomSeed=42) +``` + +### 12. Pharmacophore and 3D Features + +**Pharmacophore Features:** + +```python +from rdkit.Chem import ChemicalFeatures +from rdkit import RDConfig +import os + +# Load feature factory +fdef_path = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') +factory = ChemicalFeatures.BuildFeatureFactory(fdef_path) + +# Get pharmacophore features +features = factory.GetFeaturesForMol(mol) + +for feat in features: + print(feat.GetFamily(), feat.GetType(), feat.GetAtomIds()) +``` + +## Common Workflows + +### Drug-likeness Analysis + +```python +from rdkit import Chem +from rdkit.Chem import Descriptors + +def analyze_druglikeness(smiles): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + + # Calculate Lipinski descriptors + results = { + 'MW': Descriptors.MolWt(mol), + 'LogP': Descriptors.MolLogP(mol), + 'HBD': Descriptors.NumHDonors(mol), + 'HBA': Descriptors.NumHAcceptors(mol), + 'TPSA': Descriptors.TPSA(mol), + 'RotBonds': Descriptors.NumRotatableBonds(mol) + } + + # Check Lipinski's Rule of Five + results['Lipinski'] = ( + results['MW'] <= 500 and + results['LogP'] <= 5 and + results['HBD'] <= 5 and + results['HBA'] <= 10 + ) + + return results +``` + +### Similarity Screening + +```python +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit import DataStructs + +def similarity_screen(query_smiles, database_smiles, threshold=0.7): + query_mol = Chem.MolFromSmiles(query_smiles) + query_fp = AllChem.GetMorganFingerprintAsBitVect(query_mol, 2) + + hits = [] + for idx, smiles in enumerate(database_smiles): + mol = Chem.MolFromSmiles(smiles) + if mol: + fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2) + sim = DataStructs.TanimotoSimilarity(query_fp, fp) + if sim >= threshold: + hits.append((idx, smiles, sim)) + + return sorted(hits, key=lambda x: x[2], reverse=True) +``` + +### Substructure Filtering + +```python +from rdkit import Chem + +def filter_by_substructure(smiles_list, pattern_smarts): + query = Chem.MolFromSmarts(pattern_smarts) + + hits = [] + for smiles in smiles_list: + mol = Chem.MolFromSmiles(smiles) + if mol and mol.HasSubstructMatch(query): + hits.append(smiles) + + return hits +``` + +## Best Practices + +### Error Handling + +Always check for `None` when parsing molecules: + +```python +mol = Chem.MolFromSmiles(smiles) +if mol is None: + print(f"Failed to parse: {smiles}") + continue +``` + +### Performance Optimization + +**Use binary formats for storage:** + +```python +import pickle + +# Pickle molecules for fast loading +with open('molecules.pkl', 'wb') as f: + pickle.dump(mols, f) + +# Load pickled molecules (much faster than reparsing) +with open('molecules.pkl', 'rb') as f: + mols = pickle.load(f) +``` + +**Use bulk operations:** + +```python +# Calculate fingerprints for all molecules at once +fps = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols] + +# Use bulk similarity calculations +similarities = DataStructs.BulkTanimotoSimilarity(fps[0], fps[1:]) +``` + +### Thread Safety + +RDKit operations are generally thread-safe for: +- Molecule I/O (SMILES, mol blocks) +- Coordinate generation +- Fingerprinting and descriptors +- Substructure searching +- Reactions +- Drawing + +**Not thread-safe:** MolSuppliers when accessed concurrently. + +### Memory Management + +For large datasets: + +```python +# Use ForwardSDMolSupplier to avoid loading entire file +with open('large.sdf') as f: + suppl = Chem.ForwardSDMolSupplier(f) + for mol in suppl: + # Process one molecule at a time + pass + +# Use MultithreadedSDMolSupplier for parallel processing +suppl = Chem.MultithreadedSDMolSupplier('large.sdf', numWriterThreads=4) +``` + +## Common Pitfalls + +1. **Forgetting to check for None:** Always validate molecules after parsing +2. **Sanitization failures:** Use `DetectChemistryProblems()` to debug +3. **Missing hydrogens:** Use `AddHs()` when calculating properties that depend on hydrogen +4. **2D vs 3D:** Generate appropriate coordinates before visualization or 3D analysis +5. **SMARTS matching rules:** Remember that unspecified properties match anything +6. **Thread safety with MolSuppliers:** Don't share supplier objects across threads + +## Resources + +### references/ + +This skill includes detailed API reference documentation: + +- `api_reference.md` - Comprehensive listing of RDKit modules, functions, and classes organized by functionality +- `descriptors_reference.md` - Complete list of available molecular descriptors with descriptions +- `smarts_patterns.md` - Common SMARTS patterns for functional groups and structural features + +Load these references when needing specific API details, parameter information, or pattern examples. + +### scripts/ + +Example scripts for common RDKit workflows: + +- `molecular_properties.py` - Calculate comprehensive molecular properties and descriptors +- `similarity_search.py` - Perform fingerprint-based similarity screening +- `substructure_filter.py` - Filter molecules by substructure patterns + +These scripts can be executed directly or used as templates for custom workflows. diff --git a/scientific-packages/rdkit/references/api_reference.md b/scientific-packages/rdkit/references/api_reference.md new file mode 100644 index 0000000..86fd4fc --- /dev/null +++ b/scientific-packages/rdkit/references/api_reference.md @@ -0,0 +1,432 @@ +# RDKit API Reference + +This document provides a comprehensive reference for RDKit's Python API, organized by functionality. + +## Core Module: rdkit.Chem + +The fundamental module for working with molecules. + +### Molecule I/O + +**Reading Molecules:** + +- `Chem.MolFromSmiles(smiles, sanitize=True)` - Parse SMILES string +- `Chem.MolFromSmarts(smarts)` - Parse SMARTS pattern +- `Chem.MolFromMolFile(filename, sanitize=True, removeHs=True)` - Read MOL file +- `Chem.MolFromMolBlock(molblock, sanitize=True, removeHs=True)` - Parse MOL block string +- `Chem.MolFromMol2File(filename, sanitize=True, removeHs=True)` - Read MOL2 file +- `Chem.MolFromMol2Block(molblock, sanitize=True, removeHs=True)` - Parse MOL2 block +- `Chem.MolFromPDBFile(filename, sanitize=True, removeHs=True)` - Read PDB file +- `Chem.MolFromPDBBlock(pdbblock, sanitize=True, removeHs=True)` - Parse PDB block +- `Chem.MolFromInchi(inchi, sanitize=True, removeHs=True)` - Parse InChI string +- `Chem.MolFromSequence(seq, sanitize=True)` - Create molecule from peptide sequence + +**Writing Molecules:** + +- `Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)` - Convert to SMILES +- `Chem.MolToSmarts(mol, isomericSmarts=False)` - Convert to SMARTS +- `Chem.MolToMolBlock(mol, includeStereo=True, confId=-1)` - Convert to MOL block +- `Chem.MolToMolFile(mol, filename, includeStereo=True, confId=-1)` - Write MOL file +- `Chem.MolToPDBBlock(mol, confId=-1)` - Convert to PDB block +- `Chem.MolToPDBFile(mol, filename, confId=-1)` - Write PDB file +- `Chem.MolToInchi(mol, options='')` - Convert to InChI +- `Chem.MolToInchiKey(mol, options='')` - Generate InChI key +- `Chem.MolToSequence(mol)` - Convert to peptide sequence + +**Batch I/O:** + +- `Chem.SDMolSupplier(filename, sanitize=True, removeHs=True)` - SDF file reader +- `Chem.ForwardSDMolSupplier(fileobj, sanitize=True, removeHs=True)` - Forward-only SDF reader +- `Chem.MultithreadedSDMolSupplier(filename, numWriterThreads=1)` - Parallel SDF reader +- `Chem.SmilesMolSupplier(filename, delimiter=' ', titleLine=True)` - SMILES file reader +- `Chem.SDWriter(filename)` - SDF file writer +- `Chem.SmilesWriter(filename, delimiter=' ', includeHeader=True)` - SMILES file writer + +### Molecular Manipulation + +**Sanitization:** + +- `Chem.SanitizeMol(mol, sanitizeOps=SANITIZE_ALL, catchErrors=False)` - Sanitize molecule +- `Chem.DetectChemistryProblems(mol, sanitizeOps=SANITIZE_ALL)` - Detect sanitization issues +- `Chem.AssignStereochemistry(mol, cleanIt=True, force=False)` - Assign stereochemistry +- `Chem.FindPotentialStereo(mol)` - Find potential stereocenters +- `Chem.AssignStereochemistryFrom3D(mol, confId=-1)` - Assign stereo from 3D coords + +**Hydrogen Management:** + +- `Chem.AddHs(mol, explicitOnly=False, addCoords=False)` - Add explicit hydrogens +- `Chem.RemoveHs(mol, implicitOnly=False, updateExplicitCount=False)` - Remove hydrogens +- `Chem.RemoveAllHs(mol)` - Remove all hydrogens + +**Aromaticity:** + +- `Chem.SetAromaticity(mol, model=AROMATICITY_RDKIT)` - Set aromaticity model +- `Chem.Kekulize(mol, clearAromaticFlags=False)` - Kekulize aromatic bonds +- `Chem.SetConjugation(mol)` - Set conjugation flags + +**Fragments:** + +- `Chem.GetMolFrags(mol, asMols=False, sanitizeFrags=True)` - Get disconnected fragments +- `Chem.FragmentOnBonds(mol, bondIndices, addDummies=True)` - Fragment on specific bonds +- `Chem.ReplaceSubstructs(mol, query, replacement, replaceAll=False)` - Replace substructures +- `Chem.DeleteSubstructs(mol, query, onlyFrags=False)` - Delete substructures + +**Stereochemistry:** + +- `Chem.FindMolChiralCenters(mol, includeUnassigned=False, useLegacyImplementation=False)` - Find chiral centers +- `Chem.FindPotentialStereo(mol, cleanIt=True)` - Find potential stereocenters + +### Substructure Searching + +**Basic Matching:** + +- `mol.HasSubstructMatch(query, useChirality=False)` - Check for substructure match +- `mol.GetSubstructMatch(query, useChirality=False)` - Get first match +- `mol.GetSubstructMatches(query, uniquify=True, useChirality=False)` - Get all matches +- `mol.GetSubstructMatches(query, maxMatches=1000)` - Limit number of matches + +### Molecular Properties + +**Atom Methods:** + +- `atom.GetSymbol()` - Atomic symbol +- `atom.GetAtomicNum()` - Atomic number +- `atom.GetDegree()` - Number of bonds +- `atom.GetTotalDegree()` - Including hydrogens +- `atom.GetFormalCharge()` - Formal charge +- `atom.GetNumRadicalElectrons()` - Radical electrons +- `atom.GetIsAromatic()` - Aromaticity flag +- `atom.GetHybridization()` - Hybridization (SP, SP2, SP3, etc.) +- `atom.GetIdx()` - Atom index +- `atom.IsInRing()` - In any ring +- `atom.IsInRingSize(size)` - In ring of specific size +- `atom.GetChiralTag()` - Chirality tag + +**Bond Methods:** + +- `bond.GetBondType()` - Bond type (SINGLE, DOUBLE, TRIPLE, AROMATIC) +- `bond.GetBeginAtomIdx()` - Starting atom index +- `bond.GetEndAtomIdx()` - Ending atom index +- `bond.GetIsConjugated()` - Conjugation flag +- `bond.GetIsAromatic()` - Aromaticity flag +- `bond.IsInRing()` - In any ring +- `bond.GetStereo()` - Stereochemistry (STEREONONE, STEREOZ, STEREOE, etc.) + +**Molecule Methods:** + +- `mol.GetNumAtoms(onlyExplicit=True)` - Number of atoms +- `mol.GetNumHeavyAtoms()` - Number of heavy atoms +- `mol.GetNumBonds()` - Number of bonds +- `mol.GetAtoms()` - Iterator over atoms +- `mol.GetBonds()` - Iterator over bonds +- `mol.GetAtomWithIdx(idx)` - Get specific atom +- `mol.GetBondWithIdx(idx)` - Get specific bond +- `mol.GetRingInfo()` - Ring information object + +**Ring Information:** + +- `Chem.GetSymmSSSR(mol)` - Get smallest set of smallest rings +- `Chem.GetSSSR(mol)` - Alias for GetSymmSSSR +- `ring_info.NumRings()` - Number of rings +- `ring_info.AtomRings()` - Tuples of atom indices in rings +- `ring_info.BondRings()` - Tuples of bond indices in rings + +## rdkit.Chem.AllChem + +Extended chemistry functionality. + +### 2D/3D Coordinate Generation + +- `AllChem.Compute2DCoords(mol, canonOrient=True, clearConfs=True)` - Generate 2D coordinates +- `AllChem.EmbedMolecule(mol, maxAttempts=0, randomSeed=-1, useRandomCoords=False)` - Generate 3D conformer +- `AllChem.EmbedMultipleConfs(mol, numConfs=10, maxAttempts=0, randomSeed=-1)` - Generate multiple conformers +- `AllChem.ConstrainedEmbed(mol, core, useTethers=True)` - Constrained embedding +- `AllChem.GenerateDepictionMatching2DStructure(mol, reference, refPattern=None)` - Align to template + +### Force Field Optimization + +- `AllChem.UFFOptimizeMolecule(mol, maxIters=200, confId=-1)` - UFF optimization +- `AllChem.MMFFOptimizeMolecule(mol, maxIters=200, confId=-1, mmffVariant='MMFF94')` - MMFF optimization +- `AllChem.UFFGetMoleculeForceField(mol, confId=-1)` - Get UFF force field object +- `AllChem.MMFFGetMoleculeForceField(mol, pyMMFFMolProperties, confId=-1)` - Get MMFF force field + +### Conformer Analysis + +- `AllChem.GetConformerRMS(mol, confId1, confId2, prealigned=False)` - Calculate RMSD +- `AllChem.GetConformerRMSMatrix(mol, prealigned=False)` - RMSD matrix +- `AllChem.AlignMol(prbMol, refMol, prbCid=-1, refCid=-1)` - Align molecules +- `AllChem.AlignMolConformers(mol)` - Align all conformers + +### Reactions + +- `AllChem.ReactionFromSmarts(smarts, useSmiles=False)` - Create reaction from SMARTS +- `reaction.RunReactants(reactants)` - Apply reaction +- `reaction.RunReactant(reactant, reactionIdx)` - Apply to specific reactant +- `AllChem.CreateDifferenceFingerprintForReaction(reaction)` - Reaction fingerprint + +### Fingerprints + +- `AllChem.GetMorganFingerprint(mol, radius, useFeatures=False)` - Morgan fingerprint +- `AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=2048)` - Morgan bit vector +- `AllChem.GetHashedMorganFingerprint(mol, radius, nBits=2048)` - Hashed Morgan +- `AllChem.GetErGFingerprint(mol)` - ErG fingerprint + +## rdkit.Chem.Descriptors + +Molecular descriptor calculations. + +### Common Descriptors + +- `Descriptors.MolWt(mol)` - Molecular weight +- `Descriptors.ExactMolWt(mol)` - Exact molecular weight +- `Descriptors.HeavyAtomMolWt(mol)` - Heavy atom molecular weight +- `Descriptors.MolLogP(mol)` - LogP (lipophilicity) +- `Descriptors.MolMR(mol)` - Molar refractivity +- `Descriptors.TPSA(mol)` - Topological polar surface area +- `Descriptors.NumHDonors(mol)` - Hydrogen bond donors +- `Descriptors.NumHAcceptors(mol)` - Hydrogen bond acceptors +- `Descriptors.NumRotatableBonds(mol)` - Rotatable bonds +- `Descriptors.NumAromaticRings(mol)` - Aromatic rings +- `Descriptors.NumSaturatedRings(mol)` - Saturated rings +- `Descriptors.NumAliphaticRings(mol)` - Aliphatic rings +- `Descriptors.NumAromaticHeterocycles(mol)` - Aromatic heterocycles +- `Descriptors.NumRadicalElectrons(mol)` - Radical electrons +- `Descriptors.NumValenceElectrons(mol)` - Valence electrons + +### Batch Calculation + +- `Descriptors.CalcMolDescriptors(mol)` - Calculate all descriptors as dictionary + +### Descriptor Lists + +- `Descriptors._descList` - List of (name, function) tuples for all descriptors + +## rdkit.Chem.Draw + +Molecular visualization. + +### Image Generation + +- `Draw.MolToImage(mol, size=(300,300), kekulize=True, wedgeBonds=True, highlightAtoms=None)` - Generate PIL image +- `Draw.MolToFile(mol, filename, size=(300,300), kekulize=True, wedgeBonds=True)` - Save to file +- `Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(200,200), legends=None)` - Grid of molecules +- `Draw.MolsMatrixToGridImage(mols, molsPerRow=3, subImgSize=(200,200), legends=None)` - Nested grid +- `Draw.ReactionToImage(rxn, subImgSize=(200,200))` - Reaction image + +### Fingerprint Visualization + +- `Draw.DrawMorganBit(mol, bitId, bitInfo, whichExample=0)` - Visualize Morgan bit +- `Draw.DrawMorganBits(bits, mol, bitInfo, molsPerRow=3)` - Multiple Morgan bits +- `Draw.DrawRDKitBit(mol, bitId, bitInfo, whichExample=0)` - Visualize RDKit bit + +### IPython Integration + +- `Draw.IPythonConsole` - Module for Jupyter integration +- `Draw.IPythonConsole.ipython_useSVG` - Use SVG (True) or PNG (False) +- `Draw.IPythonConsole.molSize` - Default molecule image size + +### Drawing Options + +- `rdMolDraw2D.MolDrawOptions()` - Get drawing options object + - `.addAtomIndices` - Show atom indices + - `.addBondIndices` - Show bond indices + - `.addStereoAnnotation` - Show stereochemistry + - `.bondLineWidth` - Line width + - `.highlightBondWidthMultiplier` - Highlight width + - `.minFontSize` - Minimum font size + - `.maxFontSize` - Maximum font size + +## rdkit.Chem.rdMolDescriptors + +Additional descriptor calculations. + +- `rdMolDescriptors.CalcNumRings(mol)` - Number of rings +- `rdMolDescriptors.CalcNumAromaticRings(mol)` - Aromatic rings +- `rdMolDescriptors.CalcNumAliphaticRings(mol)` - Aliphatic rings +- `rdMolDescriptors.CalcNumSaturatedRings(mol)` - Saturated rings +- `rdMolDescriptors.CalcNumHeterocycles(mol)` - Heterocycles +- `rdMolDescriptors.CalcNumAromaticHeterocycles(mol)` - Aromatic heterocycles +- `rdMolDescriptors.CalcNumSpiroAtoms(mol)` - Spiro atoms +- `rdMolDescriptors.CalcNumBridgeheadAtoms(mol)` - Bridgehead atoms +- `rdMolDescriptors.CalcFractionCsp3(mol)` - Fraction of sp3 carbons +- `rdMolDescriptors.CalcLabuteASA(mol)` - Labute accessible surface area +- `rdMolDescriptors.CalcTPSA(mol)` - TPSA +- `rdMolDescriptors.CalcMolFormula(mol)` - Molecular formula + +## rdkit.Chem.Scaffolds + +Scaffold analysis. + +### Murcko Scaffolds + +- `MurckoScaffold.GetScaffoldForMol(mol)` - Get Murcko scaffold +- `MurckoScaffold.MakeScaffoldGeneric(mol)` - Generic scaffold +- `MurckoScaffold.MurckoDecompose(mol)` - Decompose to scaffold and sidechains + +## rdkit.Chem.rdMolHash + +Molecular hashing and standardization. + +- `rdMolHash.MolHash(mol, hashFunction)` - Generate hash + - `rdMolHash.HashFunction.AnonymousGraph` - Anonymized structure + - `rdMolHash.HashFunction.CanonicalSmiles` - Canonical SMILES + - `rdMolHash.HashFunction.ElementGraph` - Element graph + - `rdMolHash.HashFunction.MurckoScaffold` - Murcko scaffold + - `rdMolHash.HashFunction.Regioisomer` - Regioisomer (no stereo) + - `rdMolHash.HashFunction.NetCharge` - Net charge + - `rdMolHash.HashFunction.HetAtomProtomer` - Heteroatom protomer + - `rdMolHash.HashFunction.HetAtomTautomer` - Heteroatom tautomer + +## rdkit.Chem.MolStandardize + +Molecule standardization. + +- `rdMolStandardize.Normalize(mol)` - Normalize functional groups +- `rdMolStandardize.Reionize(mol)` - Fix ionization state +- `rdMolStandardize.RemoveFragments(mol)` - Remove small fragments +- `rdMolStandardize.Cleanup(mol)` - Full cleanup (normalize + reionize + remove) +- `rdMolStandardize.Uncharger()` - Create uncharger object + - `.uncharge(mol)` - Remove charges +- `rdMolStandardize.TautomerEnumerator()` - Enumerate tautomers + - `.Enumerate(mol)` - Generate tautomers + - `.Canonicalize(mol)` - Get canonical tautomer + +## rdkit.DataStructs + +Fingerprint similarity and operations. + +### Similarity Metrics + +- `DataStructs.TanimotoSimilarity(fp1, fp2)` - Tanimoto coefficient +- `DataStructs.DiceSimilarity(fp1, fp2)` - Dice coefficient +- `DataStructs.CosineSimilarity(fp1, fp2)` - Cosine similarity +- `DataStructs.SokalSimilarity(fp1, fp2)` - Sokal similarity +- `DataStructs.KulczynskiSimilarity(fp1, fp2)` - Kulczynski similarity +- `DataStructs.McConnaugheySimilarity(fp1, fp2)` - McConnaughey similarity + +### Bulk Operations + +- `DataStructs.BulkTanimotoSimilarity(fp, fps)` - Tanimoto for list of fingerprints +- `DataStructs.BulkDiceSimilarity(fp, fps)` - Dice for list +- `DataStructs.BulkCosineSimilarity(fp, fps)` - Cosine for list + +### Distance Metrics + +- `DataStructs.TanimotoDistance(fp1, fp2)` - 1 - Tanimoto +- `DataStructs.DiceDistance(fp1, fp2)` - 1 - Dice + +## rdkit.Chem.AtomPairs + +Atom pair fingerprints. + +- `Pairs.GetAtomPairFingerprint(mol, minLength=1, maxLength=30)` - Atom pair fingerprint +- `Pairs.GetAtomPairFingerprintAsBitVect(mol, minLength=1, maxLength=30, nBits=2048)` - As bit vector +- `Pairs.GetHashedAtomPairFingerprint(mol, nBits=2048, minLength=1, maxLength=30)` - Hashed version + +## rdkit.Chem.Torsions + +Topological torsion fingerprints. + +- `Torsions.GetTopologicalTorsionFingerprint(mol, targetSize=4)` - Torsion fingerprint +- `Torsions.GetTopologicalTorsionFingerprintAsIntVect(mol, targetSize=4)` - As int vector +- `Torsions.GetHashedTopologicalTorsionFingerprint(mol, nBits=2048, targetSize=4)` - Hashed version + +## rdkit.Chem.MACCSkeys + +MACCS structural keys. + +- `MACCSkeys.GenMACCSKeys(mol)` - Generate 166-bit MACCS keys + +## rdkit.Chem.ChemicalFeatures + +Pharmacophore features. + +- `ChemicalFeatures.BuildFeatureFactory(featureFile)` - Create feature factory +- `factory.GetFeaturesForMol(mol)` - Get pharmacophore features +- `feature.GetFamily()` - Feature family (Donor, Acceptor, etc.) +- `feature.GetType()` - Feature type +- `feature.GetAtomIds()` - Atoms involved in feature + +## rdkit.ML.Cluster.Butina + +Clustering algorithms. + +- `Butina.ClusterData(distances, nPts, distThresh, isDistData=True)` - Butina clustering + - Returns tuple of tuples with cluster members + +## rdkit.Chem.rdFingerprintGenerator + +Modern fingerprint generation API (RDKit 2020.09+). + +- `rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)` - Morgan generator +- `rdFingerprintGenerator.GetRDKitFPGenerator(minPath=1, maxPath=7, fpSize=2048)` - RDKit FP generator +- `rdFingerprintGenerator.GetAtomPairGenerator(minDistance=1, maxDistance=30)` - Atom pair generator +- `generator.GetFingerprint(mol)` - Generate fingerprint +- `generator.GetCountFingerprint(mol)` - Count-based fingerprint + +## Common Parameters + +### Sanitization Operations + +- `SANITIZE_NONE` - No sanitization +- `SANITIZE_ALL` - All operations (default) +- `SANITIZE_CLEANUP` - Basic cleanup +- `SANITIZE_PROPERTIES` - Calculate properties +- `SANITIZE_SYMMRINGS` - Symmetrize rings +- `SANITIZE_KEKULIZE` - Kekulize aromatic rings +- `SANITIZE_FINDRADICALS` - Find radical electrons +- `SANITIZE_SETAROMATICITY` - Set aromaticity +- `SANITIZE_SETCONJUGATION` - Set conjugation +- `SANITIZE_SETHYBRIDIZATION` - Set hybridization +- `SANITIZE_CLEANUPCHIRALITY` - Cleanup chirality + +### Bond Types + +- `BondType.SINGLE` - Single bond +- `BondType.DOUBLE` - Double bond +- `BondType.TRIPLE` - Triple bond +- `BondType.AROMATIC` - Aromatic bond +- `BondType.DATIVE` - Dative bond +- `BondType.UNSPECIFIED` - Unspecified + +### Hybridization + +- `HybridizationType.S` - S +- `HybridizationType.SP` - SP +- `HybridizationType.SP2` - SP2 +- `HybridizationType.SP3` - SP3 +- `HybridizationType.SP3D` - SP3D +- `HybridizationType.SP3D2` - SP3D2 + +### Chirality + +- `ChiralType.CHI_UNSPECIFIED` - Unspecified +- `ChiralType.CHI_TETRAHEDRAL_CW` - Clockwise +- `ChiralType.CHI_TETRAHEDRAL_CCW` - Counter-clockwise + +## Installation + +```bash +# Using conda (recommended) +conda install -c conda-forge rdkit + +# Using pip +pip install rdkit-pypi +``` + +## Importing + +```python +# Core functionality +from rdkit import Chem +from rdkit.Chem import AllChem + +# Descriptors +from rdkit.Chem import Descriptors + +# Drawing +from rdkit.Chem import Draw + +# Similarity +from rdkit import DataStructs +``` diff --git a/scientific-packages/rdkit/references/descriptors_reference.md b/scientific-packages/rdkit/references/descriptors_reference.md new file mode 100644 index 0000000..2d35287 --- /dev/null +++ b/scientific-packages/rdkit/references/descriptors_reference.md @@ -0,0 +1,595 @@ +# RDKit Molecular Descriptors Reference + +Complete reference for molecular descriptors available in RDKit's `Descriptors` module. + +## Usage + +```python +from rdkit import Chem +from rdkit.Chem import Descriptors + +mol = Chem.MolFromSmiles('CCO') + +# Calculate individual descriptor +mw = Descriptors.MolWt(mol) + +# Calculate all descriptors at once +all_desc = Descriptors.CalcMolDescriptors(mol) +``` + +## Molecular Weight and Mass + +### MolWt +Average molecular weight of the molecule. +```python +Descriptors.MolWt(mol) +``` + +### ExactMolWt +Exact molecular weight using isotopic composition. +```python +Descriptors.ExactMolWt(mol) +``` + +### HeavyAtomMolWt +Average molecular weight ignoring hydrogens. +```python +Descriptors.HeavyAtomMolWt(mol) +``` + +## Lipophilicity + +### MolLogP +Wildman-Crippen LogP (octanol-water partition coefficient). +```python +Descriptors.MolLogP(mol) +``` + +### MolMR +Wildman-Crippen molar refractivity. +```python +Descriptors.MolMR(mol) +``` + +## Polar Surface Area + +### TPSA +Topological polar surface area (TPSA) based on fragment contributions. +```python +Descriptors.TPSA(mol) +``` + +### LabuteASA +Labute's Approximate Surface Area (ASA). +```python +Descriptors.LabuteASA(mol) +``` + +## Hydrogen Bonding + +### NumHDonors +Number of hydrogen bond donors (N-H and O-H). +```python +Descriptors.NumHDonors(mol) +``` + +### NumHAcceptors +Number of hydrogen bond acceptors (N and O). +```python +Descriptors.NumHAcceptors(mol) +``` + +### NOCount +Number of N and O atoms. +```python +Descriptors.NOCount(mol) +``` + +### NHOHCount +Number of N-H and O-H bonds. +```python +Descriptors.NHOHCount(mol) +``` + +## Atom Counts + +### HeavyAtomCount +Number of heavy atoms (non-hydrogen). +```python +Descriptors.HeavyAtomCount(mol) +``` + +### NumHeteroatoms +Number of heteroatoms (non-C and non-H). +```python +Descriptors.NumHeteroatoms(mol) +``` + +### NumValenceElectrons +Total number of valence electrons. +```python +Descriptors.NumValenceElectrons(mol) +``` + +### NumRadicalElectrons +Number of radical electrons. +```python +Descriptors.NumRadicalElectrons(mol) +``` + +## Ring Descriptors + +### RingCount +Number of rings. +```python +Descriptors.RingCount(mol) +``` + +### NumAromaticRings +Number of aromatic rings. +```python +Descriptors.NumAromaticRings(mol) +``` + +### NumSaturatedRings +Number of saturated rings. +```python +Descriptors.NumSaturatedRings(mol) +``` + +### NumAliphaticRings +Number of aliphatic (non-aromatic) rings. +```python +Descriptors.NumAliphaticRings(mol) +``` + +### NumAromaticCarbocycles +Number of aromatic carbocycles (rings with only carbons). +```python +Descriptors.NumAromaticCarbocycles(mol) +``` + +### NumAromaticHeterocycles +Number of aromatic heterocycles (rings with heteroatoms). +```python +Descriptors.NumAromaticHeterocycles(mol) +``` + +### NumSaturatedCarbocycles +Number of saturated carbocycles. +```python +Descriptors.NumSaturatedCarbocycles(mol) +``` + +### NumSaturatedHeterocycles +Number of saturated heterocycles. +```python +Descriptors.NumSaturatedHeterocycles(mol) +``` + +### NumAliphaticCarbocycles +Number of aliphatic carbocycles. +```python +Descriptors.NumAliphaticCarbocycles(mol) +``` + +### NumAliphaticHeterocycles +Number of aliphatic heterocycles. +```python +Descriptors.NumAliphaticHeterocycles(mol) +``` + +## Rotatable Bonds + +### NumRotatableBonds +Number of rotatable bonds (flexibility). +```python +Descriptors.NumRotatableBonds(mol) +``` + +## Aromatic Atoms + +### NumAromaticAtoms +Number of aromatic atoms. +```python +Descriptors.NumAromaticAtoms(mol) +``` + +## Fraction Descriptors + +### FractionCsp3 +Fraction of carbons that are sp3 hybridized. +```python +Descriptors.FractionCsp3(mol) +``` + +## Complexity Descriptors + +### BertzCT +Bertz complexity index. +```python +Descriptors.BertzCT(mol) +``` + +### Ipc +Information content (complexity measure). +```python +Descriptors.Ipc(mol) +``` + +## Kappa Shape Indices + +Molecular shape descriptors based on graph invariants. + +### Kappa1 +First kappa shape index. +```python +Descriptors.Kappa1(mol) +``` + +### Kappa2 +Second kappa shape index. +```python +Descriptors.Kappa2(mol) +``` + +### Kappa3 +Third kappa shape index. +```python +Descriptors.Kappa3(mol) +``` + +## Chi Connectivity Indices + +Molecular connectivity indices. + +### Chi0, Chi1, Chi2, Chi3, Chi4 +Simple chi connectivity indices. +```python +Descriptors.Chi0(mol) +Descriptors.Chi1(mol) +Descriptors.Chi2(mol) +Descriptors.Chi3(mol) +Descriptors.Chi4(mol) +``` + +### Chi0n, Chi1n, Chi2n, Chi3n, Chi4n +Valence-modified chi connectivity indices. +```python +Descriptors.Chi0n(mol) +Descriptors.Chi1n(mol) +Descriptors.Chi2n(mol) +Descriptors.Chi3n(mol) +Descriptors.Chi4n(mol) +``` + +### Chi0v, Chi1v, Chi2v, Chi3v, Chi4v +Valence chi connectivity indices. +```python +Descriptors.Chi0v(mol) +Descriptors.Chi1v(mol) +Descriptors.Chi2v(mol) +Descriptors.Chi3v(mol) +Descriptors.Chi4v(mol) +``` + +## Hall-Kier Alpha + +### HallKierAlpha +Hall-Kier alpha value (molecular flexibility). +```python +Descriptors.HallKierAlpha(mol) +``` + +## Balaban's J Index + +### BalabanJ +Balaban's J index (branching descriptor). +```python +Descriptors.BalabanJ(mol) +``` + +## EState Indices + +Electrotopological state indices. + +### MaxEStateIndex +Maximum E-state value. +```python +Descriptors.MaxEStateIndex(mol) +``` + +### MinEStateIndex +Minimum E-state value. +```python +Descriptors.MinEStateIndex(mol) +``` + +### MaxAbsEStateIndex +Maximum absolute E-state value. +```python +Descriptors.MaxAbsEStateIndex(mol) +``` + +### MinAbsEStateIndex +Minimum absolute E-state value. +```python +Descriptors.MinAbsEStateIndex(mol) +``` + +## Partial Charges + +### MaxPartialCharge +Maximum partial charge. +```python +Descriptors.MaxPartialCharge(mol) +``` + +### MinPartialCharge +Minimum partial charge. +```python +Descriptors.MinPartialCharge(mol) +``` + +### MaxAbsPartialCharge +Maximum absolute partial charge. +```python +Descriptors.MaxAbsPartialCharge(mol) +``` + +### MinAbsPartialCharge +Minimum absolute partial charge. +```python +Descriptors.MinAbsPartialCharge(mol) +``` + +## Fingerprint Density + +Measures the density of molecular fingerprints. + +### FpDensityMorgan1 +Morgan fingerprint density at radius 1. +```python +Descriptors.FpDensityMorgan1(mol) +``` + +### FpDensityMorgan2 +Morgan fingerprint density at radius 2. +```python +Descriptors.FpDensityMorgan2(mol) +``` + +### FpDensityMorgan3 +Morgan fingerprint density at radius 3. +```python +Descriptors.FpDensityMorgan3(mol) +``` + +## PEOE VSA Descriptors + +Partial Equalization of Orbital Electronegativities (PEOE) VSA descriptors. + +### PEOE_VSA1 through PEOE_VSA14 +MOE-type descriptors using partial charges and surface area contributions. +```python +Descriptors.PEOE_VSA1(mol) +# ... through PEOE_VSA14 +``` + +## SMR VSA Descriptors + +Molecular refractivity VSA descriptors. + +### SMR_VSA1 through SMR_VSA10 +MOE-type descriptors using MR contributions and surface area. +```python +Descriptors.SMR_VSA1(mol) +# ... through SMR_VSA10 +``` + +## SLogP VSA Descriptors + +LogP VSA descriptors. + +### SLogP_VSA1 through SLogP_VSA12 +MOE-type descriptors using LogP contributions and surface area. +```python +Descriptors.SLogP_VSA1(mol) +# ... through SLogP_VSA12 +``` + +## EState VSA Descriptors + +### EState_VSA1 through EState_VSA11 +MOE-type descriptors using E-state indices and surface area. +```python +Descriptors.EState_VSA1(mol) +# ... through EState_VSA11 +``` + +## VSA Descriptors + +van der Waals surface area descriptors. + +### VSA_EState1 through VSA_EState10 +EState VSA descriptors. +```python +Descriptors.VSA_EState1(mol) +# ... through VSA_EState10 +``` + +## BCUT Descriptors + +Burden-CAS-University of Texas eigenvalue descriptors. + +### BCUT2D_MWHI +Highest eigenvalue of Burden matrix weighted by molecular weight. +```python +Descriptors.BCUT2D_MWHI(mol) +``` + +### BCUT2D_MWLOW +Lowest eigenvalue of Burden matrix weighted by molecular weight. +```python +Descriptors.BCUT2D_MWLOW(mol) +``` + +### BCUT2D_CHGHI +Highest eigenvalue weighted by partial charges. +```python +Descriptors.BCUT2D_CHGHI(mol) +``` + +### BCUT2D_CHGLO +Lowest eigenvalue weighted by partial charges. +```python +Descriptors.BCUT2D_CHGLO(mol) +``` + +### BCUT2D_LOGPHI +Highest eigenvalue weighted by LogP. +```python +Descriptors.BCUT2D_LOGPHI(mol) +``` + +### BCUT2D_LOGPLOW +Lowest eigenvalue weighted by LogP. +```python +Descriptors.BCUT2D_LOGPLOW(mol) +``` + +### BCUT2D_MRHI +Highest eigenvalue weighted by molar refractivity. +```python +Descriptors.BCUT2D_MRHI(mol) +``` + +### BCUT2D_MRLOW +Lowest eigenvalue weighted by molar refractivity. +```python +Descriptors.BCUT2D_MRLOW(mol) +``` + +## Autocorrelation Descriptors + +### AUTOCORR2D +2D autocorrelation descriptors (if enabled). +Various autocorrelation indices measuring spatial distribution of properties. + +## MQN Descriptors + +Molecular Quantum Numbers - 42 simple descriptors. + +### mqn1 through mqn42 +Integer descriptors counting various molecular features. +```python +# Access via CalcMolDescriptors +desc = Descriptors.CalcMolDescriptors(mol) +mqns = {k: v for k, v in desc.items() if k.startswith('mqn')} +``` + +## QED + +### qed +Quantitative Estimate of Drug-likeness. +```python +Descriptors.qed(mol) +``` + +## Lipinski's Rule of Five + +Check drug-likeness using Lipinski's criteria: + +```python +def lipinski_rule_of_five(mol): + mw = Descriptors.MolWt(mol) <= 500 + logp = Descriptors.MolLogP(mol) <= 5 + hbd = Descriptors.NumHDonors(mol) <= 5 + hba = Descriptors.NumHAcceptors(mol) <= 10 + return mw and logp and hbd and hba +``` + +## Batch Descriptor Calculation + +Calculate all descriptors at once: + +```python +from rdkit import Chem +from rdkit.Chem import Descriptors + +mol = Chem.MolFromSmiles('CCO') + +# Get all descriptors as dictionary +all_descriptors = Descriptors.CalcMolDescriptors(mol) + +# Access specific descriptor +mw = all_descriptors['MolWt'] +logp = all_descriptors['MolLogP'] + +# Get list of available descriptor names +from rdkit.Chem import Descriptors +descriptor_names = [desc[0] for desc in Descriptors._descList] +``` + +## Descriptor Categories Summary + +1. **Physicochemical**: MolWt, MolLogP, MolMR, TPSA +2. **Topological**: BertzCT, BalabanJ, Kappa indices +3. **Electronic**: Partial charges, E-state indices +4. **Shape**: Kappa indices, BCUT descriptors +5. **Connectivity**: Chi indices +6. **2D Fingerprints**: FpDensity descriptors +7. **Atom counts**: Heavy atoms, heteroatoms, rings +8. **Drug-likeness**: QED, Lipinski parameters +9. **Flexibility**: NumRotatableBonds, HallKierAlpha +10. **Surface area**: VSA-based descriptors + +## Common Use Cases + +### Drug-likeness Screening + +```python +def screen_druglikeness(mol): + return { + 'MW': Descriptors.MolWt(mol), + 'LogP': Descriptors.MolLogP(mol), + 'HBD': Descriptors.NumHDonors(mol), + 'HBA': Descriptors.NumHAcceptors(mol), + 'TPSA': Descriptors.TPSA(mol), + 'RotBonds': Descriptors.NumRotatableBonds(mol), + 'AromaticRings': Descriptors.NumAromaticRings(mol), + 'QED': Descriptors.qed(mol) + } +``` + +### Lead-like Filtering + +```python +def is_leadlike(mol): + mw = 250 <= Descriptors.MolWt(mol) <= 350 + logp = Descriptors.MolLogP(mol) <= 3.5 + rot_bonds = Descriptors.NumRotatableBonds(mol) <= 7 + return mw and logp and rot_bonds +``` + +### Diversity Analysis + +```python +def molecular_complexity(mol): + return { + 'BertzCT': Descriptors.BertzCT(mol), + 'NumRings': Descriptors.RingCount(mol), + 'NumRotBonds': Descriptors.NumRotatableBonds(mol), + 'FractionCsp3': Descriptors.FractionCsp3(mol), + 'NumAromaticRings': Descriptors.NumAromaticRings(mol) + } +``` + +## Tips + +1. **Use batch calculation** for multiple descriptors to avoid redundant computations +2. **Check for None** - some descriptors may return None for invalid molecules +3. **Normalize descriptors** for machine learning applications +4. **Select relevant descriptors** - not all 200+ descriptors are useful for every task +5. **Consider 3D descriptors** separately (require 3D coordinates) +6. **Validate ranges** - check if descriptor values are in expected ranges diff --git a/scientific-packages/rdkit/references/smarts_patterns.md b/scientific-packages/rdkit/references/smarts_patterns.md new file mode 100644 index 0000000..7b7caff --- /dev/null +++ b/scientific-packages/rdkit/references/smarts_patterns.md @@ -0,0 +1,668 @@ +# Common SMARTS Patterns for RDKit + +This document provides a collection of commonly used SMARTS patterns for substructure searching in RDKit. + +## Functional Groups + +### Alcohols + +```python +# Primary alcohol +'[CH2][OH1]' + +# Secondary alcohol +'[CH1]([OH1])[CH3,CH2]' + +# Tertiary alcohol +'[C]([OH1])([C])([C])[C]' + +# Any alcohol +'[OH1][C]' + +# Phenol +'c[OH1]' +``` + +### Aldehydes and Ketones + +```python +# Aldehyde +'[CH1](=O)' + +# Ketone +'[C](=O)[C]' + +# Any carbonyl +'[C](=O)' +``` + +### Carboxylic Acids and Derivatives + +```python +# Carboxylic acid +'C(=O)[OH1]' +'[CX3](=O)[OX2H1]' # More specific + +# Ester +'C(=O)O[C]' +'[CX3](=O)[OX2][C]' # More specific + +# Amide +'C(=O)N' +'[CX3](=O)[NX3]' # More specific + +# Acyl chloride +'C(=O)Cl' + +# Anhydride +'C(=O)OC(=O)' +``` + +### Amines + +```python +# Primary amine +'[NH2][C]' + +# Secondary amine +'[NH1]([C])[C]' + +# Tertiary amine +'[N]([C])([C])[C]' + +# Aromatic amine (aniline) +'c[NH2]' + +# Any amine +'[NX3]' +``` + +### Ethers + +```python +# Aliphatic ether +'[C][O][C]' + +# Aromatic ether +'c[O][C,c]' +``` + +### Halides + +```python +# Alkyl halide +'[C][F,Cl,Br,I]' + +# Aryl halide +'c[F,Cl,Br,I]' + +# Specific halides +'[C]F' # Fluoride +'[C]Cl' # Chloride +'[C]Br' # Bromide +'[C]I' # Iodide +``` + +### Nitriles and Nitro Groups + +```python +# Nitrile +'C#N' + +# Nitro group +'[N+](=O)[O-]' + +# Nitro on aromatic +'c[N+](=O)[O-]' +``` + +### Thiols and Sulfides + +```python +# Thiol +'[C][SH1]' + +# Sulfide +'[C][S][C]' + +# Disulfide +'[C][S][S][C]' + +# Sulfoxide +'[C][S](=O)[C]' + +# Sulfone +'[C][S](=O)(=O)[C]' +``` + +## Ring Systems + +### Simple Rings + +```python +# Benzene ring +'c1ccccc1' +'[#6]1:[#6]:[#6]:[#6]:[#6]:[#6]:1' # Explicit atoms + +# Cyclohexane +'C1CCCCC1' + +# Cyclopentane +'C1CCCC1' + +# Any 3-membered ring +'[r3]' + +# Any 4-membered ring +'[r4]' + +# Any 5-membered ring +'[r5]' + +# Any 6-membered ring +'[r6]' + +# Any 7-membered ring +'[r7]' +``` + +### Aromatic Rings + +```python +# Aromatic carbon in ring +'[cR]' + +# Aromatic nitrogen in ring (pyridine, etc.) +'[nR]' + +# Aromatic oxygen in ring (furan, etc.) +'[oR]' + +# Aromatic sulfur in ring (thiophene, etc.) +'[sR]' + +# Any aromatic ring +'a1aaaaa1' +``` + +### Heterocycles + +```python +# Pyridine +'n1ccccc1' + +# Pyrrole +'n1cccc1' + +# Furan +'o1cccc1' + +# Thiophene +'s1cccc1' + +# Imidazole +'n1cncc1' + +# Pyrimidine +'n1cnccc1' + +# Thiazole +'n1ccsc1' + +# Oxazole +'n1ccoc1' +``` + +### Fused Rings + +```python +# Naphthalene +'c1ccc2ccccc2c1' + +# Indole +'c1ccc2[nH]ccc2c1' + +# Quinoline +'n1cccc2ccccc12' + +# Benzimidazole +'c1ccc2[nH]cnc2c1' + +# Purine +'n1cnc2ncnc2c1' +``` + +### Macrocycles + +```python +# Rings with 8 or more atoms +'[r{8-}]' + +# Rings with 9-15 atoms +'[r{9-15}]' + +# Rings with more than 12 atoms (macrocycles) +'[r{12-}]' +``` + +## Specific Structural Features + +### Aliphatic vs Aromatic + +```python +# Aliphatic carbon +'[C]' + +# Aromatic carbon +'[c]' + +# Aliphatic carbon in ring +'[CR]' + +# Aromatic carbon (alternative) +'[cR]' +``` + +### Stereochemistry + +```python +# Tetrahedral center with clockwise chirality +'[C@]' + +# Tetrahedral center with counterclockwise chirality +'[C@@]' + +# Any chiral center +'[C@,C@@]' + +# E double bond +'C/C=C/C' + +# Z double bond +'C/C=C\\C' +``` + +### Hybridization + +```python +# SP hybridization (triple bond) +'[CX2]' + +# SP2 hybridization (double bond or aromatic) +'[CX3]' + +# SP3 hybridization (single bonds) +'[CX4]' +``` + +### Charge + +```python +# Positive charge +'[+]' + +# Negative charge +'[-]' + +# Specific charge +'[+1]' +'[-1]' +'[+2]' + +# Positively charged nitrogen +'[N+]' + +# Negatively charged oxygen +'[O-]' + +# Carboxylate anion +'C(=O)[O-]' + +# Ammonium cation +'[N+]([C])([C])([C])[C]' +``` + +## Pharmacophore Features + +### Hydrogen Bond Donors + +```python +# Hydroxyl +'[OH]' + +# Amine +'[NH,NH2]' + +# Amide NH +'[N][C](=O)' + +# Any H-bond donor +'[OH,NH,NH2,NH3+]' +``` + +### Hydrogen Bond Acceptors + +```python +# Carbonyl oxygen +'[O]=[C,S,P]' + +# Ether oxygen +'[OX2]' + +# Ester oxygen +'C(=O)[O]' + +# Nitrogen acceptor +'[N;!H0]' + +# Any H-bond acceptor +'[O,N]' +``` + +### Hydrophobic Groups + +```python +# Alkyl chain (4+ carbons) +'CCCC' + +# Branched alkyl +'C(C)(C)C' + +# Aromatic rings (hydrophobic) +'c1ccccc1' +``` + +### Aromatic Interactions + +```python +# Benzene for pi-pi stacking +'c1ccccc1' + +# Heterocycle for pi-pi +'[a]1[a][a][a][a][a]1' + +# Any aromatic ring +'[aR]' +``` + +## Drug-like Fragments + +### Lipinski Fragments + +```python +# Aromatic ring with substituents +'c1cc(*)ccc1' + +# Aliphatic chain +'CCCC' + +# Ether linkage +'[C][O][C]' + +# Amine (basic center) +'[N]([C])([C])' +``` + +### Common Scaffolds + +```python +# Benzamide +'c1ccccc1C(=O)N' + +# Sulfonamide +'S(=O)(=O)N' + +# Urea +'[N][C](=O)[N]' + +# Guanidine +'[N]C(=[N])[N]' + +# Phosphate +'P(=O)([O-])([O-])[O-]' +``` + +### Privileged Structures + +```python +# Biphenyl +'c1ccccc1-c2ccccc2' + +# Benzopyran +'c1ccc2OCCCc2c1' + +# Piperazine +'N1CCNCC1' + +# Piperidine +'N1CCCCC1' + +# Morpholine +'N1CCOCC1' +``` + +## Reactive Groups + +### Electrophiles + +```python +# Acyl chloride +'C(=O)Cl' + +# Alkyl halide +'[C][Cl,Br,I]' + +# Epoxide +'C1OC1' + +# Michael acceptor +'C=C[C](=O)' +``` + +### Nucleophiles + +```python +# Primary amine +'[NH2][C]' + +# Thiol +'[SH][C]' + +# Alcohol +'[OH][C]' +``` + +## Toxicity Alerts (PAINS) + +```python +# Rhodanine +'S1C(=O)NC(=S)C1' + +# Catechol +'c1ccc(O)c(O)c1' + +# Quinone +'O=C1C=CC(=O)C=C1' + +# Hydroquinone +'OC1=CC=C(O)C=C1' + +# Alkyl halide (reactive) +'[C][I,Br]' + +# Michael acceptor (reactive) +'C=CC(=O)[C,N]' +``` + +## Metal Binding + +```python +# Carboxylate (metal chelator) +'C(=O)[O-]' + +# Hydroxamic acid +'C(=O)N[OH]' + +# Catechol (iron chelator) +'c1c(O)c(O)ccc1' + +# Thiol (metal binding) +'[SH]' + +# Histidine-like (metal binding) +'c1ncnc1' +``` + +## Size and Complexity Filters + +```python +# Long aliphatic chains (>6 carbons) +'CCCCCCC' + +# Highly branched (quaternary carbon) +'C(C)(C)(C)C' + +# Multiple rings +'[R]~[R]' # Two rings connected + +# Spiro center +'[C]12[C][C][C]1[C][C]2' +``` + +## Special Patterns + +### Atom Counts + +```python +# Any atom +'[*]' + +# Heavy atom (not H) +'[!H]' + +# Carbon +'[C,c]' + +# Heteroatom +'[!C;!H]' + +# Halogen +'[F,Cl,Br,I]' +``` + +### Bond Types + +```python +# Single bond +'C-C' + +# Double bond +'C=C' + +# Triple bond +'C#C' + +# Aromatic bond +'c:c' + +# Any bond +'C~C' +``` + +### Ring Membership + +```python +# In any ring +'[R]' + +# Not in ring +'[!R]' + +# In exactly one ring +'[R1]' + +# In exactly two rings +'[R2]' + +# Ring bond +'[R]~[R]' +``` + +### Degree and Connectivity + +```python +# Total degree 1 (terminal atom) +'[D1]' + +# Total degree 2 (chain) +'[D2]' + +# Total degree 3 (branch point) +'[D3]' + +# Total degree 4 (highly branched) +'[D4]' + +# Connected to exactly 2 carbons +'[C]([C])[C]' +``` + +## Usage Examples + +```python +from rdkit import Chem + +# Create SMARTS query +pattern = Chem.MolFromSmarts('[CH2][OH1]') # Primary alcohol + +# Search molecule +mol = Chem.MolFromSmiles('CCO') +matches = mol.GetSubstructMatches(pattern) + +# Multiple patterns +patterns = { + 'alcohol': '[OH1][C]', + 'amine': '[NH2,NH1][C]', + 'carboxylic_acid': 'C(=O)[OH1]' +} + +# Check for functional groups +for name, smarts in patterns.items(): + query = Chem.MolFromSmarts(smarts) + if mol.HasSubstructMatch(query): + print(f"Found {name}") +``` + +## Tips for Writing SMARTS + +1. **Be specific when needed:** Use atom properties [CX3] instead of just [C] +2. **Use brackets for clarity:** [C] is different from C (aromatic) +3. **Consider aromaticity:** lowercase letters (c, n, o) are aromatic +4. **Check ring membership:** [R] for in-ring, [!R] for not in-ring +5. **Use recursive SMARTS:** $(...) for complex patterns +6. **Test patterns:** Always validate SMARTS on known molecules +7. **Start simple:** Build complex patterns incrementally + +## Common SMARTS Syntax + +- `[C]` - Aliphatic carbon +- `[c]` - Aromatic carbon +- `[CX4]` - Carbon with 4 connections (sp3) +- `[CX3]` - Carbon with 3 connections (sp2) +- `[CX2]` - Carbon with 2 connections (sp) +- `[CH3]` - Methyl group +- `[R]` - In ring +- `[r6]` - In 6-membered ring +- `[r{5-7}]` - In 5, 6, or 7-membered ring +- `[D2]` - Degree 2 (2 neighbors) +- `[+]` - Positive charge +- `[-]` - Negative charge +- `[!C]` - Not carbon +- `[#6]` - Element with atomic number 6 (carbon) +- `~` - Any bond type +- `-` - Single bond +- `=` - Double bond +- `#` - Triple bond +- `:` - Aromatic bond +- `@` - Clockwise chirality +- `@@` - Counter-clockwise chirality diff --git a/scientific-packages/rdkit/scripts/molecular_properties.py b/scientific-packages/rdkit/scripts/molecular_properties.py new file mode 100644 index 0000000..4444a7d --- /dev/null +++ b/scientific-packages/rdkit/scripts/molecular_properties.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" +Molecular Properties Calculator + +Calculate comprehensive molecular properties and descriptors for molecules. +Supports single molecules or batch processing from files. + +Usage: + python molecular_properties.py "CCO" + python molecular_properties.py --file molecules.smi --output properties.csv +""" + +import argparse +import sys +from pathlib import Path + +try: + from rdkit import Chem + from rdkit.Chem import Descriptors, Lipinski +except ImportError: + print("Error: RDKit not installed. Install with: conda install -c conda-forge rdkit") + sys.exit(1) + + +def calculate_properties(mol): + """Calculate comprehensive molecular properties.""" + if mol is None: + return None + + properties = { + # Basic properties + 'SMILES': Chem.MolToSmiles(mol), + 'Molecular_Formula': Chem.rdMolDescriptors.CalcMolFormula(mol), + + # Molecular weight + 'MW': Descriptors.MolWt(mol), + 'ExactMW': Descriptors.ExactMolWt(mol), + + # Lipophilicity + 'LogP': Descriptors.MolLogP(mol), + 'MR': Descriptors.MolMR(mol), + + # Polar surface area + 'TPSA': Descriptors.TPSA(mol), + 'LabuteASA': Descriptors.LabuteASA(mol), + + # Hydrogen bonding + 'HBD': Descriptors.NumHDonors(mol), + 'HBA': Descriptors.NumHAcceptors(mol), + + # Atom counts + 'Heavy_Atoms': Descriptors.HeavyAtomCount(mol), + 'Heteroatoms': Descriptors.NumHeteroatoms(mol), + 'Valence_Electrons': Descriptors.NumValenceElectrons(mol), + + # Ring information + 'Rings': Descriptors.RingCount(mol), + 'Aromatic_Rings': Descriptors.NumAromaticRings(mol), + 'Saturated_Rings': Descriptors.NumSaturatedRings(mol), + 'Aliphatic_Rings': Descriptors.NumAliphaticRings(mol), + 'Aromatic_Heterocycles': Descriptors.NumAromaticHeterocycles(mol), + + # Flexibility + 'Rotatable_Bonds': Descriptors.NumRotatableBonds(mol), + 'Fraction_Csp3': Descriptors.FractionCsp3(mol), + + # Complexity + 'BertzCT': Descriptors.BertzCT(mol), + + # Drug-likeness + 'QED': Descriptors.qed(mol), + } + + # Lipinski's Rule of Five + properties['Lipinski_Pass'] = ( + properties['MW'] <= 500 and + properties['LogP'] <= 5 and + properties['HBD'] <= 5 and + properties['HBA'] <= 10 + ) + + # Lead-likeness + properties['Lead-like'] = ( + 250 <= properties['MW'] <= 350 and + properties['LogP'] <= 3.5 and + properties['Rotatable_Bonds'] <= 7 + ) + + return properties + + +def process_single_molecule(smiles): + """Process a single SMILES string.""" + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"Error: Failed to parse SMILES: {smiles}") + return None + + props = calculate_properties(mol) + return props + + +def process_file(input_file, output_file=None): + """Process molecules from a file.""" + input_path = Path(input_file) + + if not input_path.exists(): + print(f"Error: File not found: {input_file}") + return + + # Determine file type + if input_path.suffix.lower() in ['.sdf', '.mol']: + suppl = Chem.SDMolSupplier(str(input_path)) + elif input_path.suffix.lower() in ['.smi', '.smiles', '.txt']: + suppl = Chem.SmilesMolSupplier(str(input_path), titleLine=False) + else: + print(f"Error: Unsupported file format: {input_path.suffix}") + return + + results = [] + for idx, mol in enumerate(suppl): + if mol is None: + print(f"Warning: Failed to parse molecule {idx+1}") + continue + + props = calculate_properties(mol) + if props: + props['Index'] = idx + 1 + results.append(props) + + # Output results + if output_file: + write_csv(results, output_file) + print(f"Results written to: {output_file}") + else: + # Print to console + for props in results: + print("\n" + "="*60) + for key, value in props.items(): + print(f"{key:25s}: {value}") + + return results + + +def write_csv(results, output_file): + """Write results to CSV file.""" + import csv + + if not results: + print("No results to write") + return + + with open(output_file, 'w', newline='') as f: + fieldnames = results[0].keys() + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + + +def print_properties(props): + """Print properties in formatted output.""" + print("\nMolecular Properties:") + print("="*60) + + # Group related properties + print("\n[Basic Information]") + print(f" SMILES: {props['SMILES']}") + print(f" Formula: {props['Molecular_Formula']}") + + print("\n[Size & Weight]") + print(f" Molecular Weight: {props['MW']:.2f}") + print(f" Exact MW: {props['ExactMW']:.4f}") + print(f" Heavy Atoms: {props['Heavy_Atoms']}") + print(f" Heteroatoms: {props['Heteroatoms']}") + + print("\n[Lipophilicity]") + print(f" LogP: {props['LogP']:.2f}") + print(f" Molar Refractivity: {props['MR']:.2f}") + + print("\n[Polarity]") + print(f" TPSA: {props['TPSA']:.2f}") + print(f" Labute ASA: {props['LabuteASA']:.2f}") + print(f" H-bond Donors: {props['HBD']}") + print(f" H-bond Acceptors: {props['HBA']}") + + print("\n[Ring Systems]") + print(f" Total Rings: {props['Rings']}") + print(f" Aromatic Rings: {props['Aromatic_Rings']}") + print(f" Saturated Rings: {props['Saturated_Rings']}") + print(f" Aliphatic Rings: {props['Aliphatic_Rings']}") + print(f" Aromatic Heterocycles: {props['Aromatic_Heterocycles']}") + + print("\n[Flexibility & Complexity]") + print(f" Rotatable Bonds: {props['Rotatable_Bonds']}") + print(f" Fraction Csp3: {props['Fraction_Csp3']:.3f}") + print(f" Bertz Complexity: {props['BertzCT']:.1f}") + + print("\n[Drug-likeness]") + print(f" QED Score: {props['QED']:.3f}") + print(f" Lipinski Pass: {'Yes' if props['Lipinski_Pass'] else 'No'}") + print(f" Lead-like: {'Yes' if props['Lead-like'] else 'No'}") + print("="*60) + + +def main(): + parser = argparse.ArgumentParser( + description='Calculate molecular properties for molecules', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Single molecule + python molecular_properties.py "CCO" + + # From file + python molecular_properties.py --file molecules.smi + + # Save to CSV + python molecular_properties.py --file molecules.sdf --output properties.csv + """ + ) + + parser.add_argument('smiles', nargs='?', help='SMILES string to analyze') + parser.add_argument('--file', '-f', help='Input file (SDF or SMILES)') + parser.add_argument('--output', '-o', help='Output CSV file') + + args = parser.parse_args() + + if not args.smiles and not args.file: + parser.print_help() + sys.exit(1) + + if args.smiles: + # Process single molecule + props = process_single_molecule(args.smiles) + if props: + print_properties(props) + elif args.file: + # Process file + process_file(args.file, args.output) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/rdkit/scripts/similarity_search.py b/scientific-packages/rdkit/scripts/similarity_search.py new file mode 100644 index 0000000..4469ef1 --- /dev/null +++ b/scientific-packages/rdkit/scripts/similarity_search.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +""" +Molecular Similarity Search + +Perform fingerprint-based similarity screening against a database of molecules. +Supports multiple fingerprint types and similarity metrics. + +Usage: + python similarity_search.py "CCO" database.smi --threshold 0.7 + python similarity_search.py query.smi database.sdf --method morgan --output hits.csv +""" + +import argparse +import sys +from pathlib import Path + +try: + from rdkit import Chem + from rdkit.Chem import AllChem, MACCSkeys + from rdkit import DataStructs +except ImportError: + print("Error: RDKit not installed. Install with: conda install -c conda-forge rdkit") + sys.exit(1) + + +FINGERPRINT_METHODS = { + 'morgan': 'Morgan fingerprint (ECFP-like)', + 'rdkit': 'RDKit topological fingerprint', + 'maccs': 'MACCS structural keys', + 'atompair': 'Atom pair fingerprint', + 'torsion': 'Topological torsion fingerprint' +} + + +def generate_fingerprint(mol, method='morgan', radius=2, n_bits=2048): + """Generate molecular fingerprint based on specified method.""" + if mol is None: + return None + + method = method.lower() + + if method == 'morgan': + return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) + elif method == 'rdkit': + return Chem.RDKFingerprint(mol, maxPath=7, fpSize=n_bits) + elif method == 'maccs': + return MACCSkeys.GenMACCSKeys(mol) + elif method == 'atompair': + from rdkit.Chem.AtomPairs import Pairs + return Pairs.GetAtomPairFingerprintAsBitVect(mol, nBits=n_bits) + elif method == 'torsion': + from rdkit.Chem.AtomPairs import Torsions + return Torsions.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=n_bits) + else: + raise ValueError(f"Unknown fingerprint method: {method}") + + +def load_molecules(file_path): + """Load molecules from file.""" + path = Path(file_path) + + if not path.exists(): + print(f"Error: File not found: {file_path}") + return [] + + molecules = [] + + if path.suffix.lower() in ['.sdf', '.mol']: + suppl = Chem.SDMolSupplier(str(path)) + elif path.suffix.lower() in ['.smi', '.smiles', '.txt']: + suppl = Chem.SmilesMolSupplier(str(path), titleLine=False) + else: + print(f"Error: Unsupported file format: {path.suffix}") + return [] + + for idx, mol in enumerate(suppl): + if mol is None: + print(f"Warning: Failed to parse molecule {idx+1}") + continue + + # Try to get molecule name + name = mol.GetProp('_Name') if mol.HasProp('_Name') else f"Mol_{idx+1}" + smiles = Chem.MolToSmiles(mol) + + molecules.append({ + 'index': idx + 1, + 'name': name, + 'smiles': smiles, + 'mol': mol + }) + + return molecules + + +def similarity_search(query_mol, database, method='morgan', threshold=0.7, + radius=2, n_bits=2048, metric='tanimoto'): + """ + Perform similarity search. + + Args: + query_mol: Query molecule (RDKit Mol object) + database: List of database molecules + method: Fingerprint method + threshold: Similarity threshold (0-1) + radius: Morgan fingerprint radius + n_bits: Fingerprint size + metric: Similarity metric (tanimoto, dice, cosine) + + Returns: + List of hits with similarity scores + """ + if query_mol is None: + print("Error: Invalid query molecule") + return [] + + # Generate query fingerprint + query_fp = generate_fingerprint(query_mol, method, radius, n_bits) + if query_fp is None: + print("Error: Failed to generate query fingerprint") + return [] + + # Choose similarity function + if metric.lower() == 'tanimoto': + sim_func = DataStructs.TanimotoSimilarity + elif metric.lower() == 'dice': + sim_func = DataStructs.DiceSimilarity + elif metric.lower() == 'cosine': + sim_func = DataStructs.CosineSimilarity + else: + raise ValueError(f"Unknown similarity metric: {metric}") + + # Search database + hits = [] + for db_entry in database: + db_fp = generate_fingerprint(db_entry['mol'], method, radius, n_bits) + if db_fp is None: + continue + + similarity = sim_func(query_fp, db_fp) + + if similarity >= threshold: + hits.append({ + 'index': db_entry['index'], + 'name': db_entry['name'], + 'smiles': db_entry['smiles'], + 'similarity': similarity + }) + + # Sort by similarity (descending) + hits.sort(key=lambda x: x['similarity'], reverse=True) + + return hits + + +def write_results(hits, output_file): + """Write results to CSV file.""" + import csv + + with open(output_file, 'w', newline='') as f: + fieldnames = ['Rank', 'Index', 'Name', 'SMILES', 'Similarity'] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for rank, hit in enumerate(hits, 1): + writer.writerow({ + 'Rank': rank, + 'Index': hit['index'], + 'Name': hit['name'], + 'SMILES': hit['smiles'], + 'Similarity': f"{hit['similarity']:.4f}" + }) + + +def print_results(hits, max_display=20): + """Print results to console.""" + if not hits: + print("\nNo hits found above threshold") + return + + print(f"\nFound {len(hits)} similar molecules:") + print("="*80) + print(f"{'Rank':<6} {'Index':<8} {'Similarity':<12} {'Name':<20} {'SMILES'}") + print("-"*80) + + for rank, hit in enumerate(hits[:max_display], 1): + name = hit['name'][:18] + '..' if len(hit['name']) > 20 else hit['name'] + smiles = hit['smiles'][:40] + '...' if len(hit['smiles']) > 43 else hit['smiles'] + print(f"{rank:<6} {hit['index']:<8} {hit['similarity']:<12.4f} {name:<20} {smiles}") + + if len(hits) > max_display: + print(f"\n... and {len(hits) - max_display} more") + + print("="*80) + + +def main(): + parser = argparse.ArgumentParser( + description='Molecular similarity search using fingerprints', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Available fingerprint methods: +{chr(10).join(f' {k:12s} - {v}' for k, v in FINGERPRINT_METHODS.items())} + +Similarity metrics: + tanimoto - Tanimoto coefficient (default) + dice - Dice coefficient + cosine - Cosine similarity + +Examples: + # Search with SMILES query + python similarity_search.py "CCO" database.smi --threshold 0.7 + + # Use different fingerprint + python similarity_search.py query.smi database.sdf --method maccs + + # Save results + python similarity_search.py "c1ccccc1" database.smi --output hits.csv + + # Adjust Morgan radius + python similarity_search.py "CCO" database.smi --method morgan --radius 3 + """ + ) + + parser.add_argument('query', help='Query SMILES or file') + parser.add_argument('database', help='Database file (SDF or SMILES)') + parser.add_argument('--method', '-m', default='morgan', + choices=FINGERPRINT_METHODS.keys(), + help='Fingerprint method (default: morgan)') + parser.add_argument('--threshold', '-t', type=float, default=0.7, + help='Similarity threshold (default: 0.7)') + parser.add_argument('--radius', '-r', type=int, default=2, + help='Morgan fingerprint radius (default: 2)') + parser.add_argument('--bits', '-b', type=int, default=2048, + help='Fingerprint size (default: 2048)') + parser.add_argument('--metric', default='tanimoto', + choices=['tanimoto', 'dice', 'cosine'], + help='Similarity metric (default: tanimoto)') + parser.add_argument('--output', '-o', help='Output CSV file') + parser.add_argument('--max-display', type=int, default=20, + help='Maximum hits to display (default: 20)') + + args = parser.parse_args() + + # Load query + query_path = Path(args.query) + if query_path.exists(): + # Query is a file + query_mols = load_molecules(args.query) + if not query_mols: + print("Error: No valid molecules in query file") + sys.exit(1) + query_mol = query_mols[0]['mol'] + query_smiles = query_mols[0]['smiles'] + else: + # Query is SMILES string + query_mol = Chem.MolFromSmiles(args.query) + query_smiles = args.query + if query_mol is None: + print(f"Error: Failed to parse query SMILES: {args.query}") + sys.exit(1) + + print(f"Query: {query_smiles}") + print(f"Method: {args.method}") + print(f"Threshold: {args.threshold}") + print(f"Loading database: {args.database}...") + + # Load database + database = load_molecules(args.database) + if not database: + print("Error: No valid molecules in database") + sys.exit(1) + + print(f"Loaded {len(database)} molecules") + print("Searching...") + + # Perform search + hits = similarity_search( + query_mol, database, + method=args.method, + threshold=args.threshold, + radius=args.radius, + n_bits=args.bits, + metric=args.metric + ) + + # Output results + if args.output: + write_results(hits, args.output) + print(f"\nResults written to: {args.output}") + + print_results(hits, args.max_display) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/rdkit/scripts/substructure_filter.py b/scientific-packages/rdkit/scripts/substructure_filter.py new file mode 100644 index 0000000..596cfc3 --- /dev/null +++ b/scientific-packages/rdkit/scripts/substructure_filter.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +""" +Substructure Filter + +Filter molecules based on substructure patterns using SMARTS. +Supports inclusion and exclusion filters, and custom pattern libraries. + +Usage: + python substructure_filter.py molecules.smi --pattern "c1ccccc1" --output filtered.smi + python substructure_filter.py database.sdf --exclude "C(=O)Cl" --filter-type functional-groups +""" + +import argparse +import sys +from pathlib import Path + +try: + from rdkit import Chem +except ImportError: + print("Error: RDKit not installed. Install with: conda install -c conda-forge rdkit") + sys.exit(1) + + +# Common SMARTS pattern libraries +PATTERN_LIBRARIES = { + 'functional-groups': { + 'alcohol': '[OH][C]', + 'aldehyde': '[CH1](=O)', + 'ketone': '[C](=O)[C]', + 'carboxylic_acid': 'C(=O)[OH]', + 'ester': 'C(=O)O[C]', + 'amide': 'C(=O)N', + 'amine': '[NX3]', + 'ether': '[C][O][C]', + 'nitrile': 'C#N', + 'nitro': '[N+](=O)[O-]', + 'halide': '[C][F,Cl,Br,I]', + 'thiol': '[C][SH]', + 'sulfide': '[C][S][C]', + }, + 'rings': { + 'benzene': 'c1ccccc1', + 'pyridine': 'n1ccccc1', + 'pyrrole': 'n1cccc1', + 'furan': 'o1cccc1', + 'thiophene': 's1cccc1', + 'imidazole': 'n1cncc1', + 'indole': 'c1ccc2[nH]ccc2c1', + 'naphthalene': 'c1ccc2ccccc2c1', + }, + 'pains': { + 'rhodanine': 'S1C(=O)NC(=S)C1', + 'catechol': 'c1ccc(O)c(O)c1', + 'quinone': 'O=C1C=CC(=O)C=C1', + 'michael_acceptor': 'C=CC(=O)', + 'alkyl_halide': '[C][I,Br]', + }, + 'privileged': { + 'biphenyl': 'c1ccccc1-c2ccccc2', + 'piperazine': 'N1CCNCC1', + 'piperidine': 'N1CCCCC1', + 'morpholine': 'N1CCOCC1', + } +} + + +def load_molecules(file_path, keep_props=True): + """Load molecules from file.""" + path = Path(file_path) + + if not path.exists(): + print(f"Error: File not found: {file_path}") + return [] + + molecules = [] + + if path.suffix.lower() in ['.sdf', '.mol']: + suppl = Chem.SDMolSupplier(str(path)) + elif path.suffix.lower() in ['.smi', '.smiles', '.txt']: + suppl = Chem.SmilesMolSupplier(str(path), titleLine=False) + else: + print(f"Error: Unsupported file format: {path.suffix}") + return [] + + for idx, mol in enumerate(suppl): + if mol is None: + print(f"Warning: Failed to parse molecule {idx+1}") + continue + + molecules.append(mol) + + return molecules + + +def create_pattern_query(pattern_string): + """Create SMARTS query from string or SMILES.""" + # Try as SMARTS first + query = Chem.MolFromSmarts(pattern_string) + if query is not None: + return query + + # Try as SMILES + query = Chem.MolFromSmiles(pattern_string) + if query is not None: + return query + + print(f"Error: Invalid pattern: {pattern_string}") + return None + + +def filter_molecules(molecules, include_patterns=None, exclude_patterns=None, + match_all_include=False): + """ + Filter molecules based on substructure patterns. + + Args: + molecules: List of RDKit Mol objects + include_patterns: List of (name, pattern) tuples to include + exclude_patterns: List of (name, pattern) tuples to exclude + match_all_include: If True, molecule must match ALL include patterns + + Returns: + Tuple of (filtered_molecules, match_info) + """ + filtered = [] + match_info = [] + + for idx, mol in enumerate(molecules): + if mol is None: + continue + + # Check exclusion patterns first + excluded = False + exclude_matches = [] + if exclude_patterns: + for name, pattern in exclude_patterns: + if mol.HasSubstructMatch(pattern): + excluded = True + exclude_matches.append(name) + + if excluded: + match_info.append({ + 'index': idx + 1, + 'smiles': Chem.MolToSmiles(mol), + 'status': 'excluded', + 'matches': exclude_matches + }) + continue + + # Check inclusion patterns + if include_patterns: + include_matches = [] + for name, pattern in include_patterns: + if mol.HasSubstructMatch(pattern): + include_matches.append(name) + + # Decide if molecule passes inclusion filter + if match_all_include: + passed = len(include_matches) == len(include_patterns) + else: + passed = len(include_matches) > 0 + + if passed: + filtered.append(mol) + match_info.append({ + 'index': idx + 1, + 'smiles': Chem.MolToSmiles(mol), + 'status': 'included', + 'matches': include_matches + }) + else: + match_info.append({ + 'index': idx + 1, + 'smiles': Chem.MolToSmiles(mol), + 'status': 'no_match', + 'matches': [] + }) + else: + # No inclusion patterns, keep all non-excluded + filtered.append(mol) + match_info.append({ + 'index': idx + 1, + 'smiles': Chem.MolToSmiles(mol), + 'status': 'included', + 'matches': [] + }) + + return filtered, match_info + + +def write_molecules(molecules, output_file): + """Write molecules to file.""" + output_path = Path(output_file) + + if output_path.suffix.lower() in ['.sdf']: + writer = Chem.SDWriter(str(output_path)) + for mol in molecules: + writer.write(mol) + writer.close() + elif output_path.suffix.lower() in ['.smi', '.smiles', '.txt']: + with open(output_path, 'w') as f: + for mol in molecules: + smiles = Chem.MolToSmiles(mol) + name = mol.GetProp('_Name') if mol.HasProp('_Name') else '' + f.write(f"{smiles} {name}\n") + else: + print(f"Error: Unsupported output format: {output_path.suffix}") + return + + print(f"Wrote {len(molecules)} molecules to {output_file}") + + +def write_report(match_info, output_file): + """Write detailed match report.""" + import csv + + with open(output_file, 'w', newline='') as f: + fieldnames = ['Index', 'SMILES', 'Status', 'Matches'] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for info in match_info: + writer.writerow({ + 'Index': info['index'], + 'SMILES': info['smiles'], + 'Status': info['status'], + 'Matches': ', '.join(info['matches']) + }) + + +def print_summary(total, filtered, match_info): + """Print filtering summary.""" + print("\n" + "="*60) + print("Filtering Summary") + print("="*60) + print(f"Total molecules: {total}") + print(f"Passed filter: {len(filtered)}") + print(f"Filtered out: {total - len(filtered)}") + print(f"Pass rate: {len(filtered)/total*100:.1f}%") + + # Count by status + status_counts = {} + for info in match_info: + status = info['status'] + status_counts[status] = status_counts.get(status, 0) + 1 + + print("\nBreakdown:") + for status, count in status_counts.items(): + print(f" {status:15s}: {count}") + + print("="*60) + + +def main(): + parser = argparse.ArgumentParser( + description='Filter molecules by substructure patterns', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Pattern libraries: + --filter-type functional-groups Common functional groups + --filter-type rings Ring systems + --filter-type pains PAINS (Pan-Assay Interference) + --filter-type privileged Privileged structures + +Examples: + # Include molecules with benzene ring + python substructure_filter.py molecules.smi --pattern "c1ccccc1" -o filtered.smi + + # Exclude reactive groups + python substructure_filter.py database.sdf --exclude "C(=O)Cl" -o clean.sdf + + # Filter by functional groups + python substructure_filter.py molecules.smi --filter-type functional-groups -o fg.smi + + # Remove PAINS + python substructure_filter.py compounds.smi --filter-type pains --exclude-mode -o clean.smi + + # Multiple patterns + python substructure_filter.py mol.smi --pattern "c1ccccc1" --pattern "N" -o aromatic_amines.smi + """ + ) + + parser.add_argument('input', help='Input file (SDF or SMILES)') + parser.add_argument('--pattern', '-p', action='append', + help='SMARTS/SMILES pattern to include (can specify multiple)') + parser.add_argument('--exclude', '-e', action='append', + help='SMARTS/SMILES pattern to exclude (can specify multiple)') + parser.add_argument('--filter-type', choices=PATTERN_LIBRARIES.keys(), + help='Use predefined pattern library') + parser.add_argument('--exclude-mode', action='store_true', + help='Use filter-type patterns for exclusion instead of inclusion') + parser.add_argument('--match-all', action='store_true', + help='Molecule must match ALL include patterns') + parser.add_argument('--output', '-o', help='Output file') + parser.add_argument('--report', '-r', help='Write detailed report to CSV') + parser.add_argument('--list-patterns', action='store_true', + help='List available pattern libraries and exit') + + args = parser.parse_args() + + # List patterns mode + if args.list_patterns: + print("\nAvailable Pattern Libraries:") + print("="*60) + for lib_name, patterns in PATTERN_LIBRARIES.items(): + print(f"\n{lib_name}:") + for name, pattern in patterns.items(): + print(f" {name:25s}: {pattern}") + sys.exit(0) + + # Load molecules + print(f"Loading molecules from: {args.input}") + molecules = load_molecules(args.input) + if not molecules: + print("Error: No valid molecules loaded") + sys.exit(1) + + print(f"Loaded {len(molecules)} molecules") + + # Prepare patterns + include_patterns = [] + exclude_patterns = [] + + # Add custom include patterns + if args.pattern: + for pattern_str in args.pattern: + query = create_pattern_query(pattern_str) + if query: + include_patterns.append(('custom', query)) + + # Add custom exclude patterns + if args.exclude: + for pattern_str in args.exclude: + query = create_pattern_query(pattern_str) + if query: + exclude_patterns.append(('custom', query)) + + # Add library patterns + if args.filter_type: + lib_patterns = PATTERN_LIBRARIES[args.filter_type] + for name, pattern_str in lib_patterns.items(): + query = create_pattern_query(pattern_str) + if query: + if args.exclude_mode: + exclude_patterns.append((name, query)) + else: + include_patterns.append((name, query)) + + if not include_patterns and not exclude_patterns: + print("Error: No patterns specified") + sys.exit(1) + + # Print filter configuration + print(f"\nFilter configuration:") + if include_patterns: + print(f" Include patterns: {len(include_patterns)}") + if args.match_all: + print(" Mode: Match ALL") + else: + print(" Mode: Match ANY") + if exclude_patterns: + print(f" Exclude patterns: {len(exclude_patterns)}") + + # Perform filtering + print("\nFiltering...") + filtered, match_info = filter_molecules( + molecules, + include_patterns=include_patterns if include_patterns else None, + exclude_patterns=exclude_patterns if exclude_patterns else None, + match_all_include=args.match_all + ) + + # Print summary + print_summary(len(molecules), filtered, match_info) + + # Write output + if args.output: + write_molecules(filtered, args.output) + + if args.report: + write_report(match_info, args.report) + print(f"Detailed report written to: {args.report}") + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/reportlab/SKILL.md b/scientific-packages/reportlab/SKILL.md new file mode 100644 index 0000000..96fc349 --- /dev/null +++ b/scientific-packages/reportlab/SKILL.md @@ -0,0 +1,621 @@ +--- +name: reportlab +description: This skill provides comprehensive guidance for creating PDF documents using the ReportLab Python library. Use this skill when generating PDFs programmatically, including invoices, reports, certificates, labels, forms, and any document requiring precise layout control. The skill covers both low-level Canvas API for pixel-perfect positioning and high-level Platypus for flowing multi-page documents, along with tables, charts, barcodes, text formatting, and PDF features. +--- + +# ReportLab PDF Generation + +## Overview + +ReportLab is a powerful Python library for programmatic PDF generation. Create anything from simple documents to complex reports with tables, charts, images, and interactive forms. + +**Two main approaches:** +- **Canvas API** (low-level): Direct drawing with coordinate-based positioning - use for precise layouts +- **Platypus** (high-level): Flowing document layout with automatic page breaks - use for multi-page documents + +**Core capabilities:** +- Text with rich formatting and custom fonts +- Tables with complex styling and cell spanning +- Charts (bar, line, pie, area, scatter) +- Barcodes and QR codes (Code128, EAN, QR, etc.) +- Images with transparency +- PDF features (links, bookmarks, forms, encryption) + +## Choosing the Right Approach + +### Use Canvas API when: +- Creating labels, business cards, certificates +- Precise positioning is critical (x, y coordinates) +- Single-page documents or simple layouts +- Drawing graphics, shapes, and custom designs +- Adding barcodes or QR codes at specific locations + +### Use Platypus when: +- Creating multi-page documents (reports, articles, books) +- Content should flow automatically across pages +- Need headers/footers that repeat on each page +- Working with paragraphs that can split across pages +- Building complex documents with table of contents + +### Use Both when: +- Complex reports need both flowing content AND precise positioning +- Adding headers/footers to Platypus documents (use `onPage` callback with Canvas) +- Embedding custom graphics (Canvas) within flowing documents (Platypus) + +## Quick Start Examples + +### Simple Canvas Document + +```python +from reportlab.pdfgen import canvas +from reportlab.lib.pagesizes import letter +from reportlab.lib.units import inch + +c = canvas.Canvas("output.pdf", pagesize=letter) +width, height = letter + +# Draw text +c.setFont("Helvetica-Bold", 24) +c.drawString(inch, height - inch, "Hello ReportLab!") + +# Draw a rectangle +c.setFillColorRGB(0.2, 0.4, 0.8) +c.rect(inch, 5*inch, 4*inch, 2*inch, fill=1) + +# Save +c.showPage() +c.save() +``` + +### Simple Platypus Document + +```python +from reportlab.lib.pagesizes import letter +from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer +from reportlab.lib.styles import getSampleStyleSheet +from reportlab.lib.units import inch + +doc = SimpleDocTemplate("output.pdf", pagesize=letter) +story = [] +styles = getSampleStyleSheet() + +# Add content +story.append(Paragraph("Document Title", styles['Title'])) +story.append(Spacer(1, 0.2*inch)) +story.append(Paragraph("This is body text with bold and italic.", styles['BodyText'])) + +# Build PDF +doc.build(story) +``` + +## Common Tasks + +### Creating Tables + +Tables work with both Canvas (via Drawing) and Platypus (as Flowables): + +```python +from reportlab.platypus import Table, TableStyle +from reportlab.lib import colors +from reportlab.lib.units import inch + +# Define data +data = [ + ['Product', 'Q1', 'Q2', 'Q3', 'Q4'], + ['Widget A', '100', '150', '130', '180'], + ['Widget B', '80', '120', '110', '160'], +] + +# Create table +table = Table(data, colWidths=[2*inch, 1*inch, 1*inch, 1*inch, 1*inch]) + +# Apply styling +style = TableStyle([ + # Header row + ('BACKGROUND', (0, 0), (-1, 0), colors.darkblue), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + + # Data rows + ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), + ('GRID', (0, 0), (-1, -1), 1, colors.black), +]) + +table.setStyle(style) + +# Add to Platypus story +story.append(table) + +# Or draw on Canvas +table.wrapOn(c, width, height) +table.drawOn(c, x, y) +``` + +**Detailed table reference:** See `references/tables_reference.md` for cell spanning, borders, alignment, and advanced styling. + +### Creating Charts + +Charts use the graphics framework and can be added to both Canvas and Platypus: + +```python +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.lib import colors + +# Create drawing +drawing = Drawing(400, 200) + +# Create chart +chart = VerticalBarChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 125 + +# Set data +chart.data = [[100, 150, 130, 180, 140]] +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5'] + +# Style +chart.bars[0].fillColor = colors.blue +chart.valueAxis.valueMin = 0 +chart.valueAxis.valueMax = 200 + +# Add to drawing +drawing.add(chart) + +# Use in Platypus +story.append(drawing) + +# Or render directly to PDF +from reportlab.graphics import renderPDF +renderPDF.drawToFile(drawing, 'chart.pdf', 'Chart Title') +``` + +**Available chart types:** Bar (vertical/horizontal), Line, Pie, Area, Scatter +**Detailed charts reference:** See `references/charts_reference.md` for all chart types, styling, legends, and customization. + +### Adding Barcodes and QR Codes + +```python +from reportlab.graphics.barcode import code128 +from reportlab.graphics.barcode.qr import QrCodeWidget +from reportlab.graphics.shapes import Drawing +from reportlab.graphics import renderPDF + +# Code128 barcode (general purpose) +barcode = code128.Code128("ABC123456789", barHeight=0.5*inch) + +# On Canvas +barcode.drawOn(c, x, y) + +# QR Code +qr = QrCodeWidget("https://example.com") +qr.barWidth = 2*inch +qr.barHeight = 2*inch + +# Wrap in Drawing for Platypus +d = Drawing() +d.add(qr) +story.append(d) +``` + +**Supported formats:** Code128, Code39, EAN-13, EAN-8, UPC-A, ISBN, QR, Data Matrix, and 20+ more +**Detailed barcode reference:** See `references/barcodes_reference.md` for all formats and usage examples. + +### Working with Text and Fonts + +```python +from reportlab.platypus import Paragraph +from reportlab.lib.styles import ParagraphStyle +from reportlab.lib.enums import TA_JUSTIFY + +# Create custom style +custom_style = ParagraphStyle( + 'CustomStyle', + fontSize=12, + leading=14, # Line spacing + alignment=TA_JUSTIFY, + spaceAfter=10, + textColor=colors.black, +) + +# Paragraph with inline formatting +text = """ +This paragraph has bold, italic, and underlined text. +You can also use colors and different sizes. +Chemical formula: H2O, Einstein: E=mc2 +""" + +para = Paragraph(text, custom_style) +story.append(para) +``` + +**Using custom fonts:** + +```python +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont + +# Register TrueType font +pdfmetrics.registerFont(TTFont('CustomFont', 'CustomFont.ttf')) + +# Use in Canvas +c.setFont('CustomFont', 12) + +# Use in Paragraph style +style = ParagraphStyle('Custom', fontName='CustomFont', fontSize=12) +``` + +**Detailed text reference:** See `references/text_and_fonts.md` for paragraph styles, font families, Asian languages, Greek letters, and formatting. + +### Adding Images + +```python +from reportlab.platypus import Image +from reportlab.lib.units import inch + +# In Platypus +img = Image('photo.jpg', width=4*inch, height=3*inch) +story.append(img) + +# Maintain aspect ratio +img = Image('photo.jpg', width=4*inch, height=3*inch, kind='proportional') + +# In Canvas +c.drawImage('photo.jpg', x, y, width=4*inch, height=3*inch) + +# With transparency (mask white background) +c.drawImage('logo.png', x, y, mask=[255,255,255,255,255,255]) +``` + +### Creating Forms + +```python +from reportlab.pdfgen import canvas +from reportlab.lib.colors import black, white, lightgrey + +c = canvas.Canvas("form.pdf") + +# Text field +c.acroForm.textfield( + name="name", + tooltip="Enter your name", + x=100, y=700, + width=200, height=20, + borderColor=black, + fillColor=lightgrey, + forceBorder=True +) + +# Checkbox +c.acroForm.checkbox( + name="agree", + x=100, y=650, + size=20, + buttonStyle='check', + checked=False +) + +# Dropdown +c.acroForm.choice( + name="country", + x=100, y=600, + width=150, height=20, + options=[("United States", "US"), ("Canada", "CA")], + forceBorder=True +) + +c.save() +``` + +**Detailed PDF features reference:** See `references/pdf_features.md` for forms, links, bookmarks, encryption, and metadata. + +### Headers and Footers + +For Platypus documents, use page callbacks: + +```python +from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame + +def add_header_footer(canvas, doc): + """Called on each page""" + canvas.saveState() + + # Header + canvas.setFont('Helvetica', 9) + canvas.drawString(inch, height - 0.5*inch, "Document Title") + + # Footer + canvas.drawRightString(width - inch, 0.5*inch, f"Page {doc.page}") + + canvas.restoreState() + +# Set up document +doc = BaseDocTemplate("output.pdf") +frame = Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height, id='normal') +template = PageTemplate(id='normal', frames=[frame], onPage=add_header_footer) +doc.addPageTemplates([template]) + +# Build with story +doc.build(story) +``` + +## Helper Scripts + +This skill includes helper scripts for common tasks: + +### Quick Document Generator + +Use `scripts/quick_document.py` for rapid document creation: + +```python +from scripts.quick_document import create_simple_document, create_styled_table + +# Simple document from content blocks +content = [ + {'type': 'heading', 'content': 'Introduction'}, + {'type': 'paragraph', 'content': 'Your text here...'}, + {'type': 'bullet', 'content': 'Bullet point'}, +] + +create_simple_document("output.pdf", "My Document", content_blocks=content) + +# Styled tables with presets +data = [['Header1', 'Header2'], ['Data1', 'Data2']] +table = create_styled_table(data, style_name='striped') # 'default', 'striped', 'minimal', 'report' +``` + +## Template Examples + +Complete working examples in `assets/`: + +### Invoice Template + +`assets/invoice_template.py` - Professional invoice with: +- Company and client information +- Itemized table with calculations +- Tax and totals +- Terms and notes +- Logo placement + +```python +from assets.invoice_template import create_invoice + +create_invoice( + filename="invoice.pdf", + invoice_number="INV-2024-001", + invoice_date="January 15, 2024", + due_date="February 15, 2024", + company_info={'name': 'Acme Corp', 'address': '...', 'phone': '...', 'email': '...'}, + client_info={'name': 'Client Name', ...}, + items=[ + {'description': 'Service', 'quantity': 1, 'unit_price': 500.00}, + ... + ], + tax_rate=0.08, + notes="Thank you for your business!", +) +``` + +### Report Template + +`assets/report_template.py` - Multi-page business report with: +- Cover page +- Table of contents +- Multiple sections with subsections +- Charts and tables +- Headers and footers + +```python +from assets.report_template import create_report + +report_data = { + 'title': 'Quarterly Report', + 'subtitle': 'Q4 2023', + 'author': 'Analytics Team', + 'sections': [ + { + 'title': 'Executive Summary', + 'content': 'Report content...', + 'table_data': {...}, + 'chart_data': {...} + }, + ... + ] +} + +create_report("report.pdf", report_data) +``` + +## Reference Documentation + +Comprehensive API references organized by feature: + +- **`references/canvas_api.md`** - Low-level Canvas: drawing primitives, coordinates, transformations, state management, images, paths +- **`references/platypus_guide.md`** - High-level Platypus: document templates, frames, flowables, page layouts, TOC +- **`references/text_and_fonts.md`** - Text formatting: paragraph styles, inline markup, custom fonts, Asian languages, bullets, sequences +- **`references/tables_reference.md`** - Tables: creation, styling, cell spanning, borders, alignment, colors, gradients +- **`references/charts_reference.md`** - Charts: all chart types, data handling, axes, legends, colors, rendering +- **`references/barcodes_reference.md`** - Barcodes: Code128, QR codes, EAN, UPC, postal codes, and 20+ formats +- **`references/pdf_features.md`** - PDF features: links, bookmarks, forms, encryption, metadata, page transitions + +## Best Practices + +### Coordinate System (Canvas) +- Origin (0, 0) is **lower-left corner** (not top-left) +- Y-axis points **upward** +- Units are in **points** (72 points = 1 inch) +- Always specify page size explicitly + +```python +from reportlab.lib.pagesizes import letter +from reportlab.lib.units import inch + +width, height = letter +margin = inch + +# Top of page +y_top = height - margin + +# Bottom of page +y_bottom = margin +``` + +### Choosing Page Size + +```python +from reportlab.lib.pagesizes import letter, A4, landscape + +# US Letter (8.5" x 11") +pagesize=letter + +# ISO A4 (210mm x 297mm) +pagesize=A4 + +# Landscape +pagesize=landscape(letter) + +# Custom +pagesize=(6*inch, 9*inch) +``` + +### Performance Tips + +1. **Use `drawImage()` over `drawInlineImage()`** - caches images for reuse +2. **Enable compression for large files:** `canvas.Canvas("file.pdf", pageCompression=1)` +3. **Reuse styles** - create once, use throughout document +4. **Use Forms/XObjects** for repeated graphics + +### Common Patterns + +**Centering text on Canvas:** +```python +text = "Centered Text" +text_width = c.stringWidth(text, "Helvetica", 12) +x = (width - text_width) / 2 +c.drawString(x, y, text) + +# Or use built-in +c.drawCentredString(width/2, y, text) +``` + +**Page breaks in Platypus:** +```python +from reportlab.platypus import PageBreak + +story.append(PageBreak()) +``` + +**Keep content together (no split):** +```python +from reportlab.platypus import KeepTogether + +story.append(KeepTogether([ + heading, + paragraph1, + paragraph2, +])) +``` + +**Alternate row colors:** +```python +style = TableStyle([ + ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), +]) +``` + +## Troubleshooting + +**Text overlaps or disappears:** +- Check Y-coordinates - remember origin is bottom-left +- Ensure text fits within page bounds +- Verify `leading` (line spacing) is greater than `fontSize` + +**Table doesn't fit on page:** +- Reduce column widths +- Decrease font size +- Use landscape orientation +- Enable table splitting with `repeatRows` + +**Barcode not scanning:** +- Increase `barHeight` (try 0.5 inch minimum) +- Set `quiet=1` for quiet zones +- Test print quality (300+ DPI recommended) +- Validate data format for barcode type + +**Font not found:** +- Register TrueType fonts with `pdfmetrics.registerFont()` +- Use font family name exactly as registered +- Check font file path is correct + +**Images have white background:** +- Use `mask` parameter to make white transparent +- Provide RGB range to mask: `mask=[255,255,255,255,255,255]` +- Or use PNG with alpha channel + +## Example Workflows + +### Creating an Invoice + +1. Start with invoice template from `assets/invoice_template.py` +2. Customize company info, logo path +3. Add items with descriptions, quantities, prices +4. Set tax rate if applicable +5. Add notes and payment terms +6. Generate PDF + +### Creating a Report + +1. Start with report template from `assets/report_template.py` +2. Define sections with titles and content +3. Add tables for data using `create_styled_table()` +4. Add charts using graphics framework +5. Build with `doc.multiBuild(story)` for TOC + +### Creating a Certificate + +1. Use Canvas API for precise positioning +2. Load custom fonts for elegant typography +3. Add border graphics or image background +4. Position text elements (name, date, achievement) +5. Optional: Add QR code for verification + +### Creating Labels with Barcodes + +1. Use Canvas with custom page size (label dimensions) +2. Calculate grid positions for multiple labels per page +3. Draw label content (text, images) +4. Add barcode at specific position +5. Use `showPage()` between labels or grids + +## Installation + +```bash +pip install reportlab + +# For image support +pip install pillow + +# For charts +pip install reportlab[renderPM] + +# For barcode support (included in reportlab) +# QR codes require: pip install qrcode +``` + +## When to Use This Skill + +Invoke this skill when: +- Generating PDF documents programmatically +- Creating invoices, receipts, or billing documents +- Building reports with tables and charts +- Generating certificates, badges, or credentials +- Creating shipping labels or product labels with barcodes +- Designing forms or fillable PDFs +- Producing multi-page documents with consistent formatting +- Converting data to PDF format for archival or distribution +- Creating custom layouts that require precise positioning + +This skill provides comprehensive guidance for all ReportLab capabilities, from simple documents to complex multi-page reports with charts, tables, and interactive elements. diff --git a/scientific-packages/reportlab/assets/invoice_template.py b/scientific-packages/reportlab/assets/invoice_template.py new file mode 100644 index 0000000..9d8f592 --- /dev/null +++ b/scientific-packages/reportlab/assets/invoice_template.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Invoice Template - Complete example of a professional invoice + +This template demonstrates: +- Company header with logo placement +- Client information +- Invoice details table +- Calculations (subtotal, tax, total) +- Professional styling +- Terms and conditions footer +""" + +from reportlab.lib.pagesizes import letter +from reportlab.lib.units import inch +from reportlab.lib import colors +from reportlab.platypus import ( + SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image +) +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.lib.enums import TA_LEFT, TA_RIGHT, TA_CENTER +from datetime import datetime + + +def create_invoice( + filename, + invoice_number, + invoice_date, + due_date, + company_info, + client_info, + items, + tax_rate=0.0, + notes="", + terms="Payment due within 30 days.", + logo_path=None +): + """ + Create a professional invoice PDF. + + Args: + filename: Output PDF filename + invoice_number: Invoice number (e.g., "INV-2024-001") + invoice_date: Date of invoice (datetime or string) + due_date: Payment due date (datetime or string) + company_info: Dict with company details + {'name': 'Company Name', 'address': 'Address', 'phone': 'Phone', 'email': 'Email'} + client_info: Dict with client details (same structure as company_info) + items: List of dicts with item details + [{'description': 'Item', 'quantity': 1, 'unit_price': 100.00}, ...] + tax_rate: Tax rate as decimal (e.g., 0.08 for 8%) + notes: Additional notes to client + terms: Payment terms + logo_path: Path to company logo image (optional) + """ + # Create document + doc = SimpleDocTemplate(filename, pagesize=letter, + rightMargin=0.5*inch, leftMargin=0.5*inch, + topMargin=0.5*inch, bottomMargin=0.5*inch) + + # Container for elements + story = [] + styles = getSampleStyleSheet() + + # Create custom styles + title_style = ParagraphStyle( + 'InvoiceTitle', + parent=styles['Heading1'], + fontSize=24, + textColor=colors.HexColor('#2C3E50'), + spaceAfter=12, + ) + + header_style = ParagraphStyle( + 'Header', + parent=styles['Normal'], + fontSize=10, + textColor=colors.HexColor('#34495E'), + ) + + # --- HEADER SECTION --- + header_data = [] + + # Company info (left side) + company_text = f""" + {company_info['name']}
+ {company_info.get('address', '')}
+ Phone: {company_info.get('phone', '')}
+ Email: {company_info.get('email', '')} + """ + + # Invoice title and number (right side) + invoice_text = f""" + INVOICE
+ Invoice #: {invoice_number}
+ Date: {invoice_date}
+ Due Date: {due_date} + """ + + if logo_path: + logo = Image(logo_path, width=1.5*inch, height=1*inch) + header_data = [[logo, Paragraph(company_text, header_style), Paragraph(invoice_text, header_style)]] + header_table = Table(header_data, colWidths=[1.5*inch, 3*inch, 2.5*inch]) + else: + header_data = [[Paragraph(company_text, header_style), Paragraph(invoice_text, header_style)]] + header_table = Table(header_data, colWidths=[4.5*inch, 2.5*inch]) + + header_table.setStyle(TableStyle([ + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('ALIGN', (-1, 0), (-1, -1), 'RIGHT'), + ])) + + story.append(header_table) + story.append(Spacer(1, 0.3*inch)) + + # --- CLIENT INFORMATION --- + client_label = Paragraph("Bill To:", header_style) + client_text = f""" + {client_info['name']}
+ {client_info.get('address', '')}
+ Phone: {client_info.get('phone', '')}
+ Email: {client_info.get('email', '')} + """ + client_para = Paragraph(client_text, header_style) + + client_table = Table([[client_label, client_para]], colWidths=[1*inch, 6*inch]) + story.append(client_table) + story.append(Spacer(1, 0.3*inch)) + + # --- ITEMS TABLE --- + # Table header + items_data = [['Description', 'Quantity', 'Unit Price', 'Amount']] + + # Calculate items + subtotal = 0 + for item in items: + desc = item['description'] + qty = item['quantity'] + price = item['unit_price'] + amount = qty * price + subtotal += amount + + items_data.append([ + desc, + str(qty), + f"${price:,.2f}", + f"${amount:,.2f}" + ]) + + # Create items table + items_table = Table(items_data, colWidths=[3.5*inch, 1*inch, 1.5*inch, 1*inch]) + + items_table.setStyle(TableStyle([ + # Header row + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495E')), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, 0), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 11), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + + # Data rows + ('BACKGROUND', (0, 1), (-1, -1), colors.white), + ('ALIGN', (1, 1), (-1, -1), 'RIGHT'), + ('ALIGN', (0, 1), (0, -1), 'LEFT'), + ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'), + ('FONTSIZE', (0, 1), (-1, -1), 10), + ('TOPPADDING', (0, 1), (-1, -1), 6), + ('BOTTOMPADDING', (0, 1), (-1, -1), 6), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ])) + + story.append(items_table) + story.append(Spacer(1, 0.2*inch)) + + # --- TOTALS SECTION --- + tax_amount = subtotal * tax_rate + total = subtotal + tax_amount + + totals_data = [ + ['Subtotal:', f"${subtotal:,.2f}"], + ] + + if tax_rate > 0: + totals_data.append([f'Tax ({tax_rate*100:.1f}%):', f"${tax_amount:,.2f}"]) + + totals_data.append(['Total:', f"${total:,.2f}"]) + + totals_table = Table(totals_data, colWidths=[5*inch, 2*inch]) + totals_table.setStyle(TableStyle([ + ('ALIGN', (0, 0), (-1, -1), 'RIGHT'), + ('FONTNAME', (0, 0), (-1, -2), 'Helvetica'), + ('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, -1), 11), + ('TOPPADDING', (0, 0), (-1, -1), 6), + ('LINEABOVE', (1, -1), (1, -1), 2, colors.HexColor('#34495E')), + ])) + + story.append(totals_table) + story.append(Spacer(1, 0.4*inch)) + + # --- NOTES --- + if notes: + notes_style = ParagraphStyle('Notes', parent=styles['Normal'], fontSize=9) + story.append(Paragraph(f"Notes:
{notes}", notes_style)) + story.append(Spacer(1, 0.2*inch)) + + # --- TERMS --- + terms_style = ParagraphStyle('Terms', parent=styles['Normal'], + fontSize=9, textColor=colors.grey) + story.append(Paragraph(f"Payment Terms:
{terms}", terms_style)) + + # Build PDF + doc.build(story) + return filename + + +# Example usage +if __name__ == "__main__": + # Sample data + company = { + 'name': 'Acme Corporation', + 'address': '123 Business St, Suite 100\nNew York, NY 10001', + 'phone': '(555) 123-4567', + 'email': 'info@acme.com' + } + + client = { + 'name': 'John Doe', + 'address': '456 Client Ave\nLos Angeles, CA 90001', + 'phone': '(555) 987-6543', + 'email': 'john@example.com' + } + + items = [ + {'description': 'Web Design Services', 'quantity': 1, 'unit_price': 2500.00}, + {'description': 'Content Writing (10 pages)', 'quantity': 10, 'unit_price': 50.00}, + {'description': 'SEO Optimization', 'quantity': 1, 'unit_price': 750.00}, + {'description': 'Hosting Setup', 'quantity': 1, 'unit_price': 200.00}, + ] + + create_invoice( + filename="sample_invoice.pdf", + invoice_number="INV-2024-001", + invoice_date="January 15, 2024", + due_date="February 15, 2024", + company_info=company, + client_info=client, + items=items, + tax_rate=0.08, + notes="Thank you for your business! We appreciate your prompt payment.", + terms="Payment due within 30 days. Late payments subject to 1.5% monthly fee.", + logo_path=None # Set to your logo path if available + ) + + print("Invoice created: sample_invoice.pdf") diff --git a/scientific-packages/reportlab/assets/report_template.py b/scientific-packages/reportlab/assets/report_template.py new file mode 100644 index 0000000..9f50b3c --- /dev/null +++ b/scientific-packages/reportlab/assets/report_template.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +""" +Report Template - Complete example of a professional multi-page report + +This template demonstrates: +- Cover page +- Table of contents +- Multiple sections with headers +- Charts and graphs integration +- Tables with data +- Headers and footers +- Professional styling +""" + +from reportlab.lib.pagesizes import letter +from reportlab.lib.units import inch +from reportlab.lib import colors +from reportlab.platypus import ( + BaseDocTemplate, PageTemplate, Frame, Paragraph, Spacer, + Table, TableStyle, PageBreak, KeepTogether, TableOfContents +) +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.lib.enums import TA_LEFT, TA_RIGHT, TA_CENTER, TA_JUSTIFY +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.graphics.charts.linecharts import HorizontalLineChart +from datetime import datetime + + +def header_footer(canvas, doc): + """Draw header and footer on each page (except cover)""" + canvas.saveState() + + # Skip header/footer on cover page (page 1) + if doc.page > 1: + # Header + canvas.setFont('Helvetica', 9) + canvas.setFillColor(colors.grey) + canvas.drawString(inch, letter[1] - 0.5*inch, "Quarterly Business Report") + canvas.line(inch, letter[1] - 0.55*inch, letter[0] - inch, letter[1] - 0.55*inch) + + # Footer + canvas.drawString(inch, 0.5*inch, f"Generated: {datetime.now().strftime('%B %d, %Y')}") + canvas.drawRightString(letter[0] - inch, 0.5*inch, f"Page {doc.page - 1}") + + canvas.restoreState() + + +def create_report(filename, report_data): + """ + Create a comprehensive business report. + + Args: + filename: Output PDF filename + report_data: Dict containing report information + { + 'title': 'Report Title', + 'subtitle': 'Report Subtitle', + 'author': 'Author Name', + 'date': 'Date', + 'sections': [ + { + 'title': 'Section Title', + 'content': 'Section content...', + 'subsections': [...], + 'table': {...}, + 'chart': {...} + }, + ... + ] + } + """ + # Create document with custom page template + doc = BaseDocTemplate(filename, pagesize=letter, + rightMargin=72, leftMargin=72, + topMargin=inch, bottomMargin=inch) + + # Define frame for content + frame = Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height - 0.5*inch, id='normal') + + # Create page template with header/footer + template = PageTemplate(id='normal', frames=[frame], onPage=header_footer) + doc.addPageTemplates([template]) + + # Get styles + styles = getSampleStyleSheet() + + # Custom styles + title_style = ParagraphStyle( + 'ReportTitle', + parent=styles['Title'], + fontSize=28, + textColor=colors.HexColor('#2C3E50'), + spaceAfter=20, + alignment=TA_CENTER, + ) + + subtitle_style = ParagraphStyle( + 'ReportSubtitle', + parent=styles['Normal'], + fontSize=14, + textColor=colors.grey, + alignment=TA_CENTER, + spaceAfter=30, + ) + + heading1_style = ParagraphStyle( + 'CustomHeading1', + parent=styles['Heading1'], + fontSize=18, + textColor=colors.HexColor('#2C3E50'), + spaceAfter=12, + spaceBefore=12, + ) + + heading2_style = ParagraphStyle( + 'CustomHeading2', + parent=styles['Heading2'], + fontSize=14, + textColor=colors.HexColor('#34495E'), + spaceAfter=10, + spaceBefore=10, + ) + + body_style = ParagraphStyle( + 'ReportBody', + parent=styles['BodyText'], + fontSize=11, + alignment=TA_JUSTIFY, + spaceAfter=12, + leading=14, + ) + + # Build story + story = [] + + # --- COVER PAGE --- + story.append(Spacer(1, 2*inch)) + story.append(Paragraph(report_data['title'], title_style)) + story.append(Paragraph(report_data.get('subtitle', ''), subtitle_style)) + story.append(Spacer(1, inch)) + + # Cover info table + cover_info = [ + ['Prepared by:', report_data.get('author', '')], + ['Date:', report_data.get('date', datetime.now().strftime('%B %d, %Y'))], + ['Period:', report_data.get('period', 'Q4 2023')], + ] + + cover_table = Table(cover_info, colWidths=[2*inch, 4*inch]) + cover_table.setStyle(TableStyle([ + ('ALIGN', (0, 0), (0, -1), 'RIGHT'), + ('ALIGN', (1, 0), (1, -1), 'LEFT'), + ('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, -1), 11), + ('TOPPADDING', (0, 0), (-1, -1), 6), + ])) + + story.append(cover_table) + story.append(PageBreak()) + + # --- TABLE OF CONTENTS --- + toc = TableOfContents() + toc.levelStyles = [ + ParagraphStyle(name='TOCHeading1', fontSize=14, leftIndent=20, spaceBefore=10, spaceAfter=5), + ParagraphStyle(name='TOCHeading2', fontSize=12, leftIndent=40, spaceBefore=3, spaceAfter=3), + ] + + story.append(Paragraph("Table of Contents", heading1_style)) + story.append(toc) + story.append(PageBreak()) + + # --- SECTIONS --- + for section in report_data.get('sections', []): + # Section heading + section_title = section['title'] + story.append(Paragraph(f'{section_title}', heading1_style)) + + # Add to TOC + toc.addEntry(0, section_title, doc.page) + + # Section content + if 'content' in section: + for para in section['content'].split('\n\n'): + if para.strip(): + story.append(Paragraph(para.strip(), body_style)) + + story.append(Spacer(1, 0.2*inch)) + + # Subsections + for subsection in section.get('subsections', []): + story.append(Paragraph(subsection['title'], heading2_style)) + + if 'content' in subsection: + story.append(Paragraph(subsection['content'], body_style)) + + story.append(Spacer(1, 0.1*inch)) + + # Add table if provided + if 'table_data' in section: + table = create_section_table(section['table_data']) + story.append(table) + story.append(Spacer(1, 0.2*inch)) + + # Add chart if provided + if 'chart_data' in section: + chart = create_section_chart(section['chart_data']) + story.append(chart) + story.append(Spacer(1, 0.2*inch)) + + story.append(Spacer(1, 0.3*inch)) + + # Build PDF (twice for TOC to populate) + doc.multiBuild(story) + return filename + + +def create_section_table(table_data): + """Create a styled table for report sections""" + data = table_data['data'] + table = Table(data, colWidths=table_data.get('colWidths')) + + table.setStyle(TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495E')), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 11), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('BACKGROUND', (0, 1), (-1, -1), colors.white), + ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'), + ('FONTSIZE', (0, 1), (-1, -1), 10), + ])) + + return table + + +def create_section_chart(chart_data): + """Create a chart for report sections""" + chart_type = chart_data.get('type', 'bar') + drawing = Drawing(400, 200) + + if chart_type == 'bar': + chart = VerticalBarChart() + chart.x = 50 + chart.y = 30 + chart.width = 300 + chart.height = 150 + chart.data = chart_data['data'] + chart.categoryAxis.categoryNames = chart_data.get('categories', []) + chart.valueAxis.valueMin = 0 + + # Style bars + for i in range(len(chart_data['data'])): + chart.bars[i].fillColor = colors.HexColor(['#3498db', '#e74c3c', '#2ecc71'][i % 3]) + + elif chart_type == 'line': + chart = HorizontalLineChart() + chart.x = 50 + chart.y = 30 + chart.width = 300 + chart.height = 150 + chart.data = chart_data['data'] + chart.categoryAxis.categoryNames = chart_data.get('categories', []) + + # Style lines + for i in range(len(chart_data['data'])): + chart.lines[i].strokeColor = colors.HexColor(['#3498db', '#e74c3c', '#2ecc71'][i % 3]) + chart.lines[i].strokeWidth = 2 + + drawing.add(chart) + return drawing + + +# Example usage +if __name__ == "__main__": + report = { + 'title': 'Quarterly Business Report', + 'subtitle': 'Q4 2023 Performance Analysis', + 'author': 'Analytics Team', + 'date': 'January 15, 2024', + 'period': 'October - December 2023', + 'sections': [ + { + 'title': 'Executive Summary', + 'content': """ + This report provides a comprehensive analysis of our Q4 2023 performance. + Overall, the quarter showed strong growth across all key metrics, with + revenue increasing by 25% year-over-year and customer satisfaction + scores reaching an all-time high of 4.8/5.0. + + Key highlights include the successful launch of three new products, + expansion into two new markets, and the completion of our digital + transformation initiative. + """, + 'subsections': [ + { + 'title': 'Key Achievements', + 'content': 'Successfully launched Product X with 10,000 units sold in first month.' + } + ] + }, + { + 'title': 'Financial Performance', + 'content': """ + The financial results for Q4 exceeded expectations across all categories. + Revenue growth was driven primarily by strong product sales and increased + market share in key regions. + """, + 'table_data': { + 'data': [ + ['Metric', 'Q3 2023', 'Q4 2023', 'Change'], + ['Revenue', '$2.5M', '$3.1M', '+24%'], + ['Profit', '$500K', '$680K', '+36%'], + ['Expenses', '$2.0M', '$2.4M', '+20%'], + ], + 'colWidths': [2*inch, 1.5*inch, 1.5*inch, 1*inch] + }, + 'chart_data': { + 'type': 'bar', + 'data': [[2.5, 3.1], [0.5, 0.68], [2.0, 2.4]], + 'categories': ['Q3', 'Q4'] + } + }, + { + 'title': 'Market Analysis', + 'content': """ + Market conditions remained favorable throughout the quarter, with + strong consumer confidence and increasing demand for our products. + """, + 'chart_data': { + 'type': 'line', + 'data': [[100, 120, 115, 140, 135, 150]], + 'categories': ['Oct', 'Nov', 'Dec', 'Oct', 'Nov', 'Dec'] + } + }, + ] + } + + create_report("sample_report.pdf", report) + print("Report created: sample_report.pdf") diff --git a/scientific-packages/reportlab/references/barcodes_reference.md b/scientific-packages/reportlab/references/barcodes_reference.md new file mode 100644 index 0000000..72376ad --- /dev/null +++ b/scientific-packages/reportlab/references/barcodes_reference.md @@ -0,0 +1,504 @@ +# Barcodes Reference + +Comprehensive guide to creating barcodes and QR codes in ReportLab. + +## Available Barcode Types + +ReportLab supports a wide range of 1D and 2D barcode formats. + +### 1D Barcodes (Linear) + +- **Code128** - Compact, encodes full ASCII +- **Code39** (Standard39) - Alphanumeric, widely supported +- **Code93** (Standard93) - Compressed Code39 +- **EAN-13** - European Article Number (retail) +- **EAN-8** - Short form of EAN +- **EAN-5** - 5-digit add-on (pricing) +- **UPC-A** - Universal Product Code (North America) +- **ISBN** - International Standard Book Number +- **Code11** - Telecommunications +- **Codabar** - Blood banks, FedEx, libraries +- **I2of5** (Interleaved 2 of 5) - Warehouse/distribution +- **MSI** - Inventory control +- **POSTNET** - US Postal Service +- **USPS_4State** - US Postal Service +- **FIM** (A, B, C, D) - Facing Identification Mark (mail sorting) + +### 2D Barcodes + +- **QR** - QR Code (widely used for URLs, contact info) +- **ECC200DataMatrix** - Data Matrix format + +## Using Barcodes with Canvas + +### Code128 (Recommended for General Use) + +Code128 is versatile and compact - encodes full ASCII character set with mandatory checksum. + +```python +from reportlab.pdfgen import canvas +from reportlab.graphics.barcode import code128 +from reportlab.lib.units import inch + +c = canvas.Canvas("barcode.pdf") + +# Create barcode +barcode = code128.Code128("HELLO123") + +# Draw on canvas +barcode.drawOn(c, 1*inch, 5*inch) + +c.save() +``` + +### Code128 Options + +```python +barcode = code128.Code128( + value="ABC123", # Required: data to encode + barWidth=0.01*inch, # Width of narrowest bar + barHeight=0.5*inch, # Height of bars + quiet=1, # Add quiet zones (margins) + lquiet=None, # Left quiet zone width + rquiet=None, # Right quiet zone width + stop=1, # Show stop symbol +) + +# Draw with specific size +barcode.drawOn(canvas, x, y) + +# Get dimensions +width = barcode.width +height = barcode.height +``` + +### Code39 (Standard39) + +Supports: 0-9, A-Z (uppercase), space, and special chars (-.$/+%*). + +```python +from reportlab.graphics.barcode import code39 + +barcode = code39.Standard39( + value="HELLO", + barWidth=0.01*inch, + barHeight=0.5*inch, + quiet=1, + checksum=0, # 0 or 1 +) + +barcode.drawOn(canvas, x, y) +``` + +### Extended Code39 + +Encodes full ASCII (pairs of Code39 characters). + +```python +from reportlab.graphics.barcode import code39 + +barcode = code39.Extended39( + value="Hello World!", # Can include lowercase and symbols + barWidth=0.01*inch, + barHeight=0.5*inch, +) + +barcode.drawOn(canvas, x, y) +``` + +### Code93 + +```python +from reportlab.graphics.barcode import code93 + +# Standard93 - uppercase, digits, some symbols +barcode = code93.Standard93( + value="HELLO93", + barWidth=0.01*inch, + barHeight=0.5*inch, +) + +# Extended93 - full ASCII +barcode = code93.Extended93( + value="Hello 93!", + barWidth=0.01*inch, + barHeight=0.5*inch, +) + +barcode.drawOn(canvas, x, y) +``` + +### EAN-13 (European Article Number) + +13-digit barcode for retail products. + +```python +from reportlab.graphics.barcode import eanbc + +# Must be exactly 12 digits (13th is calculated checksum) +barcode = eanbc.Ean13BarcodeWidget( + value="123456789012" +) + +# Draw +from reportlab.graphics import renderPDF +from reportlab.graphics.shapes import Drawing + +d = Drawing() +d.add(barcode) +renderPDF.draw(d, canvas, x, y) +``` + +### EAN-8 + +Short form, 8 digits. + +```python +from reportlab.graphics.barcode import eanbc + +# Must be exactly 7 digits (8th is calculated) +barcode = eanbc.Ean8BarcodeWidget( + value="1234567" +) +``` + +### UPC-A + +12-digit barcode used in North America. + +```python +from reportlab.graphics.barcode import usps + +# 11 digits (12th is checksum) +barcode = usps.UPCA( + value="01234567890" +) + +barcode.drawOn(canvas, x, y) +``` + +### ISBN (Books) + +```python +from reportlab.graphics.barcode.widgets import ISBNBarcodeWidget + +# 10 or 13 digit ISBN +barcode = ISBNBarcodeWidget( + value="978-0-123456-78-9" +) + +# With pricing (EAN-5 add-on) +barcode = ISBNBarcodeWidget( + value="978-0-123456-78-9", + price=True, +) +``` + +### QR Codes + +Most versatile 2D barcode - can encode URLs, text, contact info, etc. + +```python +from reportlab.graphics.barcode.qr import QrCodeWidget +from reportlab.graphics.shapes import Drawing +from reportlab.graphics import renderPDF + +# Create QR code +qr = QrCodeWidget("https://example.com") + +# Size in pixels (QR codes are square) +qr.barWidth = 100 # Width in points +qr.barHeight = 100 # Height in points + +# Error correction level +# L = 7% recovery, M = 15%, Q = 25%, H = 30% +qr.qrVersion = 1 # Auto-size (1-40, or None for auto) +qr.errorLevel = 'M' # L, M, Q, H + +# Draw +d = Drawing() +d.add(qr) +renderPDF.draw(d, canvas, x, y) +``` + +### QR Code - More Options + +```python +# URL QR Code +qr = QrCodeWidget("https://example.com") + +# Contact information (vCard) +vcard_data = """BEGIN:VCARD +VERSION:3.0 +FN:John Doe +TEL:+1-555-1234 +EMAIL:john@example.com +END:VCARD""" +qr = QrCodeWidget(vcard_data) + +# WiFi credentials +wifi_data = "WIFI:T:WPA;S:NetworkName;P:Password;;" +qr = QrCodeWidget(wifi_data) + +# Plain text +qr = QrCodeWidget("Any text here") +``` + +### Data Matrix (ECC200) + +Compact 2D barcode for small items. + +```python +from reportlab.graphics.barcode.datamatrix import DataMatrixWidget + +barcode = DataMatrixWidget( + value="DATA123" +) + +d = Drawing() +d.add(barcode) +renderPDF.draw(d, canvas, x, y) +``` + +### Postal Barcodes + +```python +from reportlab.graphics.barcode import usps + +# POSTNET (older format) +barcode = usps.POSTNET( + value="55555-1234", # ZIP or ZIP+4 +) + +# USPS 4-State (newer) +barcode = usps.USPS_4State( + value="12345678901234567890", # 20-digit routing code + routing="12345678901" +) + +barcode.drawOn(canvas, x, y) +``` + +### FIM (Facing Identification Mark) + +Used for mail sorting. + +```python +from reportlab.graphics.barcode import usps + +# FIM-A, FIM-B, FIM-C, or FIM-D +barcode = usps.FIM( + value="A" # A, B, C, or D +) + +barcode.drawOn(canvas, x, y) +``` + +## Using Barcodes with Platypus + +For flowing documents, wrap barcodes in Flowables. + +### Simple Approach - Drawing Flowable + +```python +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.barcode.qr import QrCodeWidget +from reportlab.lib.units import inch + +# Create drawing +d = Drawing(2*inch, 2*inch) + +# Create barcode +qr = QrCodeWidget("https://example.com") +qr.barWidth = 2*inch +qr.barHeight = 2*inch +qr.x = 0 +qr.y = 0 + +d.add(qr) + +# Add to story +story.append(d) +``` + +### Custom Flowable Wrapper + +```python +from reportlab.platypus import Flowable +from reportlab.graphics.barcode import code128 +from reportlab.lib.units import inch + +class BarcodeFlowable(Flowable): + def __init__(self, code, barcode_type='code128', width=2*inch, height=0.5*inch): + Flowable.__init__(self) + self.code = code + self.barcode_type = barcode_type + self.width_val = width + self.height_val = height + + # Create barcode + if barcode_type == 'code128': + self.barcode = code128.Code128(code, barWidth=width/100, barHeight=height) + # Add other types as needed + + def draw(self): + self.barcode.drawOn(self.canv, 0, 0) + + def wrap(self, availWidth, availHeight): + return (self.barcode.width, self.barcode.height) + +# Use in story +story.append(BarcodeFlowable("PRODUCT123")) +``` + +## Complete Examples + +### Product Label with Barcode + +```python +from reportlab.pdfgen import canvas +from reportlab.graphics.barcode import code128 +from reportlab.lib.pagesizes import letter +from reportlab.lib.units import inch + +def create_product_label(filename, product_code, product_name): + c = canvas.Canvas(filename, pagesize=(4*inch, 2*inch)) + + # Product name + c.setFont("Helvetica-Bold", 14) + c.drawCentredString(2*inch, 1.5*inch, product_name) + + # Barcode + barcode = code128.Code128(product_code) + barcode_width = barcode.width + barcode_height = barcode.height + + # Center barcode + x = (4*inch - barcode_width) / 2 + y = 0.5*inch + + barcode.drawOn(c, x, y) + + # Code text + c.setFont("Courier", 10) + c.drawCentredString(2*inch, 0.3*inch, product_code) + + c.save() + +create_product_label("label.pdf", "ABC123456789", "Premium Widget") +``` + +### QR Code Contact Card + +```python +from reportlab.pdfgen import canvas +from reportlab.graphics.barcode.qr import QrCodeWidget +from reportlab.graphics.shapes import Drawing +from reportlab.graphics import renderPDF +from reportlab.lib.units import inch + +def create_contact_card(filename, name, phone, email): + c = canvas.Canvas(filename, pagesize=(3.5*inch, 2*inch)) + + # Contact info + c.setFont("Helvetica-Bold", 12) + c.drawString(0.5*inch, 1.5*inch, name) + c.setFont("Helvetica", 10) + c.drawString(0.5*inch, 1.3*inch, phone) + c.drawString(0.5*inch, 1.1*inch, email) + + # Create vCard data + vcard = f"""BEGIN:VCARD +VERSION:3.0 +FN:{name} +TEL:{phone} +EMAIL:{email} +END:VCARD""" + + # QR code + qr = QrCodeWidget(vcard) + qr.barWidth = 1.5*inch + qr.barHeight = 1.5*inch + + d = Drawing() + d.add(qr) + + renderPDF.draw(d, c, 1.8*inch, 0.2*inch) + + c.save() + +create_contact_card("contact.pdf", "John Doe", "+1-555-1234", "john@example.com") +``` + +### Shipping Label with Multiple Barcodes + +```python +from reportlab.pdfgen import canvas +from reportlab.graphics.barcode import code128 +from reportlab.lib.units import inch + +def create_shipping_label(filename, tracking_code, zip_code): + c = canvas.Canvas(filename, pagesize=(6*inch, 4*inch)) + + # Title + c.setFont("Helvetica-Bold", 16) + c.drawString(0.5*inch, 3.5*inch, "SHIPPING LABEL") + + # Tracking barcode + c.setFont("Helvetica", 10) + c.drawString(0.5*inch, 2.8*inch, "Tracking Number:") + + tracking_barcode = code128.Code128(tracking_code, barHeight=0.5*inch) + tracking_barcode.drawOn(c, 0.5*inch, 2*inch) + + c.setFont("Courier", 9) + c.drawString(0.5*inch, 1.8*inch, tracking_code) + + # Additional info can be added + + c.save() + +create_shipping_label("shipping.pdf", "1Z999AA10123456784", "12345") +``` + +## Barcode Selection Guide + +**Choose Code128 when:** +- General purpose encoding +- Need to encode numbers and letters +- Want compact size +- Widely supported + +**Choose Code39 when:** +- Older systems require it +- Don't need lowercase letters +- Want maximum compatibility + +**Choose QR Code when:** +- Need to encode URLs +- Want mobile device scanning +- Need high data capacity +- Want error correction + +**Choose EAN/UPC when:** +- Retail product identification +- Need industry-standard format +- Global distribution + +**Choose Data Matrix when:** +- Very limited space +- Small items (PCB, electronics) +- Need 2D compact format + +## Best Practices + +1. **Test scanning** early with actual barcode scanners/readers +2. **Add quiet zones** (white space) around barcodes - set `quiet=1` +3. **Choose appropriate height** - taller barcodes are easier to scan +4. **Include human-readable text** below barcode for manual entry +5. **Use Code128** as default for general purpose - it's compact and versatile +6. **For URLs, use QR codes** - much easier for mobile users +7. **Check barcode standards** for your industry (retail uses EAN/UPC) +8. **Test print quality** - low DPI can make barcodes unscannable +9. **Validate data** before encoding - wrong check digits cause issues +10. **Consider error correction** for QR codes - use 'M' or 'H' for important data diff --git a/scientific-packages/reportlab/references/canvas_api.md b/scientific-packages/reportlab/references/canvas_api.md new file mode 100644 index 0000000..a930651 --- /dev/null +++ b/scientific-packages/reportlab/references/canvas_api.md @@ -0,0 +1,241 @@ +# Canvas API Reference + +The Canvas API provides low-level, precise control over PDF generation using coordinate-based drawing. + +## Coordinate System + +- Origin (0, 0) is at the **lower-left corner** (not top-left like web graphics) +- X-axis points right, Y-axis points upward +- Units are in points (72 points = 1 inch) +- Default page size is A4; explicitly specify page size for consistency + +## Basic Setup + +```python +from reportlab.pdfgen import canvas +from reportlab.lib.pagesizes import letter, A4 +from reportlab.lib.units import inch + +# Create canvas +c = canvas.Canvas("output.pdf", pagesize=letter) + +# Get page dimensions +width, height = letter + +# Draw content +c.drawString(100, 100, "Hello World") + +# Finish page and save +c.showPage() # Complete current page +c.save() # Write PDF to disk +``` + +## Text Drawing + +### Basic String Methods +```python +# Basic text placement +c.drawString(x, y, text) # Left-aligned at x, y +c.drawRightString(x, y, text) # Right-aligned at x, y +c.drawCentredString(x, y, text) # Center-aligned at x, y + +# Font control +c.setFont(fontname, size) # e.g., "Helvetica", 12 +c.setFillColor(color) # Text color +``` + +### Text Objects (Advanced) +For complex text operations with multiple lines and precise control: + +```python +t = c.beginText(x, y) +t.setFont("Times-Roman", 14) +t.textLine("First line") +t.textLine("Second line") +t.setTextOrigin(x, y) # Reset position +c.drawText(t) +``` + +## Drawing Primitives + +### Lines +```python +c.line(x1, y1, x2, y2) # Single line +c.lines([(x1,y1,x2,y2), (x3,y3,x4,y4)]) # Multiple lines +c.grid(xlist, ylist) # Grid from coordinate lists +``` + +### Shapes +```python +c.rect(x, y, width, height, stroke=1, fill=0) +c.roundRect(x, y, width, height, radius, stroke=1, fill=0) +c.circle(x_ctr, y_ctr, r, stroke=1, fill=0) +c.ellipse(x1, y1, x2, y2, stroke=1, fill=0) +c.wedge(x, y, radius, startAng, extent, stroke=1, fill=0) +``` + +### Bezier Curves +```python +c.bezier(x1, y1, x2, y2, x3, y3, x4, y4) +``` + +## Path Objects + +For complex shapes, use path objects: + +```python +p = c.beginPath() +p.moveTo(x, y) # Move without drawing +p.lineTo(x, y) # Draw line to point +p.curveTo(x1, y1, x2, y2, x3, y3) # Bezier curve +p.arc(x1, y1, x2, y2, startAng, extent) +p.arcTo(x1, y1, x2, y2, startAng, extent) +p.close() # Close path to start point + +# Draw the path +c.drawPath(p, stroke=1, fill=0) +``` + +## Colors + +### RGB (Screen Display) +```python +from reportlab.lib.colors import red, blue, Color + +c.setFillColorRGB(r, g, b) # r, g, b are 0-1 +c.setStrokeColorRGB(r, g, b) +c.setFillColor(red) # Named colors +c.setStrokeColor(blue) + +# Custom with alpha transparency +c.setFillColor(Color(0.5, 0, 0, alpha=0.5)) +``` + +### CMYK (Professional Printing) +```python +from reportlab.lib.colors import CMYKColor, PCMYKColor + +c.setFillColorCMYK(c, m, y, k) # 0-1 range +c.setStrokeColorCMYK(c, m, y, k) + +# Integer percentages (0-100) +c.setFillColor(PCMYKColor(100, 50, 0, 0)) +``` + +## Line Styling + +```python +c.setLineWidth(width) # Thickness in points +c.setLineCap(mode) # 0=butt, 1=round, 2=square +c.setLineJoin(mode) # 0=miter, 1=round, 2=bevel +c.setDash(array, phase) # e.g., [3, 3] for dotted line +``` + +## Coordinate Transformations + +**IMPORTANT:** Transformations are incremental and cumulative. + +```python +# Translation (move origin) +c.translate(dx, dy) + +# Rotation (in degrees, counterclockwise) +c.rotate(theta) + +# Scaling +c.scale(xscale, yscale) + +# Skewing +c.skew(alpha, beta) +``` + +### State Management +```python +# Save current graphics state +c.saveState() + +# ... apply transformations and draw ... + +# Restore previous state +c.restoreState() +``` + +**Note:** State cannot be preserved across `showPage()` calls. + +## Images + +```python +from reportlab.lib.utils import ImageReader + +# Preferred method (with caching) +c.drawImage(image_source, x, y, width=None, height=None, + mask=None, preserveAspectRatio=False) + +# image_source can be: +# - Filename string +# - PIL Image object +# - ImageReader object + +# For transparency, specify RGB mask range +c.drawImage("logo.png", 100, 500, mask=[255, 255, 255, 255, 255, 255]) + +# Inline (inefficient, no caching) +c.drawInlineImage(image_source, x, y, width=None, height=None) +``` + +## Page Management + +```python +# Complete current page +c.showPage() + +# Set page size for next page +c.setPageSize(size) # e.g., letter, A4 + +# Page compression (smaller files, slower generation) +c = canvas.Canvas("output.pdf", pageCompression=1) +``` + +## Common Patterns + +### Margins and Layout +```python +from reportlab.lib.units import inch +from reportlab.lib.pagesizes import letter + +width, height = letter +margin = inch + +# Draw within margins +content_width = width - 2*margin +content_height = height - 2*margin + +# Text at top margin +c.drawString(margin, height - margin, "Header") + +# Text at bottom margin +c.drawString(margin, margin, "Footer") +``` + +### Headers and Footers +```python +def draw_header_footer(c, width, height): + c.saveState() + c.setFont("Helvetica", 9) + c.drawString(inch, height - 0.5*inch, "Company Name") + c.drawRightString(width - inch, 0.5*inch, f"Page {c.getPageNumber()}") + c.restoreState() + +# Call on each page +draw_header_footer(c, width, height) +c.showPage() +``` + +## Best Practices + +1. **Always specify page size** - Different platforms have different defaults +2. **Use variables for measurements** - `margin = inch` instead of hardcoded values +3. **Match saveState/restoreState** - Always balance these calls +4. **Apply transformations externally** for engineering drawings to prevent line width scaling +5. **Use drawImage over drawInlineImage** for better performance with repeated images +6. **Draw from bottom-up** - Remember Y-axis points upward diff --git a/scientific-packages/reportlab/references/charts_reference.md b/scientific-packages/reportlab/references/charts_reference.md new file mode 100644 index 0000000..d33981f --- /dev/null +++ b/scientific-packages/reportlab/references/charts_reference.md @@ -0,0 +1,624 @@ +# Charts and Graphics Reference + +Comprehensive guide to creating charts and data visualizations in ReportLab. + +## Graphics Architecture + +ReportLab's graphics system provides platform-independent drawing: + +- **Drawings** - Container for shapes and charts +- **Shapes** - Primitives (rectangles, circles, lines, polygons, paths) +- **Renderers** - Convert to PDF, PostScript, SVG, or bitmaps (PNG, GIF, JPG) +- **Coordinate System** - Y-axis points upward (like PDF, unlike web graphics) + +## Quick Start + +```python +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.graphics import renderPDF + +# Create drawing (canvas for chart) +drawing = Drawing(400, 200) + +# Create chart +chart = VerticalBarChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 125 +chart.data = [[100, 150, 130, 180]] +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Add chart to drawing +drawing.add(chart) + +# Render to PDF +renderPDF.drawToFile(drawing, 'chart.pdf', 'Chart Title') + +# Or add as flowable to Platypus document +story.append(drawing) +``` + +## Available Chart Types + +### Bar Charts + +```python +from reportlab.graphics.charts.barcharts import ( + VerticalBarChart, + HorizontalBarChart, +) + +# Vertical bar chart +chart = VerticalBarChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Single series +chart.data = [[100, 150, 130, 180, 140]] + +# Multiple series (grouped bars) +chart.data = [ + [100, 150, 130, 180], # Series 1 + [80, 120, 110, 160], # Series 2 +] + +# Categories +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Colors for each series +chart.bars[0].fillColor = colors.blue +chart.bars[1].fillColor = colors.red + +# Bar spacing +chart.barWidth = 10 +chart.groupSpacing = 10 +chart.barSpacing = 2 +``` + +### Stacked Bar Charts + +```python +from reportlab.graphics.charts.barcharts import VerticalBarChart + +chart = VerticalBarChart() +# ... set position and size ... + +chart.data = [ + [100, 150, 130, 180], # Bottom layer + [50, 70, 60, 90], # Top layer +] +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Enable stacking +chart.barLabelFormat = 'values' +chart.valueAxis.visible = 1 +``` + +### Horizontal Bar Charts + +```python +from reportlab.graphics.charts.barcharts import HorizontalBarChart + +chart = HorizontalBarChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +chart.data = [[100, 150, 130, 180]] +chart.categoryAxis.categoryNames = ['Product A', 'Product B', 'Product C', 'Product D'] + +# Horizontal charts use valueAxis horizontally +chart.valueAxis.valueMin = 0 +chart.valueAxis.valueMax = 200 +``` + +### Line Charts + +```python +from reportlab.graphics.charts.linecharts import HorizontalLineChart + +chart = HorizontalLineChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Multiple lines +chart.data = [ + [100, 150, 130, 180, 140], # Line 1 + [80, 120, 110, 160, 130], # Line 2 +] + +chart.categoryAxis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr', 'May'] + +# Line styling +chart.lines[0].strokeColor = colors.blue +chart.lines[0].strokeWidth = 2 +chart.lines[1].strokeColor = colors.red +chart.lines[1].strokeWidth = 2 + +# Show/hide points +chart.lines[0].symbol = None # No symbols +# Or use symbols from makeMarker() +``` + +### Line Plots (X-Y Plots) + +```python +from reportlab.graphics.charts.lineplots import LinePlot + +chart = LinePlot() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Data as (x, y) tuples +chart.data = [ + [(0, 0), (1, 1), (2, 4), (3, 9), (4, 16)], # y = x^2 + [(0, 0), (1, 2), (2, 4), (3, 6), (4, 8)], # y = 2x +] + +# Both axes are value axes (not category) +chart.xValueAxis.valueMin = 0 +chart.xValueAxis.valueMax = 5 +chart.yValueAxis.valueMin = 0 +chart.yValueAxis.valueMax = 20 + +# Line styling +chart.lines[0].strokeColor = colors.blue +chart.lines[1].strokeColor = colors.red +``` + +### Pie Charts + +```python +from reportlab.graphics.charts.piecharts import Pie + +chart = Pie() +chart.x = 100 +chart.y = 50 +chart.width = 200 +chart.height = 200 + +chart.data = [25, 35, 20, 20] +chart.labels = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Slice colors +chart.slices[0].fillColor = colors.blue +chart.slices[1].fillColor = colors.red +chart.slices[2].fillColor = colors.green +chart.slices[3].fillColor = colors.yellow + +# Pop out a slice +chart.slices[1].popout = 10 + +# Label positioning +chart.slices.strokeColor = colors.white +chart.slices.strokeWidth = 2 +``` + +### Pie Chart with Side Labels + +```python +from reportlab.graphics.charts.piecharts import Pie + +chart = Pie() +# ... set position, data, labels ... + +# Side label mode (labels in columns beside pie) +chart.sideLabels = 1 +chart.sideLabelsOffset = 0.1 # Distance from pie + +# Simple labels (not fancy layout) +chart.simpleLabels = 1 +``` + +### Area Charts + +```python +from reportlab.graphics.charts.areacharts import HorizontalAreaChart + +chart = HorizontalAreaChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Areas stack on top of each other +chart.data = [ + [100, 150, 130, 180], # Bottom area + [50, 70, 60, 90], # Top area +] + +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Area colors +chart.strands[0].fillColor = colors.lightblue +chart.strands[1].fillColor = colors.pink +``` + +### Scatter Charts + +```python +from reportlab.graphics.charts.lineplots import ScatterPlot + +chart = ScatterPlot() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Data points +chart.data = [ + [(1, 2), (2, 3), (3, 5), (4, 4), (5, 6)], # Series 1 + [(1, 1), (2, 2), (3, 3), (4, 3), (5, 4)], # Series 2 +] + +# Hide lines, show points only +chart.lines[0].strokeColor = None +chart.lines[1].strokeColor = None + +# Marker symbols +from reportlab.graphics.widgets.markers import makeMarker +chart.lines[0].symbol = makeMarker('Circle') +chart.lines[1].symbol = makeMarker('Square') +``` + +## Axes Configuration + +### Category Axis (XCategoryAxis) + +For categorical data (labels, not numbers): + +```python +# Access via chart +axis = chart.categoryAxis + +# Labels +axis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr'] + +# Label angle (for long labels) +axis.labels.angle = 45 +axis.labels.dx = 0 +axis.labels.dy = -5 + +# Label formatting +axis.labels.fontSize = 10 +axis.labels.fontName = 'Helvetica' + +# Visibility +axis.visible = 1 +``` + +### Value Axis (YValueAxis) + +For numeric data: + +```python +# Access via chart +axis = chart.valueAxis + +# Range +axis.valueMin = 0 +axis.valueMax = 200 +axis.valueStep = 50 # Tick interval + +# Or auto-configure +axis.valueSteps = [0, 50, 100, 150, 200] # Explicit steps + +# Label formatting +axis.labels.fontSize = 10 +axis.labelTextFormat = '%d%%' # Add percentage sign + +# Grid lines +axis.strokeWidth = 1 +axis.strokeColor = colors.black +``` + +## Styling and Customization + +### Colors + +```python +from reportlab.lib import colors + +# Named colors +colors.blue, colors.red, colors.green, colors.yellow + +# RGB +colors.Color(0.5, 0.5, 0.5) # Grey + +# With alpha +colors.Color(1, 0, 0, alpha=0.5) # Semi-transparent red + +# Hex colors +colors.HexColor('#FF5733') +``` + +### Line Styling + +```python +# For line charts +chart.lines[0].strokeColor = colors.blue +chart.lines[0].strokeWidth = 2 +chart.lines[0].strokeDashArray = [2, 2] # Dashed line +``` + +### Bar Labels + +```python +# Show values on bars +chart.barLabels.nudge = 5 # Offset from bar top +chart.barLabels.fontSize = 8 +chart.barLabelFormat = '%d' # Number format + +# For negative values +chart.barLabels.dy = -5 # Position below bar +``` + +## Legends + +Charts can have associated legends: + +```python +from reportlab.graphics.charts.legends import Legend + +# Create legend +legend = Legend() +legend.x = 350 +legend.y = 150 +legend.columnMaximum = 10 + +# Link to chart (share colors) +legend.colorNamePairs = [ + (chart.bars[0].fillColor, 'Series 1'), + (chart.bars[1].fillColor, 'Series 2'), +] + +# Add to drawing +drawing.add(legend) +``` + +## Drawing Shapes + +### Basic Shapes + +```python +from reportlab.graphics.shapes import ( + Drawing, Rect, Circle, Ellipse, Line, Polygon, String +) +from reportlab.lib import colors + +drawing = Drawing(400, 200) + +# Rectangle +rect = Rect(50, 50, 100, 50) +rect.fillColor = colors.blue +rect.strokeColor = colors.black +rect.strokeWidth = 1 +drawing.add(rect) + +# Circle +circle = Circle(200, 100, 30) +circle.fillColor = colors.red +drawing.add(circle) + +# Line +line = Line(50, 150, 350, 150) +line.strokeColor = colors.black +line.strokeWidth = 2 +drawing.add(line) + +# Text +text = String(50, 175, "Label Text") +text.fontSize = 12 +text.fontName = 'Helvetica' +drawing.add(text) +``` + +### Paths (Complex Shapes) + +```python +from reportlab.graphics.shapes import Path + +path = Path() +path.moveTo(50, 50) +path.lineTo(100, 100) +path.curveTo(120, 120, 140, 100, 150, 50) +path.closePath() + +path.fillColor = colors.lightblue +path.strokeColor = colors.blue +path.strokeWidth = 2 + +drawing.add(path) +``` + +## Rendering Options + +### Render to PDF + +```python +from reportlab.graphics import renderPDF + +# Direct to file +renderPDF.drawToFile(drawing, 'output.pdf', 'Chart Title') + +# As flowable in Platypus +story.append(drawing) +``` + +### Render to Image + +```python +from reportlab.graphics import renderPM + +# PNG +renderPM.drawToFile(drawing, 'chart.png', fmt='PNG') + +# GIF +renderPM.drawToFile(drawing, 'chart.gif', fmt='GIF') + +# JPG +renderPM.drawToFile(drawing, 'chart.jpg', fmt='JPG') + +# With specific DPI +renderPM.drawToFile(drawing, 'chart.png', fmt='PNG', dpi=150) +``` + +### Render to SVG + +```python +from reportlab.graphics import renderSVG + +renderSVG.drawToFile(drawing, 'chart.svg') +``` + +## Advanced Customization + +### Inspect Properties + +```python +# List all properties +print(chart.getProperties()) + +# Dump properties (for debugging) +chart.dumpProperties() + +# Set multiple properties +chart.setProperties({ + 'width': 400, + 'height': 200, + 'data': [[100, 150, 130]], +}) +``` + +### Custom Colors for Series + +```python +# Define color scheme +from reportlab.lib.colors import PCMYKColor + +colors_list = [ + PCMYKColor(100, 67, 0, 23), # Blue + PCMYKColor(0, 100, 100, 0), # Red + PCMYKColor(66, 13, 0, 22), # Green +] + +# Apply to chart +for i, color in enumerate(colors_list): + chart.bars[i].fillColor = color +``` + +## Complete Examples + +### Sales Report Bar Chart + +```python +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.graphics.charts.legends import Legend +from reportlab.lib import colors + +drawing = Drawing(400, 250) + +# Create chart +chart = VerticalBarChart() +chart.x = 50 +chart.y = 50 +chart.width = 300 +chart.height = 150 + +# Data +chart.data = [ + [120, 150, 180, 200], # 2023 + [100, 130, 160, 190], # 2022 +] +chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4'] + +# Styling +chart.bars[0].fillColor = colors.HexColor('#3498db') +chart.bars[1].fillColor = colors.HexColor('#e74c3c') +chart.valueAxis.valueMin = 0 +chart.valueAxis.valueMax = 250 +chart.categoryAxis.labels.fontSize = 10 +chart.valueAxis.labels.fontSize = 10 + +# Add legend +legend = Legend() +legend.x = 325 +legend.y = 200 +legend.columnMaximum = 2 +legend.colorNamePairs = [ + (chart.bars[0].fillColor, '2023'), + (chart.bars[1].fillColor, '2022'), +] + +drawing.add(chart) +drawing.add(legend) + +# Add to story or save +story.append(drawing) +``` + +### Multi-Line Trend Chart + +```python +from reportlab.graphics.shapes import Drawing +from reportlab.graphics.charts.linecharts import HorizontalLineChart +from reportlab.lib import colors + +drawing = Drawing(400, 250) + +chart = HorizontalLineChart() +chart.x = 50 +chart.y = 50 +chart.width = 320 +chart.height = 170 + +# Data +chart.data = [ + [10, 15, 12, 18, 20, 25], # Product A + [8, 10, 14, 16, 18, 22], # Product B + [12, 11, 13, 15, 17, 19], # Product C +] + +chart.categoryAxis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'] + +# Line styling +chart.lines[0].strokeColor = colors.blue +chart.lines[0].strokeWidth = 2 +chart.lines[1].strokeColor = colors.red +chart.lines[1].strokeWidth = 2 +chart.lines[2].strokeColor = colors.green +chart.lines[2].strokeWidth = 2 + +# Axes +chart.valueAxis.valueMin = 0 +chart.valueAxis.valueMax = 30 +chart.categoryAxis.labels.angle = 0 +chart.categoryAxis.labels.fontSize = 9 +chart.valueAxis.labels.fontSize = 9 + +drawing.add(chart) +story.append(drawing) +``` + +## Best Practices + +1. **Set explicit dimensions** for Drawing to ensure consistent sizing +2. **Position charts** with enough margin (x, y at least 30-50 from edge) +3. **Use consistent color schemes** throughout document +4. **Set valueMin and valueMax** explicitly for consistent scales +5. **Test with realistic data** to ensure labels fit and don't overlap +6. **Add legends** for multi-series charts +7. **Angle category labels** if they're long (45° works well) +8. **Keep it simple** - fewer data series are easier to read +9. **Use appropriate chart types** - bars for comparisons, lines for trends, pies for proportions +10. **Consider colorblind-friendly palettes** - avoid red/green combinations diff --git a/scientific-packages/reportlab/references/pdf_features.md b/scientific-packages/reportlab/references/pdf_features.md new file mode 100644 index 0000000..223c7c1 --- /dev/null +++ b/scientific-packages/reportlab/references/pdf_features.md @@ -0,0 +1,561 @@ +# PDF Features Reference + +Advanced PDF capabilities: links, bookmarks, forms, encryption, and metadata. + +## Document Metadata + +Set PDF document properties viewable in PDF readers. + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("output.pdf") + +# Set metadata +c.setAuthor("John Doe") +c.setTitle("Annual Report 2024") +c.setSubject("Financial Analysis") +c.setKeywords("finance, annual, report, 2024") +c.setCreator("MyApp v1.0") + +# ... draw content ... + +c.save() +``` + +With Platypus: + +```python +from reportlab.platypus import SimpleDocTemplate + +doc = SimpleDocTemplate( + "output.pdf", + title="Annual Report 2024", + author="John Doe", + subject="Financial Analysis", +) + +doc.build(story) +``` + +## Bookmarks and Destinations + +Create internal navigation structure. + +### Simple Bookmarks + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("output.pdf") + +# Create bookmark for current page +c.bookmarkPage("intro") # Internal key +c.addOutlineEntry("Introduction", "intro", level=0) + +c.showPage() + +# Another bookmark +c.bookmarkPage("chapter1") +c.addOutlineEntry("Chapter 1", "chapter1", level=0) + +# Sub-sections +c.bookmarkPage("section1_1") +c.addOutlineEntry("Section 1.1", "section1_1", level=1) # Nested + +c.save() +``` + +### Bookmark Levels + +```python +# Create hierarchical outline +c.bookmarkPage("ch1") +c.addOutlineEntry("Chapter 1", "ch1", level=0) + +c.bookmarkPage("ch1_s1") +c.addOutlineEntry("Section 1.1", "ch1_s1", level=1) + +c.bookmarkPage("ch1_s1_1") +c.addOutlineEntry("Subsection 1.1.1", "ch1_s1_1", level=2) + +c.bookmarkPage("ch2") +c.addOutlineEntry("Chapter 2", "ch2", level=0) +``` + +### Destination Fit Modes + +Control how the page displays when navigating: + +```python +# bookmarkPage with fit mode +c.bookmarkPage( + key="chapter1", + fit="Fit" # Fit entire page in window +) + +# Or use bookmarkHorizontalAbsolute +c.bookmarkHorizontalAbsolute(key="section", top=500) + +# Available fit modes: +# "Fit" - Fit whole page +# "FitH" - Fit horizontally +# "FitV" - Fit vertically +# "FitR" - Fit rectangle +# "XYZ" - Specific position and zoom +``` + +## Hyperlinks + +### External Links + +```python +from reportlab.pdfgen import canvas +from reportlab.lib.units import inch + +c = canvas.Canvas("output.pdf") + +# Draw link rectangle +c.linkURL( + "https://www.example.com", + rect=(1*inch, 5*inch, 3*inch, 5.5*inch), # (x1, y1, x2, y2) + relative=0, # 0 for absolute positioning + thickness=1, + color=(0, 0, 1), # Blue + dashArray=None +) + +# Draw text over link area +c.setFillColorRGB(0, 0, 1) # Blue text +c.drawString(1*inch, 5.2*inch, "Click here to visit example.com") + +c.save() +``` + +### Internal Links + +Link to bookmarked locations within the document: + +```python +# Create destination +c.bookmarkPage("target_section") + +# Later, create link to that destination +c.linkRect( + "Link Text", + "target_section", # Bookmark key + rect=(1*inch, 3*inch, 2*inch, 3.2*inch), + relative=0 +) +``` + +### Links in Paragraphs + +For Platypus documents: + +```python +from reportlab.platypus import Paragraph + +# External link +text = 'Visit our website' +para = Paragraph(text, style) + +# Internal link (to anchor) +text = 'Go to Section 1' +para1 = Paragraph(text, style) + +# Create anchor +text = 'Section 1 Heading' +para2 = Paragraph(text, heading_style) + +story.append(para1) +story.append(para2) +``` + +## Interactive Forms + +Create fillable PDF forms. + +### Text Fields + +```python +from reportlab.pdfgen import canvas +from reportlab.pdfbase import pdfform +from reportlab.lib.colors import black, white + +c = canvas.Canvas("form.pdf") + +# Create text field +c.acroForm.textfield( + name="name", + tooltip="Enter your name", + x=100, + y=700, + width=200, + height=20, + borderColor=black, + fillColor=white, + textColor=black, + forceBorder=True, + fontSize=12, + maxlen=100, # Maximum character length +) + +# Label +c.drawString(100, 725, "Name:") + +c.save() +``` + +### Checkboxes + +```python +# Create checkbox +c.acroForm.checkbox( + name="agree", + tooltip="I agree to terms", + x=100, + y=650, + size=20, + buttonStyle='check', # 'check', 'circle', 'cross', 'diamond', 'square', 'star' + borderColor=black, + fillColor=white, + textColor=black, + forceBorder=True, + checked=False, # Initial state +) + +c.drawString(130, 655, "I agree to the terms and conditions") +``` + +### Radio Buttons + +```python +# Radio button group - only one can be selected +c.acroForm.radio( + name="payment", # Same name for group + tooltip="Credit Card", + value="credit", # Value when selected + x=100, + y=600, + size=15, + selected=False, +) +c.drawString(125, 603, "Credit Card") + +c.acroForm.radio( + name="payment", # Same name + tooltip="PayPal", + value="paypal", + x=100, + y=580, + size=15, + selected=False, +) +c.drawString(125, 583, "PayPal") +``` + +### List Boxes + +```python +# Listbox with multiple options +c.acroForm.listbox( + name="country", + tooltip="Select your country", + value="US", # Default selected + x=100, + y=500, + width=150, + height=80, + borderColor=black, + fillColor=white, + textColor=black, + forceBorder=True, + options=[ + ("United States", "US"), + ("Canada", "CA"), + ("Mexico", "MX"), + ("Other", "OTHER"), + ], # List of (label, value) tuples + multiple=False, # Allow multiple selections +) +``` + +### Choice (Dropdown) + +```python +# Dropdown menu +c.acroForm.choice( + name="state", + tooltip="Select state", + value="CA", + x=100, + y=450, + width=150, + height=20, + borderColor=black, + fillColor=white, + textColor=black, + forceBorder=True, + options=[ + ("California", "CA"), + ("New York", "NY"), + ("Texas", "TX"), + ], +) +``` + +### Complete Form Example + +```python +from reportlab.pdfgen import canvas +from reportlab.lib.pagesizes import letter +from reportlab.lib.colors import black, white, lightgrey +from reportlab.lib.units import inch + +def create_registration_form(filename): + c = canvas.Canvas(filename, pagesize=letter) + c.setFont("Helvetica-Bold", 16) + c.drawString(inch, 10*inch, "Registration Form") + + y = 9*inch + c.setFont("Helvetica", 12) + + # Name field + c.drawString(inch, y, "Full Name:") + c.acroForm.textfield( + name="fullname", + x=2*inch, y=y-5, width=4*inch, height=20, + borderColor=black, fillColor=lightgrey, forceBorder=True + ) + + # Email field + y -= 0.5*inch + c.drawString(inch, y, "Email:") + c.acroForm.textfield( + name="email", + x=2*inch, y=y-5, width=4*inch, height=20, + borderColor=black, fillColor=lightgrey, forceBorder=True + ) + + # Age dropdown + y -= 0.5*inch + c.drawString(inch, y, "Age Group:") + c.acroForm.choice( + name="age_group", + x=2*inch, y=y-5, width=2*inch, height=20, + borderColor=black, fillColor=lightgrey, forceBorder=True, + options=[("18-25", "18-25"), ("26-35", "26-35"), + ("36-50", "36-50"), ("51+", "51+")] + ) + + # Newsletter checkbox + y -= 0.5*inch + c.acroForm.checkbox( + name="newsletter", + x=inch, y=y-5, size=15, + buttonStyle='check', borderColor=black, forceBorder=True + ) + c.drawString(inch + 25, y, "Subscribe to newsletter") + + c.save() + +create_registration_form("registration.pdf") +``` + +## Encryption and Security + +Protect PDFs with passwords and permissions. + +### Basic Encryption + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("secure.pdf") + +# Encrypt with user password +c.encrypt( + userPassword="user123", # Password to open + ownerPassword="owner456", # Password to change permissions + canPrint=1, # Allow printing + canModify=0, # Disallow modifications + canCopy=1, # Allow text copying + canAnnotate=0, # Disallow annotations + strength=128, # 40 or 128 bit encryption +) + +# ... draw content ... + +c.save() +``` + +### Permission Settings + +```python +c.encrypt( + userPassword="user123", + ownerPassword="owner456", + canPrint=1, # 1 = allow, 0 = deny + canModify=0, # Prevent content modification + canCopy=1, # Allow text/graphics copying + canAnnotate=0, # Prevent comments/annotations + strength=128, # Use 128-bit encryption +) +``` + +### Advanced Encryption + +```python +from reportlab.lib.pdfencrypt import StandardEncryption + +# Create encryption object +encrypt = StandardEncryption( + userPassword="user123", + ownerPassword="owner456", + canPrint=1, + canModify=0, + canCopy=1, + canAnnotate=1, + strength=128, +) + +# Use with canvas +c = canvas.Canvas("secure.pdf") +c._doc.encrypt = encrypt + +# ... draw content ... + +c.save() +``` + +### Platypus with Encryption + +```python +from reportlab.platypus import SimpleDocTemplate + +doc = SimpleDocTemplate("secure.pdf") + +# Set encryption +doc.encrypt = True +doc.canPrint = 1 +doc.canModify = 0 + +# Or use encrypt() method +doc.encrypt = encrypt_object + +doc.build(story) +``` + +## Page Transitions + +Add visual effects for presentations. + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("presentation.pdf") + +# Set transition for current page +c.setPageTransition( + effectname="Wipe", # Transition effect + duration=1, # Duration in seconds + direction=0 # Direction (effect-specific) +) + +# Available effects: +# "Split", "Blinds", "Box", "Wipe", "Dissolve", +# "Glitter", "R" (Replace), "Fly", "Push", "Cover", +# "Uncover", "Fade" + +# Direction values (effect-dependent): +# 0, 90, 180, 270 for most directional effects + +# Example: Slide with fade transition +c.setFont("Helvetica-Bold", 24) +c.drawString(100, 400, "Slide 1") +c.setPageTransition("Fade", 0.5) +c.showPage() + +c.drawString(100, 400, "Slide 2") +c.setPageTransition("Wipe", 1, 90) +c.showPage() + +c.save() +``` + +## PDF/A Compliance + +Create archival-quality PDFs. + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("pdfa.pdf") + +# Enable PDF/A-1b compliance +c.setPageCompression(0) # PDF/A requires uncompressed +# Note: Full PDF/A requires additional XMP metadata +# This is simplified - full compliance needs more setup + +# ... draw content ... + +c.save() +``` + +## Compression + +Control file size vs generation speed. + +```python +# Enable page compression +c = canvas.Canvas("output.pdf", pageCompression=1) + +# Compression reduces file size but slows generation +# 0 = no compression (faster, larger files) +# 1 = compression (slower, smaller files) +``` + +## Forms and XObjects + +Reusable graphics elements. + +```python +from reportlab.pdfgen import canvas + +c = canvas.Canvas("output.pdf") + +# Begin form (reusable object) +c.beginForm("logo") +c.setFillColorRGB(0, 0, 1) +c.rect(0, 0, 100, 50, fill=1) +c.setFillColorRGB(1, 1, 1) +c.drawString(10, 20, "LOGO") +c.endForm() + +# Use form multiple times +c.doForm("logo") # At current position +c.translate(200, 0) +c.doForm("logo") # At translated position +c.translate(200, 0) +c.doForm("logo") + +c.save() + +# Benefits: Smaller file size, faster rendering +``` + +## Best Practices + +1. **Always set metadata** for professional documents +2. **Use bookmarks** for documents > 10 pages +3. **Make links visually distinct** (blue, underlined) +4. **Test forms** in multiple PDF readers (behavior varies) +5. **Use strong encryption (128-bit)** for sensitive data +6. **Set both user and owner passwords** for full security +7. **Enable printing** unless specifically restricted +8. **Test page transitions** - some readers don't support all effects +9. **Use meaningful bookmark titles** for navigation +10. **Consider PDF/A** for long-term archival needs +11. **Validate form field names** - must be unique and valid identifiers +12. **Add tooltips** to form fields for better UX diff --git a/scientific-packages/reportlab/references/platypus_guide.md b/scientific-packages/reportlab/references/platypus_guide.md new file mode 100644 index 0000000..a7546fc --- /dev/null +++ b/scientific-packages/reportlab/references/platypus_guide.md @@ -0,0 +1,343 @@ +# Platypus Guide - High-Level Page Layout + +Platypus ("Page Layout and Typography Using Scripts") provides high-level document layout for complex, flowing documents with minimal code. + +## Architecture Overview + +Platypus uses a layered design: + +1. **DocTemplates** - Document container with page formatting rules +2. **PageTemplates** - Specifications for different page layouts +3. **Frames** - Regions where content flows +4. **Flowables** - Content elements (paragraphs, tables, images, spacers) +5. **Canvas** - Underlying rendering engine (usually hidden) + +## Quick Start + +```python +from reportlab.lib.pagesizes import letter +from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak +from reportlab.lib.styles import getSampleStyleSheet +from reportlab.lib.units import inch + +# Create document +doc = SimpleDocTemplate("output.pdf", pagesize=letter, + rightMargin=72, leftMargin=72, + topMargin=72, bottomMargin=18) + +# Create story (list of flowables) +story = [] +styles = getSampleStyleSheet() + +# Add content +story.append(Paragraph("Title", styles['Title'])) +story.append(Spacer(1, 0.2*inch)) +story.append(Paragraph("Body text here", styles['BodyText'])) +story.append(PageBreak()) + +# Build PDF +doc.build(story) +``` + +## Core Components + +### DocTemplates + +#### SimpleDocTemplate +Most common template for standard documents: + +```python +doc = SimpleDocTemplate( + filename, + pagesize=letter, + rightMargin=72, # 1 inch = 72 points + leftMargin=72, + topMargin=72, + bottomMargin=18, + title=None, # PDF metadata + author=None, + subject=None +) +``` + +#### BaseDocTemplate (Advanced) +For complex documents with multiple page layouts: + +```python +from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame +from reportlab.lib.pagesizes import letter + +doc = BaseDocTemplate("output.pdf", pagesize=letter) + +# Define frames (content regions) +frame1 = Frame(doc.leftMargin, doc.bottomMargin, + doc.width/2-6, doc.height, id='col1') +frame2 = Frame(doc.leftMargin+doc.width/2+6, doc.bottomMargin, + doc.width/2-6, doc.height, id='col2') + +# Create page template +template = PageTemplate(id='TwoCol', frames=[frame1, frame2]) +doc.addPageTemplates([template]) + +# Build with story +doc.build(story) +``` + +### Frames + +Frames define regions where content flows: + +```python +from reportlab.platypus import Frame + +frame = Frame( + x1, y1, # Lower-left corner + width, height, # Dimensions + leftPadding=6, # Internal padding + bottomPadding=6, + rightPadding=6, + topPadding=6, + id=None, # Optional identifier + showBoundary=0 # 1 to show frame border (debugging) +) +``` + +### PageTemplates + +Define page layouts with frames and optional functions: + +```python +def header_footer(canvas, doc): + """Called on each page for headers/footers""" + canvas.saveState() + canvas.setFont('Helvetica', 9) + canvas.drawString(inch, 0.75*inch, f"Page {doc.page}") + canvas.restoreState() + +template = PageTemplate( + id='Normal', + frames=[frame], + onPage=header_footer, # Function called for each page + onPageEnd=None, + pagesize=letter +) +``` + +## Flowables + +Flowables are content elements that flow through frames. + +### Common Flowables + +```python +from reportlab.platypus import ( + Paragraph, Spacer, PageBreak, FrameBreak, + Image, Table, KeepTogether, CondPageBreak +) + +# Spacer - vertical whitespace +Spacer(width, height) + +# Page break - force new page +PageBreak() + +# Frame break - move to next frame +FrameBreak() + +# Conditional page break - break if less than N space remaining +CondPageBreak(height) + +# Keep together - prevent splitting across pages +KeepTogether([flowable1, flowable2, ...]) +``` + +### Paragraph Flowable +See `text_and_fonts.md` for detailed Paragraph usage. + +```python +from reportlab.platypus import Paragraph +from reportlab.lib.styles import ParagraphStyle + +style = ParagraphStyle( + 'CustomStyle', + fontSize=12, + leading=14, + alignment=0 # 0=left, 1=center, 2=right, 4=justify +) + +para = Paragraph("Text with bold and italic", style) +story.append(para) +``` + +### Image Flowable + +```python +from reportlab.platypus import Image + +# Auto-size to fit +img = Image('photo.jpg') + +# Fixed size +img = Image('photo.jpg', width=2*inch, height=2*inch) + +# Maintain aspect ratio with max width +img = Image('photo.jpg', width=4*inch, height=3*inch, + kind='proportional') + +story.append(img) +``` + +### Table Flowable +See `tables_reference.md` for detailed Table usage. + +```python +from reportlab.platypus import Table + +data = [['Header1', 'Header2'], + ['Row1Col1', 'Row1Col2'], + ['Row2Col1', 'Row2Col2']] + +table = Table(data, colWidths=[2*inch, 2*inch]) +story.append(table) +``` + +## Page Layouts + +### Single Column Document + +```python +doc = SimpleDocTemplate("output.pdf", pagesize=letter) +story = [] +# Add flowables... +doc.build(story) +``` + +### Two-Column Layout + +```python +from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame + +doc = BaseDocTemplate("output.pdf", pagesize=letter) +width, height = letter +margin = inch + +# Two side-by-side frames +frame1 = Frame(margin, margin, width/2 - 1.5*margin, height - 2*margin, id='col1') +frame2 = Frame(width/2 + 0.5*margin, margin, width/2 - 1.5*margin, height - 2*margin, id='col2') + +template = PageTemplate(id='TwoCol', frames=[frame1, frame2]) +doc.addPageTemplates([template]) + +story = [] +# Content flows left column first, then right column +# Add flowables... +doc.build(story) +``` + +### Multiple Page Templates + +```python +from reportlab.platypus import NextPageTemplate + +# Define templates +cover_template = PageTemplate(id='Cover', frames=[cover_frame]) +body_template = PageTemplate(id='Body', frames=[body_frame]) + +doc.addPageTemplates([cover_template, body_template]) + +story = [] +# Cover page content +story.append(Paragraph("Cover", title_style)) +story.append(NextPageTemplate('Body')) # Switch to body template +story.append(PageBreak()) + +# Body content +story.append(Paragraph("Chapter 1", heading_style)) +# ... + +doc.build(story) +``` + +## Headers and Footers + +Headers and footers are added via `onPage` callback functions: + +```python +def header_footer(canvas, doc): + """Draw header and footer on each page""" + canvas.saveState() + + # Header + canvas.setFont('Helvetica-Bold', 12) + canvas.drawCentredString(letter[0]/2, letter[1] - 0.5*inch, + "Document Title") + + # Footer + canvas.setFont('Helvetica', 9) + canvas.drawString(inch, 0.75*inch, "Left Footer") + canvas.drawRightString(letter[0] - inch, 0.75*inch, + f"Page {doc.page}") + + canvas.restoreState() + +# Apply to template +template = PageTemplate(id='Normal', frames=[frame], onPage=header_footer) +``` + +## Table of Contents + +```python +from reportlab.platypus import TableOfContents +from reportlab.lib.styles import ParagraphStyle + +# Create TOC +toc = TableOfContents() +toc.levelStyles = [ + ParagraphStyle(name='TOC1', fontSize=14, leftIndent=0), + ParagraphStyle(name='TOC2', fontSize=12, leftIndent=20), +] + +story = [] +story.append(toc) +story.append(PageBreak()) + +# Add entries +story.append(Paragraph("Chapter 1", heading_style)) +toc.addEntry(0, "Chapter 1", doc.page, 'ch1') + +# Must call build twice for TOC to populate +doc.build(story) +``` + +## Document Properties + +```python +from reportlab.lib.pagesizes import letter, A4 +from reportlab.lib.units import inch, cm, mm + +# Page sizes +letter # US Letter (8.5" x 11") +A4 # ISO A4 (210mm x 297mm) +landscape(letter) # Rotate to landscape + +# Units +inch # 72 points +cm # 28.35 points +mm # 2.835 points + +# Custom page size +custom_size = (6*inch, 9*inch) +``` + +## Best Practices + +1. **Use SimpleDocTemplate** for most documents - it handles common layouts +2. **Build story list** completely before calling `doc.build(story)` +3. **Use Spacer** for vertical spacing instead of empty Paragraphs +4. **Group related content** with KeepTogether to prevent awkward page breaks +5. **Test page breaks** early with realistic content amounts +6. **Use styles consistently** - create style once, reuse throughout document +7. **Set showBoundary=1** on Frames during development to visualize layout +8. **Headers/footers go in onPage** callback, not in story +9. **For long documents**, use BaseDocTemplate with multiple page templates +10. **Build TOC documents twice** to properly populate table of contents diff --git a/scientific-packages/reportlab/references/tables_reference.md b/scientific-packages/reportlab/references/tables_reference.md new file mode 100644 index 0000000..c1ea21f --- /dev/null +++ b/scientific-packages/reportlab/references/tables_reference.md @@ -0,0 +1,442 @@ +# Tables Reference + +Comprehensive guide to creating and styling tables in ReportLab. + +## Basic Table Creation + +```python +from reportlab.platypus import Table, TableStyle +from reportlab.lib import colors + +# Simple data (list of lists or tuples) +data = [ + ['Header 1', 'Header 2', 'Header 3'], + ['Row 1, Col 1', 'Row 1, Col 2', 'Row 1, Col 3'], + ['Row 2, Col 1', 'Row 2, Col 2', 'Row 2, Col 3'], +] + +# Create table +table = Table(data) + +# Add to story +story.append(table) +``` + +## Table Constructor + +```python +table = Table( + data, # Required: list of lists/tuples + colWidths=None, # List of column widths or single value + rowHeights=None, # List of row heights or single value + style=None, # TableStyle object + splitByRow=1, # Split across pages by rows (not columns) + repeatRows=0, # Number of header rows to repeat + repeatCols=0, # Number of header columns to repeat + rowSplitRange=None, # Tuple (start, end) of splittable rows + spaceBefore=None, # Space before table + spaceAfter=None, # Space after table + cornerRadii=None, # [TL, TR, BL, BR] for rounded corners +) +``` + +### Column Widths + +```python +from reportlab.lib.units import inch + +# Equal widths +table = Table(data, colWidths=2*inch) + +# Different widths per column +table = Table(data, colWidths=[1.5*inch, 2*inch, 1*inch]) + +# Auto-calculate widths (default) +table = Table(data) + +# Percentage-based (of available width) +table = Table(data, colWidths=[None, None, None]) # Equal auto-sizing +``` + +## Cell Content Types + +### Text and Newlines + +```python +# Newlines work in cells +data = [ + ['Line 1\nLine 2', 'Single line'], + ['Another\nmulti-line\ncell', 'Text'], +] +``` + +### Paragraph Objects + +```python +from reportlab.platypus import Paragraph +from reportlab.lib.styles import getSampleStyleSheet + +styles = getSampleStyleSheet() + +data = [ + [Paragraph("Formatted bold text", styles['Normal']), + Paragraph("More italic text", styles['Normal'])], +] + +table = Table(data) +``` + +### Images + +```python +from reportlab.platypus import Image + +data = [ + ['Description', Image('logo.png', width=1*inch, height=1*inch)], + ['Product', Image('product.jpg', width=2*inch, height=1.5*inch)], +] + +table = Table(data) +``` + +### Nested Tables + +```python +# Create inner table +inner_data = [['A', 'B'], ['C', 'D']] +inner_table = Table(inner_data) + +# Use in outer table +outer_data = [ + ['Label', inner_table], + ['Other', 'Content'], +] + +outer_table = Table(outer_data) +``` + +## TableStyle + +Styles are applied using command lists: + +```python +from reportlab.platypus import TableStyle +from reportlab.lib import colors + +style = TableStyle([ + # Command format: ('COMMAND', (startcol, startrow), (endcol, endrow), *args) + ('GRID', (0, 0), (-1, -1), 1, colors.black), # Grid over all cells + ('BACKGROUND', (0, 0), (-1, 0), colors.grey), # Header background + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), # Header text color +]) + +table = Table(data) +table.setStyle(style) +``` + +### Cell Coordinate System + +- Columns and rows are 0-indexed: `(col, row)` +- Negative indices count from end: `-1` is last column/row +- `(0, 0)` is top-left cell +- `(-1, -1)` is bottom-right cell + +```python +# Examples: +(0, 0), (2, 0) # First three cells of header row +(0, 1), (-1, -1) # All cells except header +(0, 0), (-1, -1) # Entire table +``` + +## Styling Commands + +### Text Formatting + +```python +style = TableStyle([ + # Font name + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + + # Font size + ('FONTSIZE', (0, 0), (-1, -1), 10), + + # Text color + ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), + ('TEXTCOLOR', (0, 1), (-1, -1), colors.black), + + # Combined font command + ('FONT', (0, 0), (-1, 0), 'Helvetica-Bold', 12), # name, size +]) +``` + +### Alignment + +```python +style = TableStyle([ + # Horizontal alignment: LEFT, CENTER, RIGHT, DECIMAL + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + ('ALIGN', (0, 1), (0, -1), 'LEFT'), # First column left + ('ALIGN', (1, 1), (-1, -1), 'RIGHT'), # Other columns right + + # Vertical alignment: TOP, MIDDLE, BOTTOM + ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), + ('VALIGN', (0, 0), (-1, 0), 'BOTTOM'), # Header bottom-aligned +]) +``` + +### Cell Padding + +```python +style = TableStyle([ + # Individual padding + ('LEFTPADDING', (0, 0), (-1, -1), 12), + ('RIGHTPADDING', (0, 0), (-1, -1), 12), + ('TOPPADDING', (0, 0), (-1, -1), 6), + ('BOTTOMPADDING', (0, 0), (-1, -1), 6), + + # Or set all at once by setting each +]) +``` + +### Background Colors + +```python +style = TableStyle([ + # Solid background + ('BACKGROUND', (0, 0), (-1, 0), colors.blue), + ('BACKGROUND', (0, 1), (-1, -1), colors.lightgrey), + + # Alternating row colors + ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightblue]), + + # Alternating column colors + ('COLBACKGROUNDS', (0, 0), (-1, -1), [colors.white, colors.lightgrey]), +]) +``` + +### Gradient Backgrounds + +```python +from reportlab.lib.colors import Color + +style = TableStyle([ + # Vertical gradient (top to bottom) + ('BACKGROUND', (0, 0), (-1, 0), colors.blue), + ('VERTICALGRADIENT', (0, 0), (-1, 0), + [colors.blue, colors.lightblue]), + + # Horizontal gradient (left to right) + ('HORIZONTALGRADIENT', (0, 1), (-1, 1), + [colors.red, colors.yellow]), +]) +``` + +### Lines and Borders + +```python +style = TableStyle([ + # Complete grid + ('GRID', (0, 0), (-1, -1), 1, colors.black), + + # Box/outline only + ('BOX', (0, 0), (-1, -1), 2, colors.black), + ('OUTLINE', (0, 0), (-1, -1), 2, colors.black), # Same as BOX + + # Inner grid only + ('INNERGRID', (0, 0), (-1, -1), 0.5, colors.grey), + + # Directional lines + ('LINEABOVE', (0, 0), (-1, 0), 2, colors.black), # Header border + ('LINEBELOW', (0, 0), (-1, 0), 1, colors.black), # Header bottom + ('LINEBEFORE', (0, 0), (0, -1), 1, colors.black), # Left border + ('LINEAFTER', (-1, 0), (-1, -1), 1, colors.black), # Right border + + # Thickness and color + ('LINEABOVE', (0, 1), (-1, 1), 0.5, colors.grey), # Thin grey line +]) +``` + +### Cell Spanning + +```python +data = [ + ['Spanning Header', '', ''], # Span will merge these + ['A', 'B', 'C'], + ['D', 'E', 'F'], +] + +style = TableStyle([ + # Span 3 columns in first row + ('SPAN', (0, 0), (2, 0)), + + # Center the spanning cell + ('ALIGN', (0, 0), (2, 0), 'CENTER'), +]) + +table = Table(data) +table.setStyle(style) +``` + +**Important:** Cells that are spanned over must contain empty strings `''`. + +### Advanced Spanning Examples + +```python +# Span multiple rows and columns +data = [ + ['A', 'B', 'B', 'C'], + ['A', 'D', 'E', 'F'], + ['A', 'G', 'H', 'I'], +] + +style = TableStyle([ + # Span rows in column 0 + ('SPAN', (0, 0), (0, 2)), # Merge A cells vertically + + # Span columns in row 0 + ('SPAN', (1, 0), (2, 0)), # Merge B cells horizontally + + ('GRID', (0, 0), (-1, -1), 1, colors.black), +]) +``` + +## Special Commands + +### Rounded Corners + +```python +table = Table(data, cornerRadii=[5, 5, 5, 5]) # [TL, TR, BL, BR] + +# Or in style +style = TableStyle([ + ('ROUNDEDCORNERS', [10, 10, 0, 0]), # Rounded top corners only +]) +``` + +### No Split + +Prevent table from splitting at specific locations: + +```python +style = TableStyle([ + # Don't split between rows 0 and 2 + ('NOSPLIT', (0, 0), (-1, 2)), +]) +``` + +### Split-Specific Styling + +Apply styles only to first or last part when table splits: + +```python +style = TableStyle([ + # Style for first part after split + ('LINEBELOW', (0, 'splitfirst'), (-1, 'splitfirst'), 2, colors.red), + + # Style for last part after split + ('LINEABOVE', (0, 'splitlast'), (-1, 'splitlast'), 2, colors.blue), +]) +``` + +## Repeating Headers + +```python +# Repeat first row on each page +table = Table(data, repeatRows=1) + +# Repeat first 2 rows +table = Table(data, repeatRows=2) +``` + +## Complete Examples + +### Styled Report Table + +```python +from reportlab.platypus import Table, TableStyle +from reportlab.lib import colors +from reportlab.lib.units import inch + +data = [ + ['Product', 'Quantity', 'Unit Price', 'Total'], + ['Widget A', '10', '$5.00', '$50.00'], + ['Widget B', '5', '$12.00', '$60.00'], + ['Widget C', '20', '$3.00', '$60.00'], + ['', '', 'Subtotal:', '$170.00'], +] + +table = Table(data, colWidths=[2.5*inch, 1*inch, 1*inch, 1*inch]) + +style = TableStyle([ + # Header row + ('BACKGROUND', (0, 0), (-1, 0), colors.darkblue), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 12), + ('ALIGN', (0, 0), (-1, 0), 'CENTER'), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + + # Data rows + ('BACKGROUND', (0, 1), (-1, -2), colors.beige), + ('GRID', (0, 0), (-1, -2), 0.5, colors.grey), + ('ALIGN', (1, 1), (-1, -1), 'RIGHT'), + ('ALIGN', (0, 1), (0, -1), 'LEFT'), + + # Total row + ('BACKGROUND', (0, -1), (-1, -1), colors.lightgrey), + ('LINEABOVE', (0, -1), (-1, -1), 2, colors.black), + ('FONTNAME', (2, -1), (-1, -1), 'Helvetica-Bold'), +]) + +table.setStyle(style) +``` + +### Alternating Row Colors + +```python +data = [ + ['Name', 'Age', 'City'], + ['Alice', '30', 'New York'], + ['Bob', '25', 'Boston'], + ['Charlie', '35', 'Chicago'], + ['Diana', '28', 'Denver'], +] + +table = Table(data, colWidths=[2*inch, 1*inch, 1.5*inch]) + +style = TableStyle([ + # Header + ('BACKGROUND', (0, 0), (-1, 0), colors.darkslategray), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + + # Alternating rows (zebra striping) + ('ROWBACKGROUNDS', (0, 1), (-1, -1), + [colors.white, colors.lightgrey]), + + # Borders + ('BOX', (0, 0), (-1, -1), 2, colors.black), + ('LINEBELOW', (0, 0), (-1, 0), 2, colors.black), + + # Padding + ('LEFTPADDING', (0, 0), (-1, -1), 12), + ('RIGHTPADDING', (0, 0), (-1, -1), 12), + ('TOPPADDING', (0, 0), (-1, -1), 6), + ('BOTTOMPADDING', (0, 0), (-1, -1), 6), +]) + +table.setStyle(style) +``` + +## Best Practices + +1. **Set colWidths explicitly** for consistent layout +2. **Use repeatRows** for multi-page tables with headers +3. **Apply padding** for better readability (especially LEFTPADDING and RIGHTPADDING) +4. **Use ROWBACKGROUNDS** for alternating colors instead of styling each row +5. **Put empty strings** in cells that will be spanned +6. **Test page breaks** early with realistic data amounts +7. **Use Paragraph objects** in cells for complex formatted text +8. **Set VALIGN to MIDDLE** for better appearance with varying row heights +9. **Keep tables simple** - complex nested tables are hard to maintain +10. **Use consistent styling** - define once, apply to all tables diff --git a/scientific-packages/reportlab/references/text_and_fonts.md b/scientific-packages/reportlab/references/text_and_fonts.md new file mode 100644 index 0000000..5cdb5db --- /dev/null +++ b/scientific-packages/reportlab/references/text_and_fonts.md @@ -0,0 +1,394 @@ +# Text and Fonts Reference + +Comprehensive guide to text formatting, paragraph styles, and font handling in ReportLab. + +## Text Encoding + +**IMPORTANT:** All text input should be UTF-8 encoded or Python Unicode objects (since ReportLab 2.0). + +```python +# Correct - UTF-8 strings +text = "Hello 世界 مرحبا" +para = Paragraph(text, style) + +# For legacy data, convert first +import codecs +decoded_text = codecs.decode(legacy_bytes, 'latin-1') +``` + +## Paragraph Styles + +### Creating Styles + +```python +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY +from reportlab.lib.colors import black, blue, red +from reportlab.lib.units import inch + +# Get default styles +styles = getSampleStyleSheet() +normal = styles['Normal'] +heading = styles['Heading1'] + +# Create custom style +custom_style = ParagraphStyle( + 'CustomStyle', + parent=normal, # Inherit from another style + + # Font properties + fontName='Helvetica', + fontSize=12, + leading=14, # Line spacing (should be > fontSize) + + # Indentation (in points) + leftIndent=0, + rightIndent=0, + firstLineIndent=0, # Positive = indent, negative = outdent + + # Spacing + spaceBefore=0, + spaceAfter=0, + + # Alignment + alignment=TA_LEFT, # TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY + + # Colors + textColor=black, + backColor=None, # Background color + + # Borders + borderWidth=0, + borderColor=None, + borderPadding=0, + borderRadius=None, + + # Bullets + bulletFontName='Helvetica', + bulletFontSize=12, + bulletIndent=0, + bulletText=None, # Text for bullets (e.g., '•') + + # Advanced + wordWrap=None, # 'CJK' for Asian languages + allowWidows=1, # Allow widow lines + allowOrphans=0, # Prevent orphan lines + endDots=None, # Trailing dots for TOC entries + splitLongWords=1, + hyphenationLang=None, # 'en_US', etc. (requires pyphen) +) + +# Add to stylesheet +styles.add(custom_style) +``` + +### Built-in Styles + +```python +styles = getSampleStyleSheet() + +# Common styles +styles['Normal'] # Body text +styles['BodyText'] # Similar to Normal +styles['Heading1'] # Top-level heading +styles['Heading2'] # Second-level heading +styles['Heading3'] # Third-level heading +styles['Title'] # Document title +styles['Bullet'] # Bulleted list items +styles['Definition'] # Definition text +styles['Code'] # Code samples +``` + +## Paragraph Formatting + +### Basic Paragraph + +```python +from reportlab.platypus import Paragraph + +para = Paragraph("This is a paragraph.", style) +story.append(para) +``` + +### Inline Formatting Tags + +```python +text = """ +Bold text +Italic text +Underlined text +Strikethrough text +Strong (bold) text +""" + +para = Paragraph(text, normal_style) +``` + +### Font Control + +```python +text = """ + +Custom font, size, and color + + +Hex color codes work too +""" + +para = Paragraph(text, normal_style) +``` + +### Superscripts and Subscripts + +```python +text = """ +H2O is water. +E=mc2 or E=mc2 +Xi for subscripted variables +""" + +para = Paragraph(text, normal_style) +``` + +### Greek Letters + +```python +text = """ +alpha, beta, gamma +epsilon, pi, omega +""" + +para = Paragraph(text, normal_style) +``` + +### Links + +```python +# External link +text = 'Click here' + +# Internal link (to bookmark) +text = 'Go to Section 1' + +# Anchor for internal links +text = 'Section 1 Heading' + +para = Paragraph(text, normal_style) +``` + +### Inline Images + +```python +text = """ +Here is an inline image: +""" + +para = Paragraph(text, normal_style) +``` + +### Line Breaks + +```python +text = """ +First line
+Second line
+Third line +""" + +para = Paragraph(text, normal_style) +``` + +## Font Handling + +### Standard Fonts + +ReportLab includes 14 standard PDF fonts (no embedding needed): + +```python +# Helvetica family +'Helvetica' +'Helvetica-Bold' +'Helvetica-Oblique' +'Helvetica-BoldOblique' + +# Times family +'Times-Roman' +'Times-Bold' +'Times-Italic' +'Times-BoldItalic' + +# Courier family +'Courier' +'Courier-Bold' +'Courier-Oblique' +'Courier-BoldOblique' + +# Symbol and Dingbats +'Symbol' +'ZapfDingbats' +``` + +### TrueType Fonts + +```python +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont + +# Register single font +pdfmetrics.registerFont(TTFont('CustomFont', 'CustomFont.ttf')) + +# Use in Canvas +canvas.setFont('CustomFont', 12) + +# Use in Paragraph style +style = ParagraphStyle('Custom', fontName='CustomFont', fontSize=12) +``` + +### Font Families + +Register related fonts as a family for bold/italic support: + +```python +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont +from reportlab.lib.fonts import addMapping + +# Register fonts +pdfmetrics.registerFont(TTFont('Vera', 'Vera.ttf')) +pdfmetrics.registerFont(TTFont('VeraBd', 'VeraBd.ttf')) +pdfmetrics.registerFont(TTFont('VeraIt', 'VeraIt.ttf')) +pdfmetrics.registerFont(TTFont('VeraBI', 'VeraBI.ttf')) + +# Map family (normal, bold, italic, bold-italic) +addMapping('Vera', 0, 0, 'Vera') # normal +addMapping('Vera', 1, 0, 'VeraBd') # bold +addMapping('Vera', 0, 1, 'VeraIt') # italic +addMapping('Vera', 1, 1, 'VeraBI') # bold-italic + +# Now and tags work with this family +style = ParagraphStyle('VeraStyle', fontName='Vera', fontSize=12) +para = Paragraph("Normal Bold Italic Both", style) +``` + +### Font Search Paths + +```python +from reportlab.pdfbase.ttfonts import TTFSearchPath + +# Add custom font directory +TTFSearchPath.append('/path/to/fonts/') + +# Now fonts in this directory can be found by name +pdfmetrics.registerFont(TTFont('MyFont', 'MyFont.ttf')) +``` + +### Asian Language Support + +#### Using Adobe Language Packs (no embedding) + +```python +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.cidfonts import UnicodeCIDFont + +# Register CID fonts +pdfmetrics.registerFont(UnicodeCIDFont('HeiseiMin-W3')) # Japanese +pdfmetrics.registerFont(UnicodeCIDFont('STSong-Light')) # Chinese (Simplified) +pdfmetrics.registerFont(UnicodeCIDFont('MSung-Light')) # Chinese (Traditional) +pdfmetrics.registerFont(UnicodeCIDFont('HYSMyeongJo-Medium')) # Korean + +# Use in styles +style = ParagraphStyle('Japanese', fontName='HeiseiMin-W3', fontSize=12) +para = Paragraph("日本語テキスト", style) +``` + +#### Using TrueType Fonts with Asian Characters + +```python +# Register TrueType font with full Unicode support +pdfmetrics.registerFont(TTFont('SimSun', 'simsun.ttc')) + +style = ParagraphStyle('Chinese', fontName='SimSun', fontSize=12, wordWrap='CJK') +para = Paragraph("中文文本", style) +``` + +Note: Set `wordWrap='CJK'` for proper line breaking in Asian languages. + +## Numbering and Sequences + +Auto-numbering using `` tags: + +```python +# Simple numbering +text = " Introduction" # Outputs: 1 Introduction +text = " Methods" # Outputs: 2 Methods + +# Reset counter +text = "" + +# Formatting templates +text = "Figure " +# Outputs: Figure 1-1, Figure 1-2, etc. + +# Multi-level numbering +text = "Section " +``` + +## Bullets and Lists + +### Using Bullet Style + +```python +bullet_style = ParagraphStyle( + 'Bullet', + parent=normal_style, + leftIndent=20, + bulletIndent=10, + bulletText='•', # Unicode bullet + bulletFontName='Helvetica', +) + +story.append(Paragraph("First item", bullet_style)) +story.append(Paragraph("Second item", bullet_style)) +story.append(Paragraph("Third item", bullet_style)) +``` + +### Custom Bullet Characters + +```python +# Different bullet styles +bulletText='•' # Filled circle +bulletText='◦' # Open circle +bulletText='▪' # Square +bulletText='▸' # Triangle +bulletText='→' # Arrow +bulletText='1.' # Numbers +bulletText='a)' # Letters +``` + +## Text Measurement + +```python +from reportlab.pdfbase.pdfmetrics import stringWidth + +# Measure string width +width = stringWidth("Hello World", "Helvetica", 12) + +# Check if text fits in available width +max_width = 200 +if stringWidth(text, font_name, font_size) > max_width: + # Text is too wide + pass +``` + +## Best Practices + +1. **Always use UTF-8** for text input +2. **Set leading > fontSize** for readability (typically fontSize + 2) +3. **Register font families** for proper bold/italic support +4. **Escape HTML** if displaying user content: use `<` for < and `>` for > +5. **Use getSampleStyleSheet()** as a starting point, don't create all styles from scratch +6. **Test Asian fonts** early if supporting multi-language content +7. **Set wordWrap='CJK'** for Chinese/Japanese/Korean text +8. **Use stringWidth()** to check if text fits before rendering +9. **Define styles once** at document start, reuse throughout +10. **Enable hyphenation** for justified text: `hyphenationLang='en_US'` (requires pyphen package) diff --git a/scientific-packages/reportlab/scripts/quick_document.py b/scientific-packages/reportlab/scripts/quick_document.py new file mode 100644 index 0000000..ed69d97 --- /dev/null +++ b/scientific-packages/reportlab/scripts/quick_document.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +""" +Quick Document Generator - Helper for creating simple ReportLab documents + +This script provides utility functions for quickly creating common document types +without writing boilerplate code. +""" + +from reportlab.lib.pagesizes import letter, A4 +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.lib.units import inch +from reportlab.lib import colors +from reportlab.platypus import ( + SimpleDocTemplate, Paragraph, Spacer, PageBreak, + Table, TableStyle, Image, KeepTogether +) +from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY +from datetime import datetime + + +def create_simple_document(filename, title, author="", content_blocks=None, pagesize=letter): + """ + Create a simple document with title and content blocks. + + Args: + filename: Output PDF filename + title: Document title + author: Document author (optional) + content_blocks: List of dicts with 'type' and 'content' keys + type can be: 'heading', 'paragraph', 'bullet', 'space' + pagesize: Page size (default: letter) + + Example content_blocks: + [ + {'type': 'heading', 'content': 'Introduction'}, + {'type': 'paragraph', 'content': 'This is a paragraph.'}, + {'type': 'bullet', 'content': 'Bullet point item'}, + {'type': 'space', 'height': 0.2}, # height in inches + ] + """ + if content_blocks is None: + content_blocks = [] + + # Create document + doc = SimpleDocTemplate( + filename, + pagesize=pagesize, + rightMargin=72, + leftMargin=72, + topMargin=72, + bottomMargin=18, + title=title, + author=author + ) + + # Get styles + styles = getSampleStyleSheet() + story = [] + + # Add title + story.append(Paragraph(title, styles['Title'])) + story.append(Spacer(1, 0.3*inch)) + + # Process content blocks + for block in content_blocks: + block_type = block.get('type', 'paragraph') + content = block.get('content', '') + + if block_type == 'heading': + story.append(Paragraph(content, styles['Heading1'])) + story.append(Spacer(1, 0.1*inch)) + + elif block_type == 'heading2': + story.append(Paragraph(content, styles['Heading2'])) + story.append(Spacer(1, 0.1*inch)) + + elif block_type == 'paragraph': + story.append(Paragraph(content, styles['BodyText'])) + story.append(Spacer(1, 0.1*inch)) + + elif block_type == 'bullet': + story.append(Paragraph(content, styles['Bullet'])) + + elif block_type == 'space': + height = block.get('height', 0.2) + story.append(Spacer(1, height*inch)) + + elif block_type == 'pagebreak': + story.append(PageBreak()) + + # Build PDF + doc.build(story) + return filename + + +def create_styled_table(data, col_widths=None, style_name='default'): + """ + Create a styled table with common styling presets. + + Args: + data: List of lists containing table data + col_widths: List of column widths (None for auto) + style_name: 'default', 'striped', 'minimal', 'report' + + Returns: + Table object ready to add to story + """ + table = Table(data, colWidths=col_widths) + + if style_name == 'striped': + style = TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.darkblue), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 12), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), + ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ]) + + elif style_name == 'minimal': + style = TableStyle([ + ('ALIGN', (0, 0), (-1, -1), 'LEFT'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('LINEABOVE', (0, 0), (-1, 0), 2, colors.black), + ('LINEBELOW', (0, 0), (-1, 0), 1, colors.black), + ('LINEBELOW', (0, -1), (-1, -1), 2, colors.black), + ]) + + elif style_name == 'report': + style = TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.grey), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.black), + ('ALIGN', (0, 0), (-1, 0), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 11), + ('BACKGROUND', (0, 1), (-1, -1), colors.beige), + ('GRID', (0, 0), (-1, -1), 1, colors.grey), + ('LEFTPADDING', (0, 0), (-1, -1), 12), + ('RIGHTPADDING', (0, 0), (-1, -1), 12), + ]) + + else: # default + style = TableStyle([ + ('BACKGROUND', (0, 0), (-1, 0), colors.grey), + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, -1), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), + ('FONTSIZE', (0, 0), (-1, 0), 12), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('BACKGROUND', (0, 1), (-1, -1), colors.white), + ('GRID', (0, 0), (-1, -1), 1, colors.black), + ]) + + table.setStyle(style) + return table + + +def add_header_footer(canvas, doc, header_text="", footer_text=""): + """ + Callback function to add headers and footers to each page. + + Usage: + from functools import partial + callback = partial(add_header_footer, header_text="My Document", footer_text="Confidential") + template = PageTemplate(id='normal', frames=[frame], onPage=callback) + """ + canvas.saveState() + + # Header + if header_text: + canvas.setFont('Helvetica', 9) + canvas.drawString(inch, doc.pagesize[1] - 0.5*inch, header_text) + + # Footer + if footer_text: + canvas.setFont('Helvetica', 9) + canvas.drawString(inch, 0.5*inch, footer_text) + + # Page number + canvas.drawRightString(doc.pagesize[0] - inch, 0.5*inch, f"Page {doc.page}") + + canvas.restoreState() + + +# Example usage +if __name__ == "__main__": + # Example 1: Simple document + content = [ + {'type': 'heading', 'content': 'Introduction'}, + {'type': 'paragraph', 'content': 'This is a sample paragraph with some text.'}, + {'type': 'space', 'height': 0.2}, + {'type': 'heading', 'content': 'Main Content'}, + {'type': 'paragraph', 'content': 'More content here with bold and italic text.'}, + {'type': 'bullet', 'content': 'First bullet point'}, + {'type': 'bullet', 'content': 'Second bullet point'}, + ] + + create_simple_document( + "example_document.pdf", + "Sample Document", + author="John Doe", + content_blocks=content + ) + + print("Created: example_document.pdf") + + # Example 2: Document with styled table + doc = SimpleDocTemplate("table_example.pdf", pagesize=letter) + story = [] + styles = getSampleStyleSheet() + + story.append(Paragraph("Sales Report", styles['Title'])) + story.append(Spacer(1, 0.3*inch)) + + # Create table + data = [ + ['Product', 'Q1', 'Q2', 'Q3', 'Q4'], + ['Widget A', '100', '150', '130', '180'], + ['Widget B', '80', '120', '110', '160'], + ['Widget C', '90', '110', '100', '140'], + ] + + table = create_styled_table(data, col_widths=[2*inch, 1*inch, 1*inch, 1*inch, 1*inch], style_name='striped') + story.append(table) + + doc.build(story) + print("Created: table_example.pdf") diff --git a/scientific-packages/scanpy/SKILL.md b/scientific-packages/scanpy/SKILL.md new file mode 100644 index 0000000..3274d24 --- /dev/null +++ b/scientific-packages/scanpy/SKILL.md @@ -0,0 +1,380 @@ +--- +name: scanpy +description: This skill should be used when working with single-cell RNA-seq data analysis using scanpy. Use for analyzing .h5ad files, 10X Genomics data, performing quality control, clustering, finding marker genes, creating UMAP/t-SNE visualizations, cell type annotation, trajectory inference, and other single-cell genomics workflows. +--- + +# Scanpy: Single-Cell Analysis + +## Overview + +This skill provides comprehensive support for analyzing single-cell RNA-seq data using scanpy, a scalable Python toolkit built on AnnData. Use this skill for complete single-cell workflows including quality control, normalization, dimensionality reduction, clustering, marker gene identification, visualization, and trajectory analysis. + +## When to Use This Skill + +Activate this skill when: +- Analyzing single-cell RNA-seq data (.h5ad, 10X, CSV formats) +- Performing quality control on scRNA-seq datasets +- Creating UMAP, t-SNE, or PCA visualizations +- Identifying cell clusters and finding marker genes +- Annotating cell types based on gene expression +- Conducting trajectory inference or pseudotime analysis +- Generating publication-quality single-cell plots + +## Quick Start + +### Basic Import and Setup + +```python +import scanpy as sc +import pandas as pd +import numpy as np + +# Configure settings +sc.settings.verbosity = 3 +sc.settings.set_figure_params(dpi=80, facecolor='white') +sc.settings.figdir = './figures/' +``` + +### Loading Data + +```python +# From 10X Genomics +adata = sc.read_10x_mtx('path/to/data/') +adata = sc.read_10x_h5('path/to/data.h5') + +# From h5ad (AnnData format) +adata = sc.read_h5ad('path/to/data.h5ad') + +# From CSV +adata = sc.read_csv('path/to/data.csv') +``` + +### Understanding AnnData Structure + +The AnnData object is the core data structure in scanpy: + +```python +adata.X # Expression matrix (cells × genes) +adata.obs # Cell metadata (DataFrame) +adata.var # Gene metadata (DataFrame) +adata.uns # Unstructured annotations (dict) +adata.obsm # Multi-dimensional cell data (PCA, UMAP) +adata.raw # Raw data backup + +# Access cell and gene names +adata.obs_names # Cell barcodes +adata.var_names # Gene names +``` + +## Standard Analysis Workflow + +### 1. Quality Control + +Identify and filter low-quality cells and genes: + +```python +# Identify mitochondrial genes +adata.var['mt'] = adata.var_names.str.startswith('MT-') + +# Calculate QC metrics +sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) + +# Visualize QC metrics +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], + jitter=0.4, multi_panel=True) + +# Filter cells and genes +sc.pp.filter_cells(adata, min_genes=200) +sc.pp.filter_genes(adata, min_cells=3) +adata = adata[adata.obs.pct_counts_mt < 5, :] # Remove high MT% cells +``` + +**Use the QC script for automated analysis:** +```bash +python scripts/qc_analysis.py input_file.h5ad --output filtered.h5ad +``` + +### 2. Normalization and Preprocessing + +```python +# Normalize to 10,000 counts per cell +sc.pp.normalize_total(adata, target_sum=1e4) + +# Log-transform +sc.pp.log1p(adata) + +# Save raw counts for later +adata.raw = adata + +# Identify highly variable genes +sc.pp.highly_variable_genes(adata, n_top_genes=2000) +sc.pl.highly_variable_genes(adata) + +# Subset to highly variable genes +adata = adata[:, adata.var.highly_variable] + +# Regress out unwanted variation +sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt']) + +# Scale data +sc.pp.scale(adata, max_value=10) +``` + +### 3. Dimensionality Reduction + +```python +# PCA +sc.tl.pca(adata, svd_solver='arpack') +sc.pl.pca_variance_ratio(adata, log=True) # Check elbow plot + +# Compute neighborhood graph +sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) + +# UMAP for visualization +sc.tl.umap(adata) +sc.pl.umap(adata, color='leiden') + +# Alternative: t-SNE +sc.tl.tsne(adata) +``` + +### 4. Clustering + +```python +# Leiden clustering (recommended) +sc.tl.leiden(adata, resolution=0.5) +sc.pl.umap(adata, color='leiden', legend_loc='on data') + +# Try multiple resolutions to find optimal granularity +for res in [0.3, 0.5, 0.8, 1.0]: + sc.tl.leiden(adata, resolution=res, key_added=f'leiden_{res}') +``` + +### 5. Marker Gene Identification + +```python +# Find marker genes for each cluster +sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon') + +# Visualize results +sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False) +sc.pl.rank_genes_groups_heatmap(adata, n_genes=10) +sc.pl.rank_genes_groups_dotplot(adata, n_genes=5) + +# Get results as DataFrame +markers = sc.get.rank_genes_groups_df(adata, group='0') +``` + +### 6. Cell Type Annotation + +```python +# Define marker genes for known cell types +marker_genes = ['CD3D', 'CD14', 'MS4A1', 'NKG7', 'FCGR3A'] + +# Visualize markers +sc.pl.umap(adata, color=marker_genes, use_raw=True) +sc.pl.dotplot(adata, var_names=marker_genes, groupby='leiden') + +# Manual annotation +cluster_to_celltype = { + '0': 'CD4 T cells', + '1': 'CD14+ Monocytes', + '2': 'B cells', + '3': 'CD8 T cells', +} +adata.obs['cell_type'] = adata.obs['leiden'].map(cluster_to_celltype) + +# Visualize annotated types +sc.pl.umap(adata, color='cell_type', legend_loc='on data') +``` + +### 7. Save Results + +```python +# Save processed data +adata.write('results/processed_data.h5ad') + +# Export metadata +adata.obs.to_csv('results/cell_metadata.csv') +adata.var.to_csv('results/gene_metadata.csv') +``` + +## Common Tasks + +### Creating Publication-Quality Plots + +```python +# Set high-quality defaults +sc.settings.set_figure_params(dpi=300, frameon=False, figsize=(5, 5)) +sc.settings.file_format_figs = 'pdf' + +# UMAP with custom styling +sc.pl.umap(adata, color='cell_type', + palette='Set2', + legend_loc='on data', + legend_fontsize=12, + legend_fontoutline=2, + frameon=False, + save='_publication.pdf') + +# Heatmap of marker genes +sc.pl.heatmap(adata, var_names=genes, groupby='cell_type', + swap_axes=True, show_gene_labels=True, + save='_markers.pdf') + +# Dot plot +sc.pl.dotplot(adata, var_names=genes, groupby='cell_type', + save='_dotplot.pdf') +``` + +Refer to `references/plotting_guide.md` for comprehensive visualization examples. + +### Trajectory Inference + +```python +# PAGA (Partition-based graph abstraction) +sc.tl.paga(adata, groups='leiden') +sc.pl.paga(adata, color='leiden') + +# Diffusion pseudotime +adata.uns['iroot'] = np.flatnonzero(adata.obs['leiden'] == '0')[0] +sc.tl.dpt(adata) +sc.pl.umap(adata, color='dpt_pseudotime') +``` + +### Differential Expression Between Conditions + +```python +# Compare treated vs control within cell types +adata_subset = adata[adata.obs['cell_type'] == 'T cells'] +sc.tl.rank_genes_groups(adata_subset, groupby='condition', + groups=['treated'], reference='control') +sc.pl.rank_genes_groups(adata_subset, groups=['treated']) +``` + +### Gene Set Scoring + +```python +# Score cells for gene set expression +gene_set = ['CD3D', 'CD3E', 'CD3G'] +sc.tl.score_genes(adata, gene_set, score_name='T_cell_score') +sc.pl.umap(adata, color='T_cell_score') +``` + +### Batch Correction + +```python +# ComBat batch correction +sc.pp.combat(adata, key='batch') + +# Alternative: use Harmony or scVI (separate packages) +``` + +## Key Parameters to Adjust + +### Quality Control +- `min_genes`: Minimum genes per cell (typically 200-500) +- `min_cells`: Minimum cells per gene (typically 3-10) +- `pct_counts_mt`: Mitochondrial threshold (typically 5-20%) + +### Normalization +- `target_sum`: Target counts per cell (default 1e4) + +### Feature Selection +- `n_top_genes`: Number of HVGs (typically 2000-3000) +- `min_mean`, `max_mean`, `min_disp`: HVG selection parameters + +### Dimensionality Reduction +- `n_pcs`: Number of principal components (check variance ratio plot) +- `n_neighbors`: Number of neighbors (typically 10-30) + +### Clustering +- `resolution`: Clustering granularity (0.4-1.2, higher = more clusters) + +## Common Pitfalls and Best Practices + +1. **Always save raw counts**: `adata.raw = adata` before filtering genes +2. **Check QC plots carefully**: Adjust thresholds based on dataset quality +3. **Use Leiden over Louvain**: More efficient and better results +4. **Try multiple clustering resolutions**: Find optimal granularity +5. **Validate cell type annotations**: Use multiple marker genes +6. **Use `use_raw=True` for gene expression plots**: Shows original counts +7. **Check PCA variance ratio**: Determine optimal number of PCs +8. **Save intermediate results**: Long workflows can fail partway through + +## Bundled Resources + +### scripts/qc_analysis.py +Automated quality control script that calculates metrics, generates plots, and filters data: + +```bash +python scripts/qc_analysis.py input.h5ad --output filtered.h5ad \ + --mt-threshold 5 --min-genes 200 --min-cells 3 +``` + +### references/standard_workflow.md +Complete step-by-step workflow with detailed explanations and code examples for: +- Data loading and setup +- Quality control with visualization +- Normalization and scaling +- Feature selection +- Dimensionality reduction (PCA, UMAP, t-SNE) +- Clustering (Leiden, Louvain) +- Marker gene identification +- Cell type annotation +- Trajectory inference +- Differential expression + +Read this reference when performing a complete analysis from scratch. + +### references/api_reference.md +Quick reference guide for scanpy functions organized by module: +- Reading/writing data (`sc.read_*`, `adata.write_*`) +- Preprocessing (`sc.pp.*`) +- Tools (`sc.tl.*`) +- Plotting (`sc.pl.*`) +- AnnData structure and manipulation +- Settings and utilities + +Use this for quick lookup of function signatures and common parameters. + +### references/plotting_guide.md +Comprehensive visualization guide including: +- Quality control plots +- Dimensionality reduction visualizations +- Clustering visualizations +- Marker gene plots (heatmaps, dot plots, violin plots) +- Trajectory and pseudotime plots +- Publication-quality customization +- Multi-panel figures +- Color palettes and styling + +Consult this when creating publication-ready figures. + +### assets/analysis_template.py +Complete analysis template providing a full workflow from data loading through cell type annotation. Copy and customize this template for new analyses: + +```bash +cp assets/analysis_template.py my_analysis.py +# Edit parameters and run +python my_analysis.py +``` + +The template includes all standard steps with configurable parameters and helpful comments. + +## Additional Resources + +- **Official scanpy documentation**: https://scanpy.readthedocs.io/ +- **Scanpy tutorials**: https://scanpy-tutorials.readthedocs.io/ +- **scverse ecosystem**: https://scverse.org/ (related tools: squidpy, scvi-tools, cellrank) +- **Best practices**: Luecken & Theis (2019) "Current best practices in single-cell RNA-seq" + +## Tips for Effective Analysis + +1. **Start with the template**: Use `assets/analysis_template.py` as a starting point +2. **Run QC script first**: Use `scripts/qc_analysis.py` for initial filtering +3. **Consult references as needed**: Load workflow and API references into context +4. **Iterate on clustering**: Try multiple resolutions and visualization methods +5. **Validate biologically**: Check marker genes match expected cell types +6. **Document parameters**: Record QC thresholds and analysis settings +7. **Save checkpoints**: Write intermediate results at key steps diff --git a/scientific-packages/scanpy/assets/analysis_template.py b/scientific-packages/scanpy/assets/analysis_template.py new file mode 100644 index 0000000..fa3d34e --- /dev/null +++ b/scientific-packages/scanpy/assets/analysis_template.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +""" +Complete Single-Cell Analysis Template + +This template provides a complete workflow for single-cell RNA-seq analysis +using scanpy, from data loading through clustering and cell type annotation. + +Customize the parameters and sections as needed for your specific dataset. +""" + +import scanpy as sc +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + +# ============================================================================ +# CONFIGURATION +# ============================================================================ + +# File paths +INPUT_FILE = 'data/raw_counts.h5ad' # Change to your input file +OUTPUT_DIR = 'results/' +FIGURES_DIR = 'figures/' + +# QC parameters +MIN_GENES = 200 # Minimum genes per cell +MIN_CELLS = 3 # Minimum cells per gene +MT_THRESHOLD = 5 # Maximum mitochondrial percentage + +# Analysis parameters +N_TOP_GENES = 2000 # Number of highly variable genes +N_PCS = 40 # Number of principal components +N_NEIGHBORS = 10 # Number of neighbors for graph +LEIDEN_RESOLUTION = 0.5 # Clustering resolution + +# Scanpy settings +sc.settings.verbosity = 3 +sc.settings.set_figure_params(dpi=80, facecolor='white') +sc.settings.figdir = FIGURES_DIR + +# ============================================================================ +# 1. LOAD DATA +# ============================================================================ + +print("=" * 80) +print("LOADING DATA") +print("=" * 80) + +# Load data (adjust based on your file format) +adata = sc.read_h5ad(INPUT_FILE) +# adata = sc.read_10x_mtx('data/filtered_gene_bc_matrices/') # For 10X data +# adata = sc.read_csv('data/counts.csv') # For CSV data + +print(f"Loaded: {adata.n_obs} cells x {adata.n_vars} genes") + +# ============================================================================ +# 2. QUALITY CONTROL +# ============================================================================ + +print("\n" + "=" * 80) +print("QUALITY CONTROL") +print("=" * 80) + +# Identify mitochondrial genes +adata.var['mt'] = adata.var_names.str.startswith('MT-') + +# Calculate QC metrics +sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, + log1p=False, inplace=True) + +# Visualize QC metrics before filtering +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], + jitter=0.4, multi_panel=True, save='_qc_before_filtering') + +sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt', save='_qc_mt') +sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts', save='_qc_genes') + +# Filter cells and genes +print(f"\nBefore filtering: {adata.n_obs} cells, {adata.n_vars} genes") + +sc.pp.filter_cells(adata, min_genes=MIN_GENES) +sc.pp.filter_genes(adata, min_cells=MIN_CELLS) +adata = adata[adata.obs.pct_counts_mt < MT_THRESHOLD, :] + +print(f"After filtering: {adata.n_obs} cells, {adata.n_vars} genes") + +# ============================================================================ +# 3. NORMALIZATION +# ============================================================================ + +print("\n" + "=" * 80) +print("NORMALIZATION") +print("=" * 80) + +# Normalize to 10,000 counts per cell +sc.pp.normalize_total(adata, target_sum=1e4) + +# Log-transform +sc.pp.log1p(adata) + +# Store normalized data +adata.raw = adata + +# ============================================================================ +# 4. FEATURE SELECTION +# ============================================================================ + +print("\n" + "=" * 80) +print("FEATURE SELECTION") +print("=" * 80) + +# Identify highly variable genes +sc.pp.highly_variable_genes(adata, n_top_genes=N_TOP_GENES) + +# Visualize +sc.pl.highly_variable_genes(adata, save='_hvg') + +print(f"Selected {sum(adata.var.highly_variable)} highly variable genes") + +# Subset to highly variable genes +adata = adata[:, adata.var.highly_variable] + +# ============================================================================ +# 5. SCALING AND REGRESSION +# ============================================================================ + +print("\n" + "=" * 80) +print("SCALING AND REGRESSION") +print("=" * 80) + +# Regress out unwanted sources of variation +sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt']) + +# Scale data +sc.pp.scale(adata, max_value=10) + +# ============================================================================ +# 6. DIMENSIONALITY REDUCTION +# ============================================================================ + +print("\n" + "=" * 80) +print("DIMENSIONALITY REDUCTION") +print("=" * 80) + +# PCA +sc.tl.pca(adata, svd_solver='arpack') +sc.pl.pca_variance_ratio(adata, log=True, save='_pca_variance') + +# Compute neighborhood graph +sc.pp.neighbors(adata, n_neighbors=N_NEIGHBORS, n_pcs=N_PCS) + +# UMAP +sc.tl.umap(adata) + +# ============================================================================ +# 7. CLUSTERING +# ============================================================================ + +print("\n" + "=" * 80) +print("CLUSTERING") +print("=" * 80) + +# Leiden clustering +sc.tl.leiden(adata, resolution=LEIDEN_RESOLUTION) + +# Visualize +sc.pl.umap(adata, color='leiden', legend_loc='on data', save='_leiden') + +print(f"Identified {len(adata.obs['leiden'].unique())} clusters") + +# ============================================================================ +# 8. MARKER GENE IDENTIFICATION +# ============================================================================ + +print("\n" + "=" * 80) +print("MARKER GENE IDENTIFICATION") +print("=" * 80) + +# Find marker genes +sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon') + +# Visualize top markers +sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False, save='_markers') +sc.pl.rank_genes_groups_heatmap(adata, n_genes=10, save='_markers_heatmap') +sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, save='_markers_dotplot') + +# Get top markers for each cluster +for cluster in adata.obs['leiden'].unique(): + print(f"\nCluster {cluster} top markers:") + markers = sc.get.rank_genes_groups_df(adata, group=cluster).head(10) + print(markers[['names', 'scores', 'pvals_adj']].to_string(index=False)) + +# ============================================================================ +# 9. CELL TYPE ANNOTATION (CUSTOMIZE THIS SECTION) +# ============================================================================ + +print("\n" + "=" * 80) +print("CELL TYPE ANNOTATION") +print("=" * 80) + +# Example marker genes for common cell types (customize for your data) +marker_genes = { + 'T cells': ['CD3D', 'CD3E', 'CD3G'], + 'B cells': ['MS4A1', 'CD79A', 'CD79B'], + 'Monocytes': ['CD14', 'LYZ', 'S100A8'], + 'NK cells': ['NKG7', 'GNLY', 'KLRD1'], + 'Dendritic cells': ['FCER1A', 'CST3'], +} + +# Visualize marker genes +for cell_type, genes in marker_genes.items(): + available_genes = [g for g in genes if g in adata.raw.var_names] + if available_genes: + sc.pl.umap(adata, color=available_genes, use_raw=True, + save=f'_{cell_type.replace(" ", "_")}') + +# Manual annotation based on marker expression (customize this mapping) +cluster_to_celltype = { + '0': 'CD4 T cells', + '1': 'CD14+ Monocytes', + '2': 'B cells', + '3': 'CD8 T cells', + '4': 'NK cells', + # Add more mappings based on your marker analysis +} + +# Apply annotations +adata.obs['cell_type'] = adata.obs['leiden'].map(cluster_to_celltype) +adata.obs['cell_type'] = adata.obs['cell_type'].fillna('Unknown') + +# Visualize annotated cell types +sc.pl.umap(adata, color='cell_type', legend_loc='on data', save='_celltypes') + +# ============================================================================ +# 10. ADDITIONAL ANALYSES (OPTIONAL) +# ============================================================================ + +print("\n" + "=" * 80) +print("ADDITIONAL ANALYSES") +print("=" * 80) + +# PAGA trajectory analysis (optional) +sc.tl.paga(adata, groups='leiden') +sc.pl.paga(adata, color='leiden', save='_paga') + +# Gene set scoring (optional) +# example_gene_set = ['CD3D', 'CD3E', 'CD3G'] +# sc.tl.score_genes(adata, example_gene_set, score_name='T_cell_score') +# sc.pl.umap(adata, color='T_cell_score', save='_gene_set_score') + +# ============================================================================ +# 11. SAVE RESULTS +# ============================================================================ + +print("\n" + "=" * 80) +print("SAVING RESULTS") +print("=" * 80) + +import os +os.makedirs(OUTPUT_DIR, exist_ok=True) + +# Save processed AnnData object +adata.write(f'{OUTPUT_DIR}/processed_data.h5ad') +print(f"Saved processed data to {OUTPUT_DIR}/processed_data.h5ad") + +# Export metadata +adata.obs.to_csv(f'{OUTPUT_DIR}/cell_metadata.csv') +adata.var.to_csv(f'{OUTPUT_DIR}/gene_metadata.csv') +print(f"Saved metadata to {OUTPUT_DIR}/") + +# Export marker genes +for cluster in adata.obs['leiden'].unique(): + markers = sc.get.rank_genes_groups_df(adata, group=cluster) + markers.to_csv(f'{OUTPUT_DIR}/markers_cluster_{cluster}.csv', index=False) +print(f"Saved marker genes to {OUTPUT_DIR}/") + +# ============================================================================ +# 12. SUMMARY +# ============================================================================ + +print("\n" + "=" * 80) +print("ANALYSIS SUMMARY") +print("=" * 80) + +print(f"\nFinal dataset:") +print(f" Cells: {adata.n_obs}") +print(f" Genes: {adata.n_vars}") +print(f" Clusters: {len(adata.obs['leiden'].unique())}") + +print(f"\nCell type distribution:") +print(adata.obs['cell_type'].value_counts()) + +print("\n" + "=" * 80) +print("ANALYSIS COMPLETE") +print("=" * 80) diff --git a/scientific-packages/scanpy/references/api_reference.md b/scientific-packages/scanpy/references/api_reference.md new file mode 100644 index 0000000..40f6659 --- /dev/null +++ b/scientific-packages/scanpy/references/api_reference.md @@ -0,0 +1,251 @@ +# Scanpy API Quick Reference + +Quick reference for commonly used scanpy functions organized by module. + +## Import Convention + +```python +import scanpy as sc +``` + +## Reading and Writing Data (sc.read_*) + +### Reading Functions + +```python +sc.read_10x_h5(filename) # Read 10X HDF5 file +sc.read_10x_mtx(path) # Read 10X mtx directory +sc.read_h5ad(filename) # Read h5ad (AnnData) file +sc.read_csv(filename) # Read CSV file +sc.read_excel(filename) # Read Excel file +sc.read_loom(filename) # Read loom file +sc.read_text(filename) # Read text file +sc.read_visium(path) # Read Visium spatial data +``` + +### Writing Functions + +```python +adata.write_h5ad(filename) # Write to h5ad format +adata.write_csvs(dirname) # Write to CSV files +adata.write_loom(filename) # Write to loom format +adata.write_zarr(filename) # Write to zarr format +``` + +## Preprocessing (sc.pp.*) + +### Quality Control + +```python +sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True) +sc.pp.filter_cells(adata, min_genes=200) +sc.pp.filter_genes(adata, min_cells=3) +``` + +### Normalization and Transformation + +```python +sc.pp.normalize_total(adata, target_sum=1e4) # Normalize to target sum +sc.pp.log1p(adata) # Log(x + 1) transformation +sc.pp.sqrt(adata) # Square root transformation +``` + +### Feature Selection + +```python +sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5) +sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=2000) +``` + +### Scaling and Regression + +```python +sc.pp.scale(adata, max_value=10) # Scale to unit variance +sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt']) # Regress out unwanted variation +``` + +### Dimensionality Reduction (Preprocessing) + +```python +sc.pp.pca(adata, n_comps=50) # Principal component analysis +sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) # Compute neighborhood graph +``` + +### Batch Correction + +```python +sc.pp.combat(adata, key='batch') # ComBat batch correction +``` + +## Tools (sc.tl.*) + +### Dimensionality Reduction + +```python +sc.tl.pca(adata, svd_solver='arpack') # PCA +sc.tl.umap(adata) # UMAP embedding +sc.tl.tsne(adata) # t-SNE embedding +sc.tl.diffmap(adata) # Diffusion map +sc.tl.draw_graph(adata, layout='fa') # Force-directed graph +``` + +### Clustering + +```python +sc.tl.leiden(adata, resolution=0.5) # Leiden clustering (recommended) +sc.tl.louvain(adata, resolution=0.5) # Louvain clustering +sc.tl.kmeans(adata, n_clusters=10) # K-means clustering +``` + +### Marker Genes and Differential Expression + +```python +sc.tl.rank_genes_groups(adata, groupby='leiden', method='wilcoxon') +sc.tl.rank_genes_groups(adata, groupby='leiden', method='t-test') +sc.tl.rank_genes_groups(adata, groupby='leiden', method='logreg') + +# Get results as dataframe +sc.get.rank_genes_groups_df(adata, group='0') +``` + +### Trajectory Inference + +```python +sc.tl.paga(adata, groups='leiden') # PAGA trajectory +sc.tl.dpt(adata) # Diffusion pseudotime +``` + +### Gene Scoring + +```python +sc.tl.score_genes(adata, gene_list, score_name='score') +sc.tl.score_genes_cell_cycle(adata, s_genes, g2m_genes) +``` + +### Embeddings and Projections + +```python +sc.tl.ingest(adata, adata_ref) # Map to reference +sc.tl.embedding_density(adata, basis='umap', groupby='leiden') +``` + +## Plotting (sc.pl.*) + +### Basic Embeddings + +```python +sc.pl.umap(adata, color='leiden') # UMAP plot +sc.pl.tsne(adata, color='gene_name') # t-SNE plot +sc.pl.pca(adata, color='leiden') # PCA plot +sc.pl.diffmap(adata, color='leiden') # Diffusion map plot +``` + +### Heatmaps and Dot Plots + +```python +sc.pl.heatmap(adata, var_names=genes, groupby='leiden') +sc.pl.dotplot(adata, var_names=genes, groupby='leiden') +sc.pl.matrixplot(adata, var_names=genes, groupby='leiden') +sc.pl.stacked_violin(adata, var_names=genes, groupby='leiden') +``` + +### Violin and Scatter Plots + +```python +sc.pl.violin(adata, keys=['gene1', 'gene2'], groupby='leiden') +sc.pl.scatter(adata, x='gene1', y='gene2', color='leiden') +``` + +### Marker Gene Visualization + +```python +sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False) +sc.pl.rank_genes_groups_violin(adata, groups='0') +sc.pl.rank_genes_groups_heatmap(adata, n_genes=10) +sc.pl.rank_genes_groups_dotplot(adata, n_genes=5) +``` + +### Trajectory Visualization + +```python +sc.pl.paga(adata, color='leiden') # PAGA graph +sc.pl.dpt_timeseries(adata) # DPT timeseries +``` + +### QC Plots + +```python +sc.pl.highest_expr_genes(adata, n_top=20) +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']) +sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts') +``` + +### Advanced Plots + +```python +sc.pl.dendrogram(adata, groupby='leiden') +sc.pl.correlation_matrix(adata, groupby='leiden') +sc.pl.tracksplot(adata, var_names=genes, groupby='leiden') +``` + +## Common Parameters + +### Color Parameters +- `color`: Variable(s) to color by (gene name, obs column) +- `use_raw`: Use `.raw` attribute of adata +- `palette`: Color palette to use +- `vmin`, `vmax`: Color scale limits + +### Layout Parameters +- `basis`: Embedding basis ('umap', 'tsne', 'pca', etc.) +- `legend_loc`: Legend location ('on data', 'right margin', etc.) +- `size`: Point size +- `alpha`: Point transparency + +### Saving Parameters +- `save`: Filename to save plot +- `show`: Whether to show plot + +## AnnData Structure + +```python +adata.X # Expression matrix (cells × genes) +adata.obs # Cell annotations (DataFrame) +adata.var # Gene annotations (DataFrame) +adata.uns # Unstructured annotations (dict) +adata.obsm # Multi-dimensional cell annotations (e.g., PCA, UMAP) +adata.varm # Multi-dimensional gene annotations +adata.layers # Additional data layers +adata.raw # Raw data backup + +# Access +adata.obs_names # Cell barcodes +adata.var_names # Gene names +adata.shape # (n_cells, n_genes) + +# Slicing +adata[cell_indices, gene_indices] +adata[:, adata.var_names.isin(gene_list)] +adata[adata.obs['leiden'] == '0', :] +``` + +## Settings + +```python +sc.settings.verbosity = 3 # 0=error, 1=warning, 2=info, 3=hint +sc.settings.set_figure_params(dpi=80, facecolor='white') +sc.settings.autoshow = False # Don't show plots automatically +sc.settings.autosave = True # Autosave figures +sc.settings.figdir = './figures/' # Figure directory +sc.settings.cachedir = './cache/' # Cache directory +sc.settings.n_jobs = 8 # Number of parallel jobs +``` + +## Useful Utilities + +```python +sc.logging.print_versions() # Print version information +sc.logging.print_memory_usage() # Print memory usage +adata.copy() # Create a copy of AnnData object +adata.concatenate([adata1, adata2]) # Concatenate AnnData objects +``` diff --git a/scientific-packages/scanpy/references/plotting_guide.md b/scientific-packages/scanpy/references/plotting_guide.md new file mode 100644 index 0000000..3fc4f62 --- /dev/null +++ b/scientific-packages/scanpy/references/plotting_guide.md @@ -0,0 +1,352 @@ +# Scanpy Plotting Guide + +Comprehensive guide for creating publication-quality visualizations with scanpy. + +## General Plotting Principles + +All scanpy plotting functions follow consistent patterns: +- Functions in `sc.pl.*` mirror analysis functions in `sc.tl.*` +- Most accept `color` parameter for gene names or metadata columns +- Results are saved via `save` parameter +- Multiple plots can be generated in a single call + +## Essential Quality Control Plots + +### Visualize QC Metrics + +```python +# Violin plots for QC metrics +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], + jitter=0.4, multi_panel=True, save='_qc_violin.pdf') + +# Scatter plots to identify outliers +sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt', save='_qc_mt.pdf') +sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts', save='_qc_genes.pdf') + +# Highest expressing genes +sc.pl.highest_expr_genes(adata, n_top=20, save='_highest_expr.pdf') +``` + +### Post-filtering QC + +```python +# Compare before and after filtering +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts'], + groupby='sample', save='_post_filter.pdf') +``` + +## Dimensionality Reduction Visualizations + +### PCA Plots + +```python +# Basic PCA +sc.pl.pca(adata, color='leiden', save='_pca.pdf') + +# PCA colored by gene expression +sc.pl.pca(adata, color=['gene1', 'gene2', 'gene3'], save='_pca_genes.pdf') + +# Variance ratio plot (elbow plot) +sc.pl.pca_variance_ratio(adata, log=True, n_pcs=50, save='_variance.pdf') + +# PCA loadings +sc.pl.pca_loadings(adata, components=[1, 2, 3], save='_loadings.pdf') +``` + +### UMAP Plots + +```python +# Basic UMAP with clusters +sc.pl.umap(adata, color='leiden', legend_loc='on data', save='_umap_leiden.pdf') + +# UMAP colored by multiple variables +sc.pl.umap(adata, color=['leiden', 'cell_type', 'batch'], + save='_umap_multi.pdf') + +# UMAP with gene expression +sc.pl.umap(adata, color=['CD3D', 'CD14', 'MS4A1'], + use_raw=False, save='_umap_genes.pdf') + +# Customize appearance +sc.pl.umap(adata, color='leiden', + palette='Set2', + size=50, + alpha=0.8, + frameon=False, + title='Cell Types', + save='_umap_custom.pdf') +``` + +### t-SNE Plots + +```python +# t-SNE with clusters +sc.pl.tsne(adata, color='leiden', legend_loc='right margin', save='_tsne.pdf') + +# Multiple t-SNE perplexities (if computed) +sc.pl.tsne(adata, color='leiden', save='_tsne_default.pdf') +``` + +## Clustering Visualizations + +### Basic Cluster Plots + +```python +# UMAP with cluster annotations +sc.pl.umap(adata, color='leiden', add_outline=True, + legend_loc='on data', legend_fontsize=12, + legend_fontoutline=2, frameon=False, + save='_clusters.pdf') + +# Show cluster proportions +sc.pl.umap(adata, color='leiden', size=50, edges=True, + edges_width=0.1, save='_clusters_edges.pdf') +``` + +### Cluster Comparison + +```python +# Compare clustering results +sc.pl.umap(adata, color=['leiden', 'louvain'], + save='_cluster_comparison.pdf') + +# Cluster dendrogram +sc.tl.dendrogram(adata, groupby='leiden') +sc.pl.dendrogram(adata, groupby='leiden', save='_dendrogram.pdf') +``` + +## Marker Gene Visualizations + +### Ranked Marker Genes + +```python +# Overview of top markers per cluster +sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False, + save='_marker_overview.pdf') + +# Heatmap of top markers +sc.pl.rank_genes_groups_heatmap(adata, n_genes=10, groupby='leiden', + show_gene_labels=True, + save='_marker_heatmap.pdf') + +# Dot plot of markers +sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, + save='_marker_dotplot.pdf') + +# Stacked violin plots +sc.pl.rank_genes_groups_stacked_violin(adata, n_genes=5, + save='_marker_violin.pdf') + +# Matrix plot +sc.pl.rank_genes_groups_matrixplot(adata, n_genes=5, + save='_marker_matrix.pdf') +``` + +### Specific Gene Expression + +```python +# Violin plots for specific genes +marker_genes = ['CD3D', 'CD14', 'MS4A1', 'NKG7', 'FCGR3A'] +sc.pl.violin(adata, keys=marker_genes, groupby='leiden', + save='_markers_violin.pdf') + +# Dot plot for curated markers +sc.pl.dotplot(adata, var_names=marker_genes, groupby='leiden', + save='_markers_dotplot.pdf') + +# Heatmap for specific genes +sc.pl.heatmap(adata, var_names=marker_genes, groupby='leiden', + swap_axes=True, save='_markers_heatmap.pdf') + +# Stacked violin for gene sets +sc.pl.stacked_violin(adata, var_names=marker_genes, groupby='leiden', + save='_markers_stacked.pdf') +``` + +### Gene Expression on Embeddings + +```python +# Multiple genes on UMAP +genes = ['CD3D', 'CD14', 'MS4A1', 'NKG7'] +sc.pl.umap(adata, color=genes, cmap='viridis', + save='_umap_markers.pdf') + +# Gene expression with custom colormap +sc.pl.umap(adata, color='CD3D', cmap='Reds', + vmin=0, vmax=3, save='_umap_cd3d.pdf') +``` + +## Trajectory and Pseudotime Visualizations + +### PAGA Plots + +```python +# PAGA graph +sc.pl.paga(adata, color='leiden', save='_paga.pdf') + +# PAGA with gene expression +sc.pl.paga(adata, color=['leiden', 'dpt_pseudotime'], + save='_paga_pseudotime.pdf') + +# PAGA overlaid on UMAP +sc.pl.umap(adata, color='leiden', save='_umap_with_paga.pdf', + edges=True, edges_color='gray') +``` + +### Pseudotime Plots + +```python +# DPT pseudotime on UMAP +sc.pl.umap(adata, color='dpt_pseudotime', save='_umap_dpt.pdf') + +# Gene expression along pseudotime +sc.pl.dpt_timeseries(adata, save='_dpt_timeseries.pdf') + +# Heatmap ordered by pseudotime +sc.pl.heatmap(adata, var_names=genes, groupby='leiden', + use_raw=False, show_gene_labels=True, + save='_pseudotime_heatmap.pdf') +``` + +## Advanced Visualizations + +### Tracks Plot (Gene Expression Trends) + +```python +# Show gene expression across cell types +sc.pl.tracksplot(adata, var_names=marker_genes, groupby='leiden', + save='_tracks.pdf') +``` + +### Correlation Matrix + +```python +# Correlation between clusters +sc.pl.correlation_matrix(adata, groupby='leiden', + save='_correlation.pdf') +``` + +### Embedding Density + +```python +# Cell density on UMAP +sc.tl.embedding_density(adata, basis='umap', groupby='cell_type') +sc.pl.embedding_density(adata, basis='umap', key='umap_density_cell_type', + save='_density.pdf') +``` + +## Multi-Panel Figures + +### Creating Panel Figures + +```python +import matplotlib.pyplot as plt + +# Create multi-panel figure +fig, axes = plt.subplots(2, 2, figsize=(12, 12)) + +# Plot on specific axes +sc.pl.umap(adata, color='leiden', ax=axes[0, 0], show=False) +sc.pl.umap(adata, color='CD3D', ax=axes[0, 1], show=False) +sc.pl.umap(adata, color='CD14', ax=axes[1, 0], show=False) +sc.pl.umap(adata, color='MS4A1', ax=axes[1, 1], show=False) + +plt.tight_layout() +plt.savefig('figures/multi_panel.pdf') +plt.show() +``` + +## Publication-Quality Customization + +### High-Quality Settings + +```python +# Set publication-quality defaults +sc.settings.set_figure_params(dpi=300, frameon=False, figsize=(5, 5), + facecolor='white') + +# Vector graphics output +sc.settings.figdir = './figures/' +sc.settings.file_format_figs = 'pdf' # or 'svg' +``` + +### Custom Color Palettes + +```python +# Use custom colors +custom_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'] +sc.pl.umap(adata, color='leiden', palette=custom_colors, + save='_custom_colors.pdf') + +# Continuous color maps +sc.pl.umap(adata, color='CD3D', cmap='viridis', save='_viridis.pdf') +sc.pl.umap(adata, color='CD3D', cmap='RdBu_r', save='_rdbu.pdf') +``` + +### Remove Axes and Frames + +```python +# Clean plot without axes +sc.pl.umap(adata, color='leiden', frameon=False, + save='_clean.pdf') + +# No legend +sc.pl.umap(adata, color='leiden', legend_loc=None, + save='_no_legend.pdf') +``` + +## Exporting Plots + +### Save Individual Plots + +```python +# Automatic saving with save parameter +sc.pl.umap(adata, color='leiden', save='_leiden.pdf') +# Saves to: sc.settings.figdir + 'umap_leiden.pdf' + +# Manual saving +import matplotlib.pyplot as plt +fig = sc.pl.umap(adata, color='leiden', show=False, return_fig=True) +fig.savefig('figures/my_umap.pdf', dpi=300, bbox_inches='tight') +``` + +### Batch Export + +```python +# Save multiple versions +for gene in ['CD3D', 'CD14', 'MS4A1']: + sc.pl.umap(adata, color=gene, save=f'_{gene}.pdf') +``` + +## Common Customization Parameters + +### Layout Parameters +- `figsize`: Figure size (width, height) +- `frameon`: Show frame around plot +- `title`: Plot title +- `legend_loc`: 'right margin', 'on data', 'best', or None +- `legend_fontsize`: Font size for legend +- `size`: Point size + +### Color Parameters +- `color`: Variable(s) to color by +- `palette`: Color palette (e.g., 'Set1', 'viridis') +- `cmap`: Colormap for continuous variables +- `vmin`, `vmax`: Color scale limits +- `use_raw`: Use raw counts for gene expression + +### Saving Parameters +- `save`: Filename suffix for saving +- `show`: Whether to display plot +- `dpi`: Resolution for raster formats + +## Tips for Publication Figures + +1. **Use vector formats**: PDF or SVG for scalable graphics +2. **High DPI**: Set dpi=300 or higher for raster images +3. **Consistent styling**: Use the same color palette across figures +4. **Clear labels**: Ensure gene names and cell types are readable +5. **White background**: Use `facecolor='white'` for publications +6. **Remove clutter**: Set `frameon=False` for cleaner appearance +7. **Legend placement**: Use 'on data' for compact figures +8. **Color blind friendly**: Consider palettes like 'colorblind' or 'Set2' diff --git a/scientific-packages/scanpy/references/standard_workflow.md b/scientific-packages/scanpy/references/standard_workflow.md new file mode 100644 index 0000000..7184ee9 --- /dev/null +++ b/scientific-packages/scanpy/references/standard_workflow.md @@ -0,0 +1,206 @@ +# Standard Scanpy Workflow for Single-Cell Analysis + +This document outlines the standard workflow for analyzing single-cell RNA-seq data using scanpy. + +## Complete Analysis Pipeline + +### 1. Data Loading and Initial Setup + +```python +import scanpy as sc +import pandas as pd +import numpy as np + +# Configure scanpy settings +sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3) +sc.settings.set_figure_params(dpi=80, facecolor='white') + +# Load data (various formats) +adata = sc.read_10x_mtx('path/to/data/') # For 10X data +# adata = sc.read_h5ad('path/to/data.h5ad') # For h5ad format +# adata = sc.read_csv('path/to/data.csv') # For CSV format +``` + +### 2. Quality Control (QC) + +```python +# Calculate QC metrics +sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) + +# Common filtering thresholds (adjust based on dataset) +sc.pp.filter_cells(adata, min_genes=200) +sc.pp.filter_genes(adata, min_cells=3) + +# Remove cells with high mitochondrial content +adata = adata[adata.obs.pct_counts_mt < 5, :] + +# Visualize QC metrics +sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], + jitter=0.4, multi_panel=True) +sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt') +sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts') +``` + +### 3. Normalization + +```python +# Normalize to 10,000 counts per cell +sc.pp.normalize_total(adata, target_sum=1e4) + +# Log-transform the data +sc.pp.log1p(adata) + +# Store normalized data in raw for later use +adata.raw = adata +``` + +### 4. Feature Selection + +```python +# Identify highly variable genes +sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5) + +# Visualize highly variable genes +sc.pl.highly_variable_genes(adata) + +# Subset to highly variable genes +adata = adata[:, adata.var.highly_variable] +``` + +### 5. Scaling and Regression + +```python +# Regress out effects of total counts per cell and percent mitochondrial genes +sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt']) + +# Scale data to unit variance and zero mean +sc.pp.scale(adata, max_value=10) +``` + +### 6. Dimensionality Reduction + +```python +# Principal Component Analysis (PCA) +sc.tl.pca(adata, svd_solver='arpack') + +# Visualize PCA results +sc.pl.pca(adata, color='CST3') +sc.pl.pca_variance_ratio(adata, log=True) + +# Computing neighborhood graph +sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) + +# UMAP for visualization +sc.tl.umap(adata) + +# t-SNE (alternative to UMAP) +# sc.tl.tsne(adata) +``` + +### 7. Clustering + +```python +# Leiden clustering (recommended) +sc.tl.leiden(adata, resolution=0.5) + +# Alternative: Louvain clustering +# sc.tl.louvain(adata, resolution=0.5) + +# Visualize clustering results +sc.pl.umap(adata, color=['leiden'], legend_loc='on data') +``` + +### 8. Marker Gene Identification + +```python +# Find marker genes for each cluster +sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon') + +# Visualize top marker genes +sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False) + +# Get marker gene dataframe +marker_genes = sc.get.rank_genes_groups_df(adata, group='0') + +# Visualize specific markers +sc.pl.umap(adata, color=['leiden', 'CST3', 'NKG7']) +``` + +### 9. Cell Type Annotation + +```python +# Manual annotation based on marker genes +cluster_annotations = { + '0': 'CD4 T cells', + '1': 'CD14+ Monocytes', + '2': 'B cells', + '3': 'CD8 T cells', + # ... add more annotations +} +adata.obs['cell_type'] = adata.obs['leiden'].map(cluster_annotations) + +# Visualize annotated cell types +sc.pl.umap(adata, color='cell_type', legend_loc='on data') +``` + +### 10. Saving Results + +```python +# Save the processed AnnData object +adata.write('results/processed_data.h5ad') + +# Export results to CSV +adata.obs.to_csv('results/cell_metadata.csv') +adata.var.to_csv('results/gene_metadata.csv') +``` + +## Additional Analysis Options + +### Trajectory Inference + +```python +# PAGA (Partition-based graph abstraction) +sc.tl.paga(adata, groups='leiden') +sc.pl.paga(adata, color=['leiden']) + +# Diffusion pseudotime (DPT) +adata.uns['iroot'] = np.flatnonzero(adata.obs['leiden'] == '0')[0] +sc.tl.dpt(adata) +sc.pl.umap(adata, color=['dpt_pseudotime']) +``` + +### Differential Expression Between Conditions + +```python +# Compare conditions within a cell type +sc.tl.rank_genes_groups(adata, groupby='condition', groups=['treated'], + reference='control', method='wilcoxon') +sc.pl.rank_genes_groups(adata, groups=['treated']) +``` + +### Gene Set Scoring + +```python +# Score cells for gene set expression +gene_set = ['CD3D', 'CD3E', 'CD3G'] +sc.tl.score_genes(adata, gene_set, score_name='T_cell_score') +sc.pl.umap(adata, color='T_cell_score') +``` + +## Common Parameters to Adjust + +- **QC thresholds**: `min_genes`, `min_cells`, `pct_counts_mt` - depends on dataset quality +- **Normalization target**: Usually 1e4, but can be adjusted +- **HVG parameters**: Affects feature selection stringency +- **PCA components**: Check variance ratio plot to determine optimal number +- **Clustering resolution**: Higher values give more clusters (typically 0.4-1.2) +- **n_neighbors**: Affects granularity of UMAP and clustering (typically 10-30) + +## Best Practices + +1. Always visualize QC metrics before filtering +2. Save raw counts before normalization (`adata.raw = adata`) +3. Use Leiden instead of Louvain for clustering (more efficient) +4. Try multiple clustering resolutions to find optimal granularity +5. Validate cell type annotations with known marker genes +6. Save intermediate results at key steps diff --git a/scientific-packages/scanpy/scripts/qc_analysis.py b/scientific-packages/scanpy/scripts/qc_analysis.py new file mode 100755 index 0000000..45fccd9 --- /dev/null +++ b/scientific-packages/scanpy/scripts/qc_analysis.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +Quality Control Analysis Script for Scanpy + +Performs comprehensive quality control on single-cell RNA-seq data, +including calculating metrics, generating QC plots, and filtering cells. + +Usage: + python qc_analysis.py [--output ] +""" + +import argparse +import scanpy as sc +import matplotlib.pyplot as plt + + +def calculate_qc_metrics(adata, mt_threshold=5, min_genes=200, min_cells=3): + """ + Calculate QC metrics and filter cells/genes. + + Parameters: + ----------- + adata : AnnData + Annotated data matrix + mt_threshold : float + Maximum percentage of mitochondrial genes (default: 5) + min_genes : int + Minimum number of genes per cell (default: 200) + min_cells : int + Minimum number of cells per gene (default: 3) + + Returns: + -------- + AnnData + Filtered annotated data matrix + """ + # Identify mitochondrial genes (assumes gene names follow standard conventions) + adata.var['mt'] = adata.var_names.str.startswith(('MT-', 'mt-', 'Mt-')) + + # Calculate QC metrics + sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, + log1p=False, inplace=True) + + print("\n=== QC Metrics Summary ===") + print(f"Total cells: {adata.n_obs}") + print(f"Total genes: {adata.n_vars}") + print(f"Mean genes per cell: {adata.obs['n_genes_by_counts'].mean():.2f}") + print(f"Mean counts per cell: {adata.obs['total_counts'].mean():.2f}") + print(f"Mean mitochondrial %: {adata.obs['pct_counts_mt'].mean():.2f}") + + return adata + + +def generate_qc_plots(adata, output_prefix='qc'): + """ + Generate comprehensive QC plots. + + Parameters: + ----------- + adata : AnnData + Annotated data matrix + output_prefix : str + Prefix for saved figure files + """ + # Create figure directory if it doesn't exist + import os + os.makedirs('figures', exist_ok=True) + + # Violin plots for QC metrics + sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], + jitter=0.4, multi_panel=True, save=f'_{output_prefix}_violin.pdf') + + # Scatter plots + sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt', + save=f'_{output_prefix}_mt_scatter.pdf') + sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts', + save=f'_{output_prefix}_genes_scatter.pdf') + + # Highest expressing genes + sc.pl.highest_expr_genes(adata, n_top=20, + save=f'_{output_prefix}_highest_expr.pdf') + + print(f"\nQC plots saved to figures/ directory with prefix '{output_prefix}'") + + +def filter_data(adata, mt_threshold=5, min_genes=200, max_genes=None, + min_counts=None, max_counts=None, min_cells=3): + """ + Filter cells and genes based on QC thresholds. + + Parameters: + ----------- + adata : AnnData + Annotated data matrix + mt_threshold : float + Maximum percentage of mitochondrial genes + min_genes : int + Minimum number of genes per cell + max_genes : int, optional + Maximum number of genes per cell + min_counts : int, optional + Minimum number of counts per cell + max_counts : int, optional + Maximum number of counts per cell + min_cells : int + Minimum number of cells per gene + + Returns: + -------- + AnnData + Filtered annotated data matrix + """ + n_cells_before = adata.n_obs + n_genes_before = adata.n_vars + + # Filter cells + sc.pp.filter_cells(adata, min_genes=min_genes) + if max_genes: + adata = adata[adata.obs['n_genes_by_counts'] < max_genes, :] + if min_counts: + adata = adata[adata.obs['total_counts'] >= min_counts, :] + if max_counts: + adata = adata[adata.obs['total_counts'] < max_counts, :] + + # Filter by mitochondrial percentage + adata = adata[adata.obs['pct_counts_mt'] < mt_threshold, :] + + # Filter genes + sc.pp.filter_genes(adata, min_cells=min_cells) + + print(f"\n=== Filtering Results ===") + print(f"Cells: {n_cells_before} -> {adata.n_obs} ({adata.n_obs/n_cells_before*100:.1f}% retained)") + print(f"Genes: {n_genes_before} -> {adata.n_vars} ({adata.n_vars/n_genes_before*100:.1f}% retained)") + + return adata + + +def main(): + parser = argparse.ArgumentParser(description='QC analysis for single-cell data') + parser.add_argument('input', help='Input file (h5ad, 10X mtx, csv, etc.)') + parser.add_argument('--output', default='qc_filtered.h5ad', + help='Output file name (default: qc_filtered.h5ad)') + parser.add_argument('--mt-threshold', type=float, default=5, + help='Max mitochondrial percentage (default: 5)') + parser.add_argument('--min-genes', type=int, default=200, + help='Min genes per cell (default: 200)') + parser.add_argument('--min-cells', type=int, default=3, + help='Min cells per gene (default: 3)') + parser.add_argument('--skip-plots', action='store_true', + help='Skip generating QC plots') + + args = parser.parse_args() + + # Configure scanpy + sc.settings.verbosity = 2 + sc.settings.set_figure_params(dpi=300, facecolor='white') + sc.settings.figdir = './figures/' + + print(f"Loading data from: {args.input}") + + # Load data based on file extension + if args.input.endswith('.h5ad'): + adata = sc.read_h5ad(args.input) + elif args.input.endswith('.h5'): + adata = sc.read_10x_h5(args.input) + elif args.input.endswith('.csv'): + adata = sc.read_csv(args.input) + else: + # Try reading as 10X mtx directory + adata = sc.read_10x_mtx(args.input) + + print(f"Loaded data: {adata.n_obs} cells x {adata.n_vars} genes") + + # Calculate QC metrics + adata = calculate_qc_metrics(adata, mt_threshold=args.mt_threshold, + min_genes=args.min_genes, min_cells=args.min_cells) + + # Generate QC plots (before filtering) + if not args.skip_plots: + print("\nGenerating QC plots (before filtering)...") + generate_qc_plots(adata, output_prefix='qc_before') + + # Filter data + adata = filter_data(adata, mt_threshold=args.mt_threshold, + min_genes=args.min_genes, min_cells=args.min_cells) + + # Generate QC plots (after filtering) + if not args.skip_plots: + print("\nGenerating QC plots (after filtering)...") + generate_qc_plots(adata, output_prefix='qc_after') + + # Save filtered data + print(f"\nSaving filtered data to: {args.output}") + adata.write_h5ad(args.output) + + print("\n=== QC Analysis Complete ===") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/scikit-bio/SKILL.md b/scientific-packages/scikit-bio/SKILL.md new file mode 100644 index 0000000..1b06fd0 --- /dev/null +++ b/scientific-packages/scikit-bio/SKILL.md @@ -0,0 +1,435 @@ +--- +name: scikit-bio +description: Comprehensive toolkit for biological data analysis in Python including DNA/RNA/protein sequence manipulation, sequence alignments, phylogenetic tree construction and analysis, microbial diversity metrics (alpha/beta diversity, UniFrac), ordination methods (PCoA, CCA, RDA), and statistical hypothesis testing (PERMANOVA, ANOSIM, Mantel). Use this skill when working with FASTA/FASTQ files, biological sequences, phylogenetic trees, microbiome data, ecological community analysis, or any bioinformatics workflow requiring sequence analysis, alignment, diversity calculations, or multivariate statistics on biological data. +--- + +# scikit-bio + +## Overview + +scikit-bio is a comprehensive Python library for working with biological data. Provide assistance with bioinformatics analyses spanning sequence manipulation, alignment, phylogenetics, microbial ecology, and multivariate statistics. This skill enables efficient work with common biological file formats and computational workflows in genomics, metagenomics, and ecological research. + +**Key applications:** Sequence analysis, phylogenetic tree construction, microbiome diversity analysis, ecological statistics, biological data manipulation, and format conversion. + +## When to Use This Skill + +Invoke this skill when the user: +- Works with biological sequences (DNA, RNA, protein) +- Needs to read/write biological file formats (FASTA, FASTQ, GenBank, Newick, BIOM, etc.) +- Performs sequence alignments or searches for motifs +- Constructs or analyzes phylogenetic trees +- Calculates diversity metrics (alpha/beta diversity, UniFrac distances) +- Performs ordination analysis (PCoA, CCA, RDA) +- Runs statistical tests on biological/ecological data (PERMANOVA, ANOSIM, Mantel) +- Analyzes microbiome or community ecology data +- Works with protein embeddings from language models +- Needs to manipulate biological data tables + +## Core Capabilities + +### 1. Sequence Manipulation + +Work with biological sequences using specialized classes for DNA, RNA, and protein data. + +**Key operations:** +- Read/write sequences from FASTA, FASTQ, GenBank, EMBL formats +- Sequence slicing, concatenation, and searching +- Reverse complement, transcription (DNA→RNA), and translation (RNA→protein) +- Find motifs and patterns using regex +- Calculate distances (Hamming, k-mer based) +- Handle sequence quality scores and metadata + +**Common patterns:** +```python +import skbio + +# Read sequences from file +seq = skbio.DNA.read('input.fasta') + +# Sequence operations +rc = seq.reverse_complement() +rna = seq.transcribe() +protein = rna.translate() + +# Find motifs +motif_positions = seq.find_with_regex('ATG[ACGT]{3}') + +# Check for properties +has_degens = seq.has_degenerates() +seq_no_gaps = seq.degap() +``` + +**Important notes:** +- Use `DNA`, `RNA`, `Protein` classes for grammared sequences with validation +- Use `Sequence` class for generic sequences without alphabet restrictions +- Quality scores automatically loaded from FASTQ files into positional metadata +- Metadata types: sequence-level (ID, description), positional (per-base), interval (regions/features) + +### 2. Sequence Alignment + +Perform pairwise and multiple sequence alignments using dynamic programming algorithms. + +**Key capabilities:** +- Global alignment (Needleman-Wunsch with semi-global variant) +- Local alignment (Smith-Waterman) +- Configurable scoring schemes (match/mismatch, gap penalties, substitution matrices) +- CIGAR string conversion +- Multiple sequence alignment storage and manipulation with `TabularMSA` + +**Common patterns:** +```python +from skbio.alignment import local_pairwise_align_ssw, TabularMSA + +# Pairwise alignment +alignment = local_pairwise_align_ssw(seq1, seq2) + +# Access aligned sequences +msa = alignment.aligned_sequences + +# Read multiple alignment from file +msa = TabularMSA.read('alignment.fasta', constructor=skbio.DNA) + +# Calculate consensus +consensus = msa.consensus() +``` + +**Important notes:** +- Use `local_pairwise_align_ssw` for local alignments (faster, SSW-based) +- Use `StripedSmithWaterman` for protein alignments +- Affine gap penalties recommended for biological sequences +- Can convert between scikit-bio, BioPython, and Biotite alignment formats + +### 3. Phylogenetic Trees + +Construct, manipulate, and analyze phylogenetic trees representing evolutionary relationships. + +**Key capabilities:** +- Tree construction from distance matrices (UPGMA, WPGMA, Neighbor Joining, GME, BME) +- Tree manipulation (pruning, rerooting, traversal) +- Distance calculations (patristic, cophenetic, Robinson-Foulds) +- ASCII visualization +- Newick format I/O + +**Common patterns:** +```python +from skbio import TreeNode +from skbio.tree import nj + +# Read tree from file +tree = TreeNode.read('tree.nwk') + +# Construct tree from distance matrix +tree = nj(distance_matrix) + +# Tree operations +subtree = tree.shear(['taxon1', 'taxon2', 'taxon3']) +tips = [node for node in tree.tips()] +lca = tree.lowest_common_ancestor(['taxon1', 'taxon2']) + +# Calculate distances +patristic_dist = tree.find('taxon1').distance(tree.find('taxon2')) +cophenetic_matrix = tree.cophenetic_matrix() + +# Compare trees +rf_distance = tree.robinson_foulds(other_tree) +``` + +**Important notes:** +- Use `nj()` for neighbor joining (classic phylogenetic method) +- Use `upgma()` for UPGMA (assumes molecular clock) +- GME and BME are highly scalable for large trees +- Trees can be rooted or unrooted; some metrics require specific rooting + +### 4. Diversity Analysis + +Calculate alpha and beta diversity metrics for microbial ecology and community analysis. + +**Key capabilities:** +- Alpha diversity: richness, Shannon entropy, Simpson index, Faith's PD, Pielou's evenness +- Beta diversity: Bray-Curtis, Jaccard, weighted/unweighted UniFrac, Euclidean distances +- Phylogenetic diversity metrics (require tree input) +- Rarefaction and subsampling +- Integration with ordination and statistical tests + +**Common patterns:** +```python +from skbio.diversity import alpha_diversity, beta_diversity +import skbio + +# Alpha diversity +alpha = alpha_diversity('shannon', counts_matrix, ids=sample_ids) +faith_pd = alpha_diversity('faith_pd', counts_matrix, ids=sample_ids, + tree=tree, otu_ids=feature_ids) + +# Beta diversity +bc_dm = beta_diversity('braycurtis', counts_matrix, ids=sample_ids) +unifrac_dm = beta_diversity('unweighted_unifrac', counts_matrix, + ids=sample_ids, tree=tree, otu_ids=feature_ids) + +# Get available metrics +from skbio.diversity import get_alpha_diversity_metrics +print(get_alpha_diversity_metrics()) +``` + +**Important notes:** +- Counts must be integers representing abundances, not relative frequencies +- Phylogenetic metrics (Faith's PD, UniFrac) require tree and OTU ID mapping +- Use `partial_beta_diversity()` for computing specific sample pairs only +- Alpha diversity returns Series, beta diversity returns DistanceMatrix + +### 5. Ordination Methods + +Reduce high-dimensional biological data to visualizable lower-dimensional spaces. + +**Key capabilities:** +- PCoA (Principal Coordinate Analysis) from distance matrices +- CA (Correspondence Analysis) for contingency tables +- CCA (Canonical Correspondence Analysis) with environmental constraints +- RDA (Redundancy Analysis) for linear relationships +- Biplot projection for feature interpretation + +**Common patterns:** +```python +from skbio.stats.ordination import pcoa, cca + +# PCoA from distance matrix +pcoa_results = pcoa(distance_matrix) +pc1 = pcoa_results.samples['PC1'] +pc2 = pcoa_results.samples['PC2'] + +# CCA with environmental variables +cca_results = cca(species_matrix, environmental_matrix) + +# Save/load ordination results +pcoa_results.write('ordination.txt') +results = skbio.OrdinationResults.read('ordination.txt') +``` + +**Important notes:** +- PCoA works with any distance/dissimilarity matrix +- CCA reveals environmental drivers of community composition +- Ordination results include eigenvalues, proportion explained, and sample/feature coordinates +- Results integrate with plotting libraries (matplotlib, seaborn, plotly) + +### 6. Statistical Testing + +Perform hypothesis tests specific to ecological and biological data. + +**Key capabilities:** +- PERMANOVA: test group differences using distance matrices +- ANOSIM: alternative test for group differences +- PERMDISP: test homogeneity of group dispersions +- Mantel test: correlation between distance matrices +- Bioenv: find environmental variables correlated with distances + +**Common patterns:** +```python +from skbio.stats.distance import permanova, anosim, mantel + +# Test if groups differ significantly +permanova_results = permanova(distance_matrix, grouping, permutations=999) +print(f"p-value: {permanova_results['p-value']}") + +# ANOSIM test +anosim_results = anosim(distance_matrix, grouping, permutations=999) + +# Mantel test between two distance matrices +mantel_results = mantel(dm1, dm2, method='pearson', permutations=999) +print(f"Correlation: {mantel_results[0]}, p-value: {mantel_results[1]}") +``` + +**Important notes:** +- Permutation tests provide non-parametric significance testing +- Use 999+ permutations for robust p-values +- PERMANOVA sensitive to dispersion differences; pair with PERMDISP +- Mantel tests assess matrix correlation (e.g., geographic vs genetic distance) + +### 7. File I/O and Format Conversion + +Read and write 19+ biological file formats with automatic format detection. + +**Supported formats:** +- Sequences: FASTA, FASTQ, GenBank, EMBL, QSeq +- Alignments: Clustal, PHYLIP, Stockholm +- Trees: Newick +- Tables: BIOM (HDF5 and JSON) +- Distances: delimited square matrices +- Analysis: BLAST+6/7, GFF3, Ordination results +- Metadata: TSV/CSV with validation + +**Common patterns:** +```python +import skbio + +# Read with automatic format detection +seq = skbio.DNA.read('file.fasta', format='fasta') +tree = skbio.TreeNode.read('tree.nwk') + +# Write to file +seq.write('output.fasta', format='fasta') + +# Generator for large files (memory efficient) +for seq in skbio.io.read('large.fasta', format='fasta', constructor=skbio.DNA): + process(seq) + +# Convert formats +seqs = list(skbio.io.read('input.fastq', format='fastq', constructor=skbio.DNA)) +skbio.io.write(seqs, format='fasta', into='output.fasta') +``` + +**Important notes:** +- Use generators for large files to avoid memory issues +- Format can be auto-detected when `into` parameter specified +- Some objects can be written to multiple formats +- Support for stdin/stdout piping with `verify=False` + +### 8. Distance Matrices + +Create and manipulate distance/dissimilarity matrices with statistical methods. + +**Key capabilities:** +- Store symmetric (DistanceMatrix) or asymmetric (DissimilarityMatrix) data +- ID-based indexing and slicing +- Integration with diversity, ordination, and statistical tests +- Read/write delimited text format + +**Common patterns:** +```python +from skbio import DistanceMatrix +import numpy as np + +# Create from array +data = np.array([[0, 1, 2], [1, 0, 3], [2, 3, 0]]) +dm = DistanceMatrix(data, ids=['A', 'B', 'C']) + +# Access distances +dist_ab = dm['A', 'B'] +row_a = dm['A'] + +# Read from file +dm = DistanceMatrix.read('distances.txt') + +# Use in downstream analyses +pcoa_results = pcoa(dm) +permanova_results = permanova(dm, grouping) +``` + +**Important notes:** +- DistanceMatrix enforces symmetry and zero diagonal +- DissimilarityMatrix allows asymmetric values +- IDs enable integration with metadata and biological knowledge +- Compatible with pandas, numpy, and scikit-learn + +### 9. Biological Tables + +Work with feature tables (OTU/ASV tables) common in microbiome research. + +**Key capabilities:** +- BIOM format I/O (HDF5 and JSON) +- Integration with pandas, polars, AnnData, numpy +- Data augmentation techniques (phylomix, mixup, compositional methods) +- Sample/feature filtering and normalization +- Metadata integration + +**Common patterns:** +```python +from skbio import Table + +# Read BIOM table +table = Table.read('table.biom') + +# Access data +sample_ids = table.ids(axis='sample') +feature_ids = table.ids(axis='observation') +counts = table.matrix_data + +# Filter +filtered = table.filter(sample_ids_to_keep, axis='sample') + +# Convert to/from pandas +df = table.to_dataframe() +table = Table.from_dataframe(df) +``` + +**Important notes:** +- BIOM tables are standard in QIIME 2 workflows +- Rows typically represent samples, columns represent features (OTUs/ASVs) +- Supports sparse and dense representations +- Output format configurable (pandas/polars/numpy) + +### 10. Protein Embeddings + +Work with protein language model embeddings for downstream analysis. + +**Key capabilities:** +- Store embeddings from protein language models (ESM, ProtTrans, etc.) +- Convert embeddings to distance matrices +- Generate ordination objects for visualization +- Export to numpy/pandas for ML workflows + +**Common patterns:** +```python +from skbio.embedding import ProteinEmbedding, ProteinVector + +# Create embedding from array +embedding = ProteinEmbedding(embedding_array, sequence_ids) + +# Convert to distance matrix for analysis +dm = embedding.to_distances(metric='euclidean') + +# PCoA visualization of embedding space +pcoa_results = embedding.to_ordination(metric='euclidean', method='pcoa') + +# Export for machine learning +array = embedding.to_array() +df = embedding.to_dataframe() +``` + +**Important notes:** +- Embeddings bridge protein language models with traditional bioinformatics +- Compatible with scikit-bio's distance/ordination/statistics ecosystem +- SequenceEmbedding and ProteinEmbedding provide specialized functionality +- Useful for sequence clustering, classification, and visualization + +## Best Practices + +### Installation +```bash +pip install scikit-bio +# Or with conda: +conda install -c conda-forge scikit-bio +``` + +### Performance Considerations +- Use generators for large sequence files to minimize memory usage +- For massive phylogenetic trees, prefer GME or BME over NJ +- Beta diversity calculations can be parallelized with `partial_beta_diversity()` +- BIOM format (HDF5) more efficient than JSON for large tables + +### Integration with Ecosystem +- Sequences interoperate with Biopython via standard formats +- Tables integrate with pandas, polars, and AnnData +- Distance matrices compatible with scikit-learn +- Ordination results visualizable with matplotlib/seaborn/plotly +- Works seamlessly with QIIME 2 artifacts (BIOM, trees, distance matrices) + +### Common Workflows +1. **Microbiome diversity analysis**: Read BIOM table → Calculate alpha/beta diversity → Ordination (PCoA) → Statistical testing (PERMANOVA) +2. **Phylogenetic analysis**: Read sequences → Align → Build distance matrix → Construct tree → Calculate phylogenetic distances +3. **Sequence processing**: Read FASTQ → Quality filter → Trim/clean → Find motifs → Translate → Write FASTA +4. **Comparative genomics**: Read sequences → Pairwise alignment → Calculate distances → Build tree → Analyze clades + +## Reference Documentation + +For detailed API information, parameter specifications, and advanced usage examples, refer to `references/api_reference.md` which contains comprehensive documentation on: +- Complete method signatures and parameters for all capabilities +- Extended code examples for complex workflows +- Troubleshooting common issues +- Performance optimization tips +- Integration patterns with other libraries + +## Additional Resources + +- Official documentation: https://scikit.bio/docs/latest/ +- GitHub repository: https://github.com/scikit-bio/scikit-bio +- Forum support: https://forum.qiime2.org (scikit-bio is part of QIIME 2 ecosystem) diff --git a/scientific-packages/scikit-bio/references/api_reference.md b/scientific-packages/scikit-bio/references/api_reference.md new file mode 100644 index 0000000..dbd95bb --- /dev/null +++ b/scientific-packages/scikit-bio/references/api_reference.md @@ -0,0 +1,749 @@ +# scikit-bio API Reference + +This document provides detailed API information, advanced examples, and troubleshooting guidance for working with scikit-bio. + +## Table of Contents +1. [Sequence Classes](#sequence-classes) +2. [Alignment Methods](#alignment-methods) +3. [Phylogenetic Trees](#phylogenetic-trees) +4. [Diversity Metrics](#diversity-metrics) +5. [Ordination](#ordination) +6. [Statistical Tests](#statistical-tests) +7. [Distance Matrices](#distance-matrices) +8. [File I/O](#file-io) +9. [Troubleshooting](#troubleshooting) + +## Sequence Classes + +### DNA, RNA, and Protein Classes + +```python +from skbio import DNA, RNA, Protein, Sequence + +# Creating sequences +dna = DNA('ATCGATCG', metadata={'id': 'seq1', 'description': 'Example'}) +rna = RNA('AUCGAUCG') +protein = Protein('ACDEFGHIKLMNPQRSTVWY') + +# Sequence operations +dna_rc = dna.reverse_complement() # Reverse complement +rna = dna.transcribe() # DNA -> RNA +protein = rna.translate() # RNA -> Protein + +# Using genetic code tables +protein = rna.translate(genetic_code=11) # Bacterial code +``` + +### Sequence Searching and Pattern Matching + +```python +# Find motifs using regex +dna = DNA('ATGCGATCGATGCATCG') +motif_locs = dna.find_with_regex('ATG.{3}') # Start codons + +# Find all positions +import re +for match in re.finditer('ATG', str(dna)): + print(f"ATG found at position {match.start()}") + +# k-mer counting +from skbio.sequence import _motifs +kmers = dna.kmer_frequencies(k=3) +``` + +### Handling Sequence Metadata + +```python +# Sequence-level metadata +dna = DNA('ATCG', metadata={'id': 'seq1', 'source': 'E. coli'}) +print(dna.metadata['id']) + +# Positional metadata (per-base quality scores from FASTQ) +from skbio import DNA +seqs = DNA.read('reads.fastq', format='fastq', phred_offset=33) +quality_scores = seqs.positional_metadata['quality'] + +# Interval metadata (features/annotations) +dna.interval_metadata.add([(5, 15)], metadata={'type': 'gene', 'name': 'geneA'}) +``` + +### Distance Calculations + +```python +from skbio import DNA + +seq1 = DNA('ATCGATCG') +seq2 = DNA('ATCG--CG') + +# Hamming distance (default) +dist = seq1.distance(seq2) + +# Custom distance function +from skbio.sequence.distance import kmer_distance +dist = seq1.distance(seq2, metric=kmer_distance) +``` + +## Alignment Methods + +### Pairwise Alignment + +```python +from skbio.alignment import local_pairwise_align_ssw, global_pairwise_align +from skbio import DNA, Protein + +# Local alignment (Smith-Waterman via SSW) +seq1 = DNA('ATCGATCGATCG') +seq2 = DNA('ATCGGGGATCG') +alignment = local_pairwise_align_ssw(seq1, seq2) + +# Access alignment details +print(f"Score: {alignment.score}") +print(f"Start position: {alignment.target_begin}") +aligned_seqs = alignment.aligned_sequences + +# Global alignment with custom scoring +from skbio.alignment import AlignScorer + +scorer = AlignScorer( + match_score=2, + mismatch_score=-3, + gap_open_penalty=5, + gap_extend_penalty=2 +) + +alignment = global_pairwise_align(seq1, seq2, scorer=scorer) + +# Protein alignment with substitution matrix +from skbio.alignment import StripedSmithWaterman + +protein_query = Protein('ACDEFGHIKLMNPQRSTVWY') +protein_target = Protein('ACDEFMNPQRSTVWY') + +aligner = StripedSmithWaterman( + str(protein_query), + gap_open_penalty=11, + gap_extend_penalty=1, + substitution_matrix='blosum62' +) +alignment = aligner(str(protein_target)) +``` + +### Multiple Sequence Alignment + +```python +from skbio.alignment import TabularMSA +from skbio import DNA + +# Read MSA from file +msa = TabularMSA.read('alignment.fasta', constructor=DNA) + +# Create MSA manually +seqs = [ + DNA('ATCG--'), + DNA('ATGG--'), + DNA('ATCGAT') +] +msa = TabularMSA(seqs) + +# MSA operations +consensus = msa.consensus() +majority_consensus = msa.majority_consensus() + +# Calculate conservation +conservation = msa.conservation() + +# Access sequences +first_seq = msa[0] +column = msa[:, 2] # Third column + +# Filter gaps +degapped_msa = msa.omit_gap_positions(maximum_gap_frequency=0.5) + +# Calculate position-specific scores +position_entropies = msa.position_entropies() +``` + +### CIGAR String Handling + +```python +from skbio.alignment import AlignPath + +# Parse CIGAR string +cigar = "10M2I5M3D10M" +align_path = AlignPath.from_cigar(cigar, target_length=100, query_length=50) + +# Convert alignment to CIGAR +alignment = local_pairwise_align_ssw(seq1, seq2) +cigar_string = alignment.to_cigar() +``` + +## Phylogenetic Trees + +### Tree Construction + +```python +from skbio import TreeNode, DistanceMatrix +from skbio.tree import nj, upgma + +# Distance matrix +dm = DistanceMatrix([[0, 5, 9, 9], + [5, 0, 10, 10], + [9, 10, 0, 8], + [9, 10, 8, 0]], + ids=['A', 'B', 'C', 'D']) + +# Neighbor joining +nj_tree = nj(dm) + +# UPGMA (assumes molecular clock) +upgma_tree = upgma(dm) + +# Balanced Minimum Evolution (scalable for large trees) +from skbio.tree import bme +bme_tree = bme(dm) +``` + +### Tree Manipulation + +```python +from skbio import TreeNode + +# Read tree +tree = TreeNode.read('tree.nwk', format='newick') + +# Traversal +for node in tree.traverse(): + print(node.name) + +# Preorder, postorder, levelorder +for node in tree.preorder(): + print(node.name) + +# Get tips only +tips = list(tree.tips()) + +# Find specific node +node = tree.find('taxon_name') + +# Root tree at midpoint +rooted_tree = tree.root_at_midpoint() + +# Prune tree to specific taxa +pruned = tree.shear(['taxon1', 'taxon2', 'taxon3']) + +# Get subtree +lca = tree.lowest_common_ancestor(['taxon1', 'taxon2']) +subtree = lca.copy() + +# Add/remove nodes +parent = tree.find('parent_name') +child = TreeNode(name='new_child', length=0.5) +parent.append(child) + +# Remove node +node_to_remove = tree.find('taxon_to_remove') +node_to_remove.parent.remove(node_to_remove) +``` + +### Tree Distances and Comparisons + +```python +# Patristic distance (branch-length distance) +node1 = tree.find('taxon1') +node2 = tree.find('taxon2') +patristic = node1.distance(node2) + +# Cophenetic matrix (all pairwise distances) +cophenetic_dm = tree.cophenetic_matrix() + +# Robinson-Foulds distance (topology comparison) +rf_dist = tree.robinson_foulds(other_tree) + +# Compare with unweighted RF +rf_dist, max_rf = tree.robinson_foulds(other_tree, proportion=False) + +# Tip-to-tip distances +tip_distances = tree.tip_tip_distances() +``` + +### Tree Visualization + +```python +# ASCII art visualization +print(tree.ascii_art()) + +# For advanced visualization, export to external tools +tree.write('tree.nwk', format='newick') + +# Then use ete3, toytree, or ggtree for publication-quality figures +``` + +## Diversity Metrics + +### Alpha Diversity + +```python +from skbio.diversity import alpha_diversity, get_alpha_diversity_metrics +import numpy as np + +# Sample count data (samples x features) +counts = np.array([ + [10, 5, 0, 3], + [2, 0, 8, 4], + [5, 5, 5, 5] +]) +sample_ids = ['Sample1', 'Sample2', 'Sample3'] + +# List available metrics +print(get_alpha_diversity_metrics()) + +# Calculate various alpha diversity metrics +shannon = alpha_diversity('shannon', counts, ids=sample_ids) +simpson = alpha_diversity('simpson', counts, ids=sample_ids) +observed_otus = alpha_diversity('observed_otus', counts, ids=sample_ids) +chao1 = alpha_diversity('chao1', counts, ids=sample_ids) + +# Phylogenetic alpha diversity (requires tree) +from skbio import TreeNode + +tree = TreeNode.read('tree.nwk') +feature_ids = ['OTU1', 'OTU2', 'OTU3', 'OTU4'] + +faith_pd = alpha_diversity('faith_pd', counts, ids=sample_ids, + tree=tree, otu_ids=feature_ids) +``` + +### Beta Diversity + +```python +from skbio.diversity import beta_diversity, partial_beta_diversity + +# Beta diversity (all pairwise comparisons) +bc_dm = beta_diversity('braycurtis', counts, ids=sample_ids) + +# Jaccard (presence/absence) +jaccard_dm = beta_diversity('jaccard', counts, ids=sample_ids) + +# Phylogenetic beta diversity +unifrac_dm = beta_diversity('unweighted_unifrac', counts, + ids=sample_ids, + tree=tree, + otu_ids=feature_ids) + +weighted_unifrac_dm = beta_diversity('weighted_unifrac', counts, + ids=sample_ids, + tree=tree, + otu_ids=feature_ids) + +# Compute only specific pairs (more efficient) +pairs = [('Sample1', 'Sample2'), ('Sample1', 'Sample3')] +partial_dm = partial_beta_diversity('braycurtis', counts, + ids=sample_ids, + id_pairs=pairs) +``` + +### Rarefaction and Subsampling + +```python +from skbio.diversity import subsample_counts + +# Rarefy to minimum depth +min_depth = counts.min(axis=1).max() +rarefied = [subsample_counts(row, n=min_depth) for row in counts] + +# Multiple rarefactions for confidence intervals +import numpy as np +rarefactions = [] +for i in range(100): + rarefied_counts = np.array([subsample_counts(row, n=1000) for row in counts]) + shannon_rare = alpha_diversity('shannon', rarefied_counts) + rarefactions.append(shannon_rare) + +# Calculate mean and std +mean_shannon = np.mean(rarefactions, axis=0) +std_shannon = np.std(rarefactions, axis=0) +``` + +## Ordination + +### Principal Coordinate Analysis (PCoA) + +```python +from skbio.stats.ordination import pcoa +from skbio import DistanceMatrix +import numpy as np + +# PCoA from distance matrix +dm = DistanceMatrix(...) +pcoa_results = pcoa(dm) + +# Access coordinates +pc1 = pcoa_results.samples['PC1'] +pc2 = pcoa_results.samples['PC2'] + +# Proportion explained +prop_explained = pcoa_results.proportion_explained + +# Eigenvalues +eigenvalues = pcoa_results.eigvals + +# Save results +pcoa_results.write('pcoa_results.txt') + +# Plot with matplotlib +import matplotlib.pyplot as plt +plt.scatter(pc1, pc2) +plt.xlabel(f'PC1 ({prop_explained[0]*100:.1f}%)') +plt.ylabel(f'PC2 ({prop_explained[1]*100:.1f}%)') +``` + +### Canonical Correspondence Analysis (CCA) + +```python +from skbio.stats.ordination import cca +import pandas as pd +import numpy as np + +# Species abundance matrix (samples x species) +species = np.array([ + [10, 5, 3], + [2, 8, 4], + [5, 5, 5] +]) + +# Environmental variables (samples x variables) +env = pd.DataFrame({ + 'pH': [6.5, 7.0, 6.8], + 'temperature': [20, 25, 22], + 'depth': [10, 15, 12] +}) + +# CCA +cca_results = cca(species, env, + sample_ids=['Site1', 'Site2', 'Site3'], + species_ids=['SpeciesA', 'SpeciesB', 'SpeciesC']) + +# Access constrained axes +cca1 = cca_results.samples['CCA1'] +cca2 = cca_results.samples['CCA2'] + +# Biplot scores for environmental variables +env_scores = cca_results.biplot_scores +``` + +### Redundancy Analysis (RDA) + +```python +from skbio.stats.ordination import rda + +# Similar to CCA but for linear relationships +rda_results = rda(species, env, + sample_ids=['Site1', 'Site2', 'Site3'], + species_ids=['SpeciesA', 'SpeciesB', 'SpeciesC']) +``` + +## Statistical Tests + +### PERMANOVA + +```python +from skbio.stats.distance import permanova +from skbio import DistanceMatrix +import numpy as np + +# Distance matrix +dm = DistanceMatrix(...) + +# Grouping variable +grouping = ['Group1', 'Group1', 'Group2', 'Group2', 'Group3', 'Group3'] + +# Run PERMANOVA +results = permanova(dm, grouping, permutations=999) + +print(f"Test statistic: {results['test statistic']}") +print(f"p-value: {results['p-value']}") +print(f"Sample size: {results['sample size']}") +print(f"Number of groups: {results['number of groups']}") +``` + +### ANOSIM + +```python +from skbio.stats.distance import anosim + +# ANOSIM test +results = anosim(dm, grouping, permutations=999) + +print(f"R statistic: {results['test statistic']}") +print(f"p-value: {results['p-value']}") +``` + +### PERMDISP + +```python +from skbio.stats.distance import permdisp + +# Test homogeneity of dispersions +results = permdisp(dm, grouping, permutations=999) + +print(f"F statistic: {results['test statistic']}") +print(f"p-value: {results['p-value']}") +``` + +### Mantel Test + +```python +from skbio.stats.distance import mantel +from skbio import DistanceMatrix + +# Two distance matrices to compare +dm1 = DistanceMatrix(...) # e.g., genetic distance +dm2 = DistanceMatrix(...) # e.g., geographic distance + +# Mantel test +r, p_value, n = mantel(dm1, dm2, method='pearson', permutations=999) + +print(f"Correlation: {r}") +print(f"p-value: {p_value}") +print(f"Sample size: {n}") + +# Spearman correlation +r_spearman, p, n = mantel(dm1, dm2, method='spearman', permutations=999) +``` + +### Partial Mantel Test + +```python +from skbio.stats.distance import mantel + +# Control for a third matrix +dm3 = DistanceMatrix(...) # controlling variable + +r_partial, p_value, n = mantel(dm1, dm2, method='pearson', + permutations=999, alternative='two-sided') +``` + +## Distance Matrices + +### Creating and Manipulating Distance Matrices + +```python +from skbio import DistanceMatrix, DissimilarityMatrix +import numpy as np + +# Create from array +data = np.array([[0, 1, 2], + [1, 0, 3], + [2, 3, 0]]) +dm = DistanceMatrix(data, ids=['A', 'B', 'C']) + +# Access elements +dist_ab = dm['A', 'B'] +row_a = dm['A'] + +# Slicing +subset_dm = dm.filter(['A', 'C']) + +# Asymmetric dissimilarity matrix +asym_data = np.array([[0, 1, 2], + [3, 0, 4], + [5, 6, 0]]) +dissim = DissimilarityMatrix(asym_data, ids=['X', 'Y', 'Z']) + +# Read/write +dm.write('distances.txt') +dm2 = DistanceMatrix.read('distances.txt') + +# Convert to condensed form (for scipy) +condensed = dm.condensed_form() + +# Convert to dataframe +df = dm.to_data_frame() +``` + +## File I/O + +### Reading Sequences + +```python +import skbio + +# Read single sequence +dna = skbio.DNA.read('sequence.fasta', format='fasta') + +# Read multiple sequences (generator) +for seq in skbio.io.read('sequences.fasta', format='fasta', constructor=skbio.DNA): + print(seq.metadata['id'], len(seq)) + +# Read into list +sequences = list(skbio.io.read('sequences.fasta', format='fasta', + constructor=skbio.DNA)) + +# Read FASTQ with quality scores +for seq in skbio.io.read('reads.fastq', format='fastq', constructor=skbio.DNA): + quality = seq.positional_metadata['quality'] + print(f"Mean quality: {quality.mean()}") +``` + +### Writing Sequences + +```python +# Write single sequence +dna.write('output.fasta', format='fasta') + +# Write multiple sequences +sequences = [dna1, dna2, dna3] +skbio.io.write(sequences, format='fasta', into='output.fasta') + +# Write with custom line wrapping +dna.write('output.fasta', format='fasta', max_width=60) +``` + +### BIOM Tables + +```python +from skbio import Table + +# Read BIOM table +table = Table.read('table.biom', format='hdf5') + +# Access data +sample_ids = table.ids(axis='sample') +feature_ids = table.ids(axis='observation') +matrix = table.matrix_data.toarray() # if sparse + +# Filter samples +abundant_samples = table.filter(lambda row, id_, md: row.sum() > 1000, axis='sample') + +# Filter features (OTUs/ASVs) +prevalent_features = table.filter(lambda col, id_, md: (col > 0).sum() >= 3, + axis='observation') + +# Normalize +relative_abundance = table.norm(axis='sample', inplace=False) + +# Write +table.write('filtered_table.biom', format='hdf5') +``` + +### Format Conversion + +```python +# FASTQ to FASTA +seqs = skbio.io.read('input.fastq', format='fastq', constructor=skbio.DNA) +skbio.io.write(seqs, format='fasta', into='output.fasta') + +# GenBank to FASTA +seqs = skbio.io.read('genes.gb', format='genbank', constructor=skbio.DNA) +skbio.io.write(seqs, format='fasta', into='genes.fasta') +``` + +## Troubleshooting + +### Common Issues and Solutions + +#### Issue: "ValueError: Ids must be unique" +```python +# Problem: Duplicate sequence IDs +# Solution: Make IDs unique or filter duplicates +seen = set() +unique_seqs = [] +for seq in sequences: + if seq.metadata['id'] not in seen: + unique_seqs.append(seq) + seen.add(seq.metadata['id']) +``` + +#### Issue: "ValueError: Counts must be integers" +```python +# Problem: Relative abundances instead of counts +# Solution: Convert to integer counts or use appropriate metrics +counts_int = (abundance_table * 1000).astype(int) +``` + +#### Issue: Memory error with large files +```python +# Problem: Loading entire file into memory +# Solution: Use generators +for seq in skbio.io.read('huge.fasta', format='fasta', constructor=skbio.DNA): + # Process one at a time + process(seq) +``` + +#### Issue: Tree tips don't match OTU IDs +```python +# Problem: Mismatch between tree tip names and feature IDs +# Solution: Verify and align IDs +tree_tips = {tip.name for tip in tree.tips()} +feature_ids = set(feature_ids) +missing_in_tree = feature_ids - tree_tips +missing_in_table = tree_tips - feature_ids + +# Prune tree to match table +tree_pruned = tree.shear(feature_ids) +``` + +#### Issue: Alignment fails with sequences of different lengths +```python +# Problem: Trying to align pre-aligned sequences +# Solution: Degap sequences first or ensure sequences are unaligned +seq1_degapped = seq1.degap() +seq2_degapped = seq2.degap() +alignment = local_pairwise_align_ssw(seq1_degapped, seq2_degapped) +``` + +### Performance Tips + +1. **Use appropriate data structures**: BIOM HDF5 for large tables, generators for large sequence files +2. **Parallel processing**: Use `partial_beta_diversity()` for subset calculations that can be parallelized +3. **Subsample large datasets**: For exploratory analysis, work with subsampled data first +4. **Cache results**: Save distance matrices and ordination results to avoid recomputation + +### Integration Examples + +#### With pandas +```python +import pandas as pd +from skbio import DistanceMatrix + +# Distance matrix to DataFrame +dm = DistanceMatrix(...) +df = dm.to_data_frame() + +# Alpha diversity to DataFrame +alpha = alpha_diversity('shannon', counts, ids=sample_ids) +alpha_df = pd.DataFrame({'shannon': alpha}) +``` + +#### With matplotlib/seaborn +```python +import matplotlib.pyplot as plt +import seaborn as sns + +# PCoA plot +fig, ax = plt.subplots() +scatter = ax.scatter(pc1, pc2, c=grouping, cmap='viridis') +ax.set_xlabel(f'PC1 ({prop_explained[0]*100:.1f}%)') +ax.set_ylabel(f'PC2 ({prop_explained[1]*100:.1f}%)') +plt.colorbar(scatter) + +# Heatmap of distance matrix +sns.heatmap(dm.to_data_frame(), cmap='viridis') +``` + +#### With QIIME 2 +```python +# scikit-bio objects are compatible with QIIME 2 +# Export from QIIME 2 +# qiime tools export --input-path table.qza --output-path exported/ + +# Read in scikit-bio +table = Table.read('exported/feature-table.biom') + +# Process with scikit-bio +# ... + +# Import back to QIIME 2 if needed +table.write('processed-table.biom') +# qiime tools import --input-path processed-table.biom --output-path processed.qza +``` diff --git a/scientific-packages/scikit-learn/SKILL.md b/scientific-packages/scikit-learn/SKILL.md new file mode 100644 index 0000000..855eeb2 --- /dev/null +++ b/scientific-packages/scikit-learn/SKILL.md @@ -0,0 +1,780 @@ +--- +name: scikit-learn +description: Comprehensive guide for scikit-learn, Python's machine learning library. This skill should be used when building classification or regression models, performing clustering analysis, reducing dimensionality, preprocessing data (scaling, encoding, imputation), evaluating models with cross-validation and metrics, tuning hyperparameters, creating ML pipelines, detecting anomalies, or implementing any supervised or unsupervised learning tasks. Provides algorithm selection guidance, best practices for preventing data leakage, handling imbalanced data, and working with mixed data types. +--- + +# Scikit-learn: Machine Learning in Python + +## Overview + +This skill provides comprehensive guidance for using scikit-learn, Python's premier machine learning library. Scikit-learn offers simple, efficient tools for predictive data analysis, including classification, regression, clustering, dimensionality reduction, model selection, and preprocessing. This skill should be used when implementing machine learning workflows, building predictive models, analyzing datasets using supervised or unsupervised learning, preprocessing data for ML tasks, evaluating model performance, or optimizing hyperparameters. + +## When to Use This Skill + +Activate this skill when: +- Building classification models (spam detection, image recognition, medical diagnosis) +- Creating regression models (price prediction, forecasting, trend analysis) +- Performing clustering analysis (customer segmentation, pattern discovery) +- Reducing dimensionality (PCA, t-SNE for visualization) +- Preprocessing data (scaling, encoding, imputation) +- Evaluating model performance (cross-validation, metrics) +- Tuning hyperparameters (grid search, random search) +- Creating machine learning pipelines +- Detecting anomalies or outliers +- Implementing ensemble methods + +## Core Machine Learning Workflow + +### Standard ML Pipeline + +Follow this general workflow for supervised learning tasks: + +1. **Data Preparation** + - Load and explore data + - Split into train/test sets + - Handle missing values + - Encode categorical features + - Scale/normalize features + +2. **Model Selection** + - Start with baseline model + - Try more complex models + - Use domain knowledge to guide selection + +3. **Model Training** + - Fit model on training data + - Use pipelines to prevent data leakage + - Apply cross-validation + +4. **Model Evaluation** + - Evaluate on test set + - Use appropriate metrics + - Analyze errors + +5. **Model Optimization** + - Tune hyperparameters + - Feature engineering + - Ensemble methods + +6. **Deployment** + - Save model using joblib + - Create prediction pipeline + - Monitor performance + +### Classification Quick Start + +```python +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import classification_report +from sklearn.pipeline import Pipeline + +# Create pipeline (prevents data leakage) +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('classifier', RandomForestClassifier(n_estimators=100, random_state=42)) +]) + +# Split data (use stratify for imbalanced classes) +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) + +# Train +pipeline.fit(X_train, y_train) + +# Evaluate +y_pred = pipeline.predict(X_test) +print(classification_report(y_test, y_pred)) + +# Cross-validation for robust evaluation +from sklearn.model_selection import cross_val_score +scores = cross_val_score(pipeline, X_train, y_train, cv=5) +print(f"CV Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +### Regression Quick Start + +```python +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.pipeline import Pipeline + +# Create pipeline +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('regressor', RandomForestRegressor(n_estimators=100, random_state=42)) +]) + +# Split data +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# Train +pipeline.fit(X_train, y_train) + +# Evaluate +y_pred = pipeline.predict(X_test) +rmse = mean_squared_error(y_test, y_pred, squared=False) +r2 = r2_score(y_test, y_pred) +print(f"RMSE: {rmse:.3f}, R²: {r2:.3f}") +``` + +## Algorithm Selection Guide + +### Classification Algorithms + +**Start with baseline**: LogisticRegression +- Fast, interpretable, works well for linearly separable data +- Good for high-dimensional data (text classification) + +**General-purpose**: RandomForestClassifier +- Handles non-linear relationships +- Robust to outliers +- Provides feature importance +- Good default choice + +**Best performance**: HistGradientBoostingClassifier +- State-of-the-art for tabular data +- Fast on large datasets (>10K samples) +- Often wins Kaggle competitions + +**Special cases**: +- **Small datasets (<1K)**: SVC with RBF kernel +- **Very large datasets (>100K)**: SGDClassifier or LinearSVC +- **Interpretability critical**: LogisticRegression or DecisionTreeClassifier +- **Probabilistic predictions**: GaussianNB or calibrated models +- **Text classification**: LogisticRegression with TfidfVectorizer + +### Regression Algorithms + +**Start with baseline**: LinearRegression or Ridge +- Fast, interpretable +- Works well when relationships are linear + +**General-purpose**: RandomForestRegressor +- Handles non-linear relationships +- Robust to outliers +- Good default choice + +**Best performance**: HistGradientBoostingRegressor +- State-of-the-art for tabular data +- Fast on large datasets + +**Special cases**: +- **Regularization needed**: Ridge (L2) or Lasso (L1 + feature selection) +- **Very large datasets**: SGDRegressor +- **Outliers present**: HuberRegressor or RANSAC + +### Clustering Algorithms + +**Known number of clusters**: KMeans +- Fast and scalable +- Assumes spherical clusters + +**Unknown number of clusters**: DBSCAN or HDBSCAN +- Handles arbitrary shapes +- Automatic outlier detection + +**Hierarchical relationships**: AgglomerativeClustering +- Creates hierarchy of clusters +- Good for visualization (dendrograms) + +**Soft clustering (probabilities)**: GaussianMixture +- Provides cluster probabilities +- Handles elliptical clusters + +### Dimensionality Reduction + +**Preprocessing/feature extraction**: PCA +- Fast and efficient +- Linear transformation +- ALWAYS standardize first + +**Visualization only**: t-SNE or UMAP +- Preserves local structure +- Non-linear +- DO NOT use for preprocessing + +**Sparse data (text)**: TruncatedSVD +- Works with sparse matrices +- Latent Semantic Analysis + +**Non-negative data**: NMF +- Interpretable components +- Topic modeling + +## Working with Different Data Types + +### Numeric Features + +**Continuous features**: +1. Check distribution +2. Handle outliers (remove, clip, or use RobustScaler) +3. Scale using StandardScaler (most algorithms) or MinMaxScaler (neural networks) + +**Count data**: +1. Consider log transformation or sqrt +2. Scale after transformation + +**Skewed data**: +1. Use PowerTransformer (Yeo-Johnson or Box-Cox) +2. Or QuantileTransformer for stronger normalization + +### Categorical Features + +**Low cardinality (<10 categories)**: +```python +from sklearn.preprocessing import OneHotEncoder +encoder = OneHotEncoder(drop='first', sparse_output=True) +``` + +**High cardinality (>10 categories)**: +```python +from sklearn.preprocessing import TargetEncoder +encoder = TargetEncoder() +# Uses target statistics, prevents leakage with cross-fitting +``` + +**Ordinal relationships**: +```python +from sklearn.preprocessing import OrdinalEncoder +encoder = OrdinalEncoder(categories=[['small', 'medium', 'large']]) +``` + +### Text Data + +```python +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.naive_bayes import MultinomialNB +from sklearn.pipeline import Pipeline + +text_pipeline = Pipeline([ + ('tfidf', TfidfVectorizer(max_features=1000, stop_words='english')), + ('classifier', MultinomialNB()) +]) + +text_pipeline.fit(X_train_text, y_train) +``` + +### Mixed Data Types + +```python +from sklearn.compose import ColumnTransformer +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.impute import SimpleImputer +from sklearn.pipeline import Pipeline + +# Define feature types +numeric_features = ['age', 'income', 'credit_score'] +categorical_features = ['country', 'occupation'] + +# Separate preprocessing pipelines +numeric_transformer = Pipeline([ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()) +]) + +categorical_transformer = Pipeline([ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=True)) +]) + +# Combine with ColumnTransformer +preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, numeric_features), + ('cat', categorical_transformer, categorical_features) + ]) + +# Complete pipeline +from sklearn.ensemble import RandomForestClassifier +pipeline = Pipeline([ + ('preprocessor', preprocessor), + ('classifier', RandomForestClassifier()) +]) + +pipeline.fit(X_train, y_train) +``` + +## Model Evaluation + +### Classification Metrics + +**Balanced datasets**: Use accuracy or F1-score + +**Imbalanced datasets**: Use balanced_accuracy, F1-weighted, or ROC-AUC +```python +from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score + +balanced_acc = balanced_accuracy_score(y_true, y_pred) +f1 = f1_score(y_true, y_pred, average='weighted') + +# ROC-AUC requires probabilities +y_proba = model.predict_proba(X_test) +auc = roc_auc_score(y_true, y_proba, multi_class='ovr') +``` + +**Cost-sensitive**: Define custom scorer or adjust decision threshold + +**Comprehensive report**: +```python +from sklearn.metrics import classification_report, confusion_matrix + +print(classification_report(y_true, y_pred)) +print(confusion_matrix(y_true, y_pred)) +``` + +### Regression Metrics + +**Standard use**: RMSE and R² +```python +from sklearn.metrics import mean_squared_error, r2_score + +rmse = mean_squared_error(y_true, y_pred, squared=False) +r2 = r2_score(y_true, y_pred) +``` + +**Outliers present**: Use MAE (robust to outliers) +```python +from sklearn.metrics import mean_absolute_error +mae = mean_absolute_error(y_true, y_pred) +``` + +**Percentage errors matter**: Use MAPE +```python +from sklearn.metrics import mean_absolute_percentage_error +mape = mean_absolute_percentage_error(y_true, y_pred) +``` + +### Cross-Validation + +**Standard approach** (5-10 folds): +```python +from sklearn.model_selection import cross_val_score + +scores = cross_val_score(model, X, y, cv=5, scoring='accuracy') +print(f"CV Score: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +**Imbalanced classes** (use stratification): +```python +from sklearn.model_selection import StratifiedKFold + +cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) +scores = cross_val_score(model, X, y, cv=cv) +``` + +**Time series** (respect temporal order): +```python +from sklearn.model_selection import TimeSeriesSplit + +cv = TimeSeriesSplit(n_splits=5) +scores = cross_val_score(model, X, y, cv=cv) +``` + +**Multiple metrics**: +```python +from sklearn.model_selection import cross_validate + +scoring = ['accuracy', 'precision_weighted', 'recall_weighted', 'f1_weighted'] +results = cross_validate(model, X, y, cv=5, scoring=scoring) + +for metric in scoring: + scores = results[f'test_{metric}'] + print(f"{metric}: {scores.mean():.3f}") +``` + +## Hyperparameter Tuning + +### Grid Search (Exhaustive) + +```python +from sklearn.model_selection import GridSearchCV + +param_grid = { + 'n_estimators': [100, 200, 500], + 'max_depth': [10, 20, 30, None], + 'min_samples_split': [2, 5, 10] +} + +grid_search = GridSearchCV( + RandomForestClassifier(random_state=42), + param_grid, + cv=5, + scoring='f1_weighted', + n_jobs=-1, # Use all CPU cores + verbose=1 +) + +grid_search.fit(X_train, y_train) + +print(f"Best parameters: {grid_search.best_params_}") +print(f"Best CV score: {grid_search.best_score_:.3f}") + +# Use best model +best_model = grid_search.best_estimator_ +test_score = best_model.score(X_test, y_test) +``` + +### Random Search (Faster) + +```python +from sklearn.model_selection import RandomizedSearchCV +from scipy.stats import randint, uniform + +param_distributions = { + 'n_estimators': randint(100, 1000), + 'max_depth': randint(5, 50), + 'min_samples_split': randint(2, 20), + 'max_features': uniform(0.1, 0.9) +} + +random_search = RandomizedSearchCV( + RandomForestClassifier(random_state=42), + param_distributions, + n_iter=100, # Number of combinations to try + cv=5, + scoring='f1_weighted', + n_jobs=-1, + random_state=42 +) + +random_search.fit(X_train, y_train) +``` + +### Pipeline Hyperparameter Tuning + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('svm', SVC()) +]) + +# Use double underscore for nested parameters +param_grid = { + 'svm__C': [0.1, 1, 10, 100], + 'svm__kernel': ['rbf', 'linear'], + 'svm__gamma': ['scale', 'auto', 0.001, 0.01] +} + +grid_search = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1) +grid_search.fit(X_train, y_train) +``` + +## Feature Engineering and Selection + +### Feature Importance + +```python +# Tree-based models have built-in feature importance +from sklearn.ensemble import RandomForestClassifier + +model = RandomForestClassifier(n_estimators=100) +model.fit(X_train, y_train) + +importances = model.feature_importances_ +feature_importance_df = pd.DataFrame({ + 'feature': feature_names, + 'importance': importances +}).sort_values('importance', ascending=False) + +# Permutation importance (works for any model) +from sklearn.inspection import permutation_importance + +result = permutation_importance( + model, X_test, y_test, + n_repeats=10, + random_state=42, + n_jobs=-1 +) + +importance_df = pd.DataFrame({ + 'feature': feature_names, + 'importance': result.importances_mean, + 'std': result.importances_std +}).sort_values('importance', ascending=False) +``` + +### Feature Selection Methods + +**Univariate selection**: +```python +from sklearn.feature_selection import SelectKBest, f_classif + +selector = SelectKBest(f_classif, k=10) +X_selected = selector.fit_transform(X, y) +selected_features = selector.get_support(indices=True) +``` + +**Recursive Feature Elimination**: +```python +from sklearn.feature_selection import RFECV +from sklearn.ensemble import RandomForestClassifier + +selector = RFECV( + RandomForestClassifier(n_estimators=100), + step=1, + cv=5, + n_jobs=-1 +) +X_selected = selector.fit_transform(X, y) +print(f"Optimal features: {selector.n_features_}") +``` + +**Model-based selection**: +```python +from sklearn.feature_selection import SelectFromModel + +selector = SelectFromModel( + RandomForestClassifier(n_estimators=100), + threshold='median' # or '0.5*mean', or specific value +) +X_selected = selector.fit_transform(X, y) +``` + +### Polynomial Features + +```python +from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import Ridge +from sklearn.pipeline import Pipeline + +pipeline = Pipeline([ + ('poly', PolynomialFeatures(degree=2, include_bias=False)), + ('scaler', StandardScaler()), + ('ridge', Ridge()) +]) + +pipeline.fit(X_train, y_train) +``` + +## Common Patterns and Best Practices + +### Always Use Pipelines + +Pipelines prevent data leakage and ensure proper workflow: + +✅ **Correct**: +```python +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('model', LogisticRegression()) +]) +pipeline.fit(X_train, y_train) +y_pred = pipeline.predict(X_test) +``` + +❌ **Wrong** (data leakage): +```python +scaler = StandardScaler().fit(X) # Fit on all data! +X_train, X_test = train_test_split(scaler.transform(X)) +``` + +### Stratify for Imbalanced Classes + +```python +# Always use stratify for classification with imbalanced classes +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, stratify=y, random_state=42 +) +``` + +### Scale When Necessary + +**Scale for**: SVM, Neural Networks, KNN, Linear Models with regularization, PCA, Gradient Descent + +**Don't scale for**: Tree-based models (Random Forest, Gradient Boosting), Naive Bayes + +### Handle Missing Values + +```python +from sklearn.impute import SimpleImputer + +# Numeric: use median (robust to outliers) +imputer = SimpleImputer(strategy='median') + +# Categorical: use constant value or most_frequent +imputer = SimpleImputer(strategy='constant', fill_value='missing') +``` + +### Use Appropriate Metrics + +- **Balanced classification**: accuracy, F1 +- **Imbalanced classification**: balanced_accuracy, F1-weighted, ROC-AUC +- **Regression with outliers**: MAE instead of RMSE +- **Cost-sensitive**: custom scorer + +### Set Random States + +```python +# For reproducibility +model = RandomForestClassifier(random_state=42) +X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=42 +) +``` + +### Use Parallel Processing + +```python +# Use all CPU cores +model = RandomForestClassifier(n_jobs=-1) +grid_search = GridSearchCV(model, param_grid, n_jobs=-1) +``` + +## Unsupervised Learning + +### Clustering Workflow + +```python +from sklearn.preprocessing import StandardScaler +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score + +# Always scale for clustering +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# Elbow method to find optimal k +inertias = [] +silhouette_scores = [] +K_range = range(2, 11) + +for k in K_range: + kmeans = KMeans(n_clusters=k, random_state=42) + labels = kmeans.fit_predict(X_scaled) + inertias.append(kmeans.inertia_) + silhouette_scores.append(silhouette_score(X_scaled, labels)) + +# Plot and choose k +import matplotlib.pyplot as plt +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) +ax1.plot(K_range, inertias, 'bo-') +ax1.set_xlabel('k') +ax1.set_ylabel('Inertia') +ax2.plot(K_range, silhouette_scores, 'ro-') +ax2.set_xlabel('k') +ax2.set_ylabel('Silhouette Score') +plt.show() + +# Fit final model +optimal_k = 5 # Based on elbow/silhouette +kmeans = KMeans(n_clusters=optimal_k, random_state=42) +labels = kmeans.fit_predict(X_scaled) +``` + +### Dimensionality Reduction + +```python +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +# ALWAYS scale before PCA +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# Specify variance to retain +pca = PCA(n_components=0.95) # Keep 95% of variance +X_pca = pca.fit_transform(X_scaled) + +print(f"Original features: {X.shape[1]}") +print(f"Reduced features: {pca.n_components_}") +print(f"Variance explained: {pca.explained_variance_ratio_.sum():.3f}") + +# Visualize explained variance +import matplotlib.pyplot as plt +plt.plot(np.cumsum(pca.explained_variance_ratio_)) +plt.xlabel('Number of components') +plt.ylabel('Cumulative explained variance') +plt.show() +``` + +### Visualization with t-SNE + +```python +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA + +# Reduce to 50 dimensions with PCA first (faster) +pca = PCA(n_components=min(50, X.shape[1])) +X_pca = pca.fit_transform(X_scaled) + +# Apply t-SNE (only for visualization!) +tsne = TSNE(n_components=2, random_state=42, perplexity=30) +X_tsne = tsne.fit_transform(X_pca) + +# Visualize +plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='viridis', alpha=0.6) +plt.colorbar() +plt.title('t-SNE Visualization') +plt.show() +``` + +## Saving and Loading Models + +```python +import joblib + +# Save model or pipeline +joblib.dump(model, 'model.pkl') +joblib.dump(pipeline, 'pipeline.pkl') + +# Load +loaded_model = joblib.load('model.pkl') +loaded_pipeline = joblib.load('pipeline.pkl') + +# Use loaded model +predictions = loaded_model.predict(X_new) +``` + +## Reference Documentation + +This skill includes comprehensive reference files: + +- **`references/supervised_learning.md`**: Detailed coverage of all classification and regression algorithms, parameters, use cases, and selection guidelines +- **`references/preprocessing.md`**: Complete guide to data preprocessing including scaling, encoding, imputation, transformations, and best practices +- **`references/model_evaluation.md`**: In-depth coverage of cross-validation strategies, metrics, hyperparameter tuning, and validation techniques +- **`references/unsupervised_learning.md`**: Comprehensive guide to clustering, dimensionality reduction, anomaly detection, and evaluation methods +- **`references/pipelines_and_composition.md`**: Complete guide to Pipeline, ColumnTransformer, FeatureUnion, custom transformers, and composition patterns +- **`references/quick_reference.md`**: Quick lookup guide with code snippets, common patterns, and decision trees for algorithm selection + +Read these files when: +- Need detailed parameter explanations for specific algorithms +- Comparing multiple algorithms for a task +- Understanding evaluation metrics in depth +- Building complex preprocessing workflows +- Troubleshooting common issues + +Example search patterns: +```python +# To find information about specific algorithms +grep -r "GradientBoosting" references/ + +# To find preprocessing techniques +grep -r "OneHotEncoder" references/preprocessing.md + +# To find evaluation metrics +grep -r "f1_score" references/model_evaluation.md +``` + +## Common Pitfalls to Avoid + +1. **Data leakage**: Always use pipelines, fit only on training data +2. **Not scaling**: Scale for distance-based algorithms (SVM, KNN, Neural Networks) +3. **Wrong metrics**: Use appropriate metrics for imbalanced data +4. **Not using cross-validation**: Single train-test split can be misleading +5. **Forgetting stratification**: Stratify for imbalanced classification +6. **Using t-SNE for preprocessing**: t-SNE is for visualization only! +7. **Not setting random_state**: Results won't be reproducible +8. **Ignoring class imbalance**: Use stratification, appropriate metrics, or resampling +9. **PCA without scaling**: Components will be dominated by high-variance features +10. **Testing on training data**: Always evaluate on held-out test set diff --git a/scientific-packages/scikit-learn/references/model_evaluation.md b/scientific-packages/scikit-learn/references/model_evaluation.md new file mode 100644 index 0000000..5543fcf --- /dev/null +++ b/scientific-packages/scikit-learn/references/model_evaluation.md @@ -0,0 +1,601 @@ +# Model Evaluation and Selection in scikit-learn + +## Overview +Model evaluation assesses how well models generalize to unseen data. Scikit-learn provides three main APIs for evaluation: +1. **Estimator score methods**: Built-in evaluation (accuracy for classifiers, R² for regressors) +2. **Scoring parameter**: Used in cross-validation and hyperparameter tuning +3. **Metric functions**: Specialized evaluation in `sklearn.metrics` + +## Cross-Validation + +Cross-validation evaluates model performance by splitting data into multiple train/test sets. This addresses overfitting: "a model that would just repeat the labels of the samples that it has just seen would have a perfect score but would fail to predict anything useful on yet-unseen data." + +### Basic Cross-Validation + +```python +from sklearn.model_selection import cross_val_score +from sklearn.linear_model import LogisticRegression + +model = LogisticRegression() +scores = cross_val_score(model, X, y, cv=5, scoring='accuracy') +print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +### Cross-Validation Strategies + +#### For i.i.d. Data + +**KFold**: Standard k-fold cross-validation +- Splits data into k equal folds +- Each fold used once as test set +- `n_splits`: Number of folds (typically 5 or 10) + +```python +from sklearn.model_selection import KFold +cv = KFold(n_splits=5, shuffle=True, random_state=42) +``` + +**RepeatedKFold**: Repeats KFold with different randomization +- More robust estimation +- Computationally expensive + +**LeaveOneOut (LOO)**: Each sample is a test set +- Maximum training data usage +- Very computationally expensive +- High variance in estimates +- Use only for small datasets (<1000 samples) + +**ShuffleSplit**: Random train/test splits +- Flexible train/test sizes +- Can control number of iterations +- Good for quick evaluation + +```python +from sklearn.model_selection import ShuffleSplit +cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=42) +``` + +#### For Imbalanced Classes + +**StratifiedKFold**: Preserves class proportions in each fold +- Essential for imbalanced datasets +- Default for classification in cross_val_score() + +```python +from sklearn.model_selection import StratifiedKFold +cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) +``` + +**StratifiedShuffleSplit**: Stratified random splits + +#### For Grouped Data + +Use when samples are not independent (e.g., multiple measurements from same subject). + +**GroupKFold**: Groups don't appear in both train and test +```python +from sklearn.model_selection import GroupKFold +cv = GroupKFold(n_splits=5) +scores = cross_val_score(model, X, y, groups=groups, cv=cv) +``` + +**StratifiedGroupKFold**: Combines stratification with group separation + +**LeaveOneGroupOut**: Each group becomes a test set + +#### For Time Series + +**TimeSeriesSplit**: Expanding window approach +- Successive training sets are supersets +- Respects temporal ordering +- No data leakage from future to past + +```python +from sklearn.model_selection import TimeSeriesSplit +cv = TimeSeriesSplit(n_splits=5) +for train_idx, test_idx in cv.split(X): + # Train on indices 0 to t, test on t+1 to t+k + pass +``` + +### Cross-Validation Functions + +**cross_val_score**: Returns array of scores +```python +scores = cross_val_score(model, X, y, cv=5, scoring='f1_weighted') +``` + +**cross_validate**: Returns multiple metrics and timing +```python +results = cross_validate( + model, X, y, cv=5, + scoring=['accuracy', 'f1_weighted', 'roc_auc'], + return_train_score=True, + return_estimator=True # Returns fitted estimators +) +print(results['test_accuracy']) +print(results['fit_time']) +``` + +**cross_val_predict**: Returns predictions for model blending/visualization +```python +from sklearn.model_selection import cross_val_predict +y_pred = cross_val_predict(model, X, y, cv=5) +# Use for confusion matrix, error analysis, etc. +``` + +## Hyperparameter Tuning + +### GridSearchCV +Exhaustively searches all parameter combinations. + +```python +from sklearn.model_selection import GridSearchCV +from sklearn.ensemble import RandomForestClassifier + +param_grid = { + 'n_estimators': [100, 200, 500], + 'max_depth': [10, 20, 30, None], + 'min_samples_split': [2, 5, 10], + 'min_samples_leaf': [1, 2, 4] +} + +grid_search = GridSearchCV( + RandomForestClassifier(random_state=42), + param_grid, + cv=5, + scoring='f1_weighted', + n_jobs=-1, # Use all CPU cores + verbose=2 +) + +grid_search.fit(X_train, y_train) +print("Best parameters:", grid_search.best_params_) +print("Best score:", grid_search.best_score_) + +# Use best model +best_model = grid_search.best_estimator_ +``` + +**When to use**: +- Small parameter spaces +- When computational resources allow +- When exhaustive search is desired + +### RandomizedSearchCV +Samples parameter combinations from distributions. + +```python +from sklearn.model_selection import RandomizedSearchCV +from scipy.stats import randint, uniform + +param_distributions = { + 'n_estimators': randint(100, 1000), + 'max_depth': randint(5, 50), + 'min_samples_split': randint(2, 20), + 'min_samples_leaf': randint(1, 10), + 'max_features': uniform(0.1, 0.9) +} + +random_search = RandomizedSearchCV( + RandomForestClassifier(random_state=42), + param_distributions, + n_iter=100, # Number of parameter settings sampled + cv=5, + scoring='f1_weighted', + n_jobs=-1, + random_state=42 +) + +random_search.fit(X_train, y_train) +``` + +**When to use**: +- Large parameter spaces +- When budget is limited +- Often finds good parameters faster than GridSearchCV + +**Advantage**: "Budget can be chosen independent of the number of parameters and possible values" + +### Successive Halving + +**HalvingGridSearchCV** and **HalvingRandomSearchCV**: Tournament-style selection + +**How it works**: +1. Start with many candidates, minimal resources +2. Eliminate poor performers +3. Increase resources for remaining candidates +4. Repeat until best candidates found + +**When to use**: +- Large parameter spaces +- Expensive model training +- When many parameter combinations are clearly inferior + +```python +from sklearn.experimental import enable_halving_search_cv +from sklearn.model_selection import HalvingGridSearchCV + +halving_search = HalvingGridSearchCV( + estimator, + param_grid, + factor=3, # Proportion of candidates eliminated each round + cv=5 +) +``` + +## Classification Metrics + +### Accuracy-Based Metrics + +**Accuracy**: Proportion of correct predictions +```python +from sklearn.metrics import accuracy_score +accuracy = accuracy_score(y_true, y_pred) +``` + +**When to use**: Balanced datasets only +**When NOT to use**: Imbalanced datasets (misleading) + +**Balanced Accuracy**: Average recall per class +```python +from sklearn.metrics import balanced_accuracy_score +bal_acc = balanced_accuracy_score(y_true, y_pred) +``` + +**When to use**: Imbalanced datasets, ensures all classes matter equally + +### Precision, Recall, F-Score + +**Precision**: Of predicted positives, how many are actually positive +- Formula: TP / (TP + FP) +- Answers: "How reliable are positive predictions?" + +**Recall** (Sensitivity): Of actual positives, how many are predicted positive +- Formula: TP / (TP + FN) +- Answers: "How complete is positive detection?" + +**F1-Score**: Harmonic mean of precision and recall +- Formula: 2 * (precision * recall) / (precision + recall) +- Balanced measure when both precision and recall are important + +```python +from sklearn.metrics import precision_recall_fscore_support, f1_score + +precision, recall, f1, support = precision_recall_fscore_support( + y_true, y_pred, average='weighted' +) + +# Or individually +f1 = f1_score(y_true, y_pred, average='weighted') +``` + +**Averaging strategies for multiclass**: +- `binary`: Binary classification only +- `micro`: Calculate globally (total TP, FP, FN) +- `macro`: Calculate per class, unweighted mean (all classes equal) +- `weighted`: Calculate per class, weighted by support (class frequency) +- `samples`: For multilabel classification + +**When to use**: +- `macro`: When all classes equally important (even rare ones) +- `weighted`: When class frequency matters +- `micro`: When overall performance across all samples matters + +### Confusion Matrix + +Shows true positives, false positives, true negatives, false negatives. + +```python +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay +import matplotlib.pyplot as plt + +cm = confusion_matrix(y_true, y_pred) +disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Class 0', 'Class 1']) +disp.plot() +plt.show() +``` + +### ROC Curve and AUC + +**ROC (Receiver Operating Characteristic)**: Plot of true positive rate vs false positive rate at different thresholds + +**AUC (Area Under Curve)**: Measures overall ability to discriminate between classes +- 1.0 = perfect classifier +- 0.5 = random classifier +- <0.5 = worse than random + +```python +from sklearn.metrics import roc_auc_score, roc_curve +import matplotlib.pyplot as plt + +# Requires probability predictions +y_proba = model.predict_proba(X_test)[:, 1] # Probabilities for positive class + +auc = roc_auc_score(y_true, y_proba) +fpr, tpr, thresholds = roc_curve(y_true, y_proba) + +plt.plot(fpr, tpr, label=f'AUC = {auc:.3f}') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.legend() +plt.show() +``` + +**Multiclass ROC**: Use `multi_class='ovr'` (one-vs-rest) or `'ovo'` (one-vs-one) + +```python +auc = roc_auc_score(y_true, y_proba, multi_class='ovr') +``` + +### Log Loss + +Measures probability calibration quality. + +```python +from sklearn.metrics import log_loss +loss = log_loss(y_true, y_proba) +``` + +**When to use**: When probability quality matters, not just class predictions +**Lower is better**: Perfect predictions have log loss of 0 + +### Classification Report + +Comprehensive summary of precision, recall, f1-score per class. + +```python +from sklearn.metrics import classification_report + +print(classification_report(y_true, y_pred, target_names=['Class 0', 'Class 1'])) +``` + +## Regression Metrics + +### Mean Squared Error (MSE) +Average squared difference between predictions and true values. + +```python +from sklearn.metrics import mean_squared_error +mse = mean_squared_error(y_true, y_pred) +rmse = mean_squared_error(y_true, y_pred, squared=False) # Root MSE +``` + +**Characteristics**: +- Penalizes large errors heavily (squared term) +- Same units as target² (use RMSE for same units as target) +- Lower is better + +### Mean Absolute Error (MAE) +Average absolute difference between predictions and true values. + +```python +from sklearn.metrics import mean_absolute_error +mae = mean_absolute_error(y_true, y_pred) +``` + +**Characteristics**: +- More robust to outliers than MSE +- Same units as target +- More interpretable +- Lower is better + +**MSE vs MAE**: Use MAE when outliers shouldn't dominate the metric + +### R² Score (Coefficient of Determination) +Proportion of variance explained by the model. + +```python +from sklearn.metrics import r2_score +r2 = r2_score(y_true, y_pred) +``` + +**Interpretation**: +- 1.0 = perfect predictions +- 0.0 = model as good as mean +- <0.0 = model worse than mean (possible!) +- Higher is better + +**Note**: Can be negative for models that perform worse than predicting the mean. + +### Mean Absolute Percentage Error (MAPE) +Percentage-based error metric. + +```python +from sklearn.metrics import mean_absolute_percentage_error +mape = mean_absolute_percentage_error(y_true, y_pred) +``` + +**When to use**: When relative errors matter more than absolute errors +**Warning**: Undefined when true values are zero + +### Median Absolute Error +Median of absolute errors (robust to outliers). + +```python +from sklearn.metrics import median_absolute_error +med_ae = median_absolute_error(y_true, y_pred) +``` + +### Max Error +Maximum residual error. + +```python +from sklearn.metrics import max_error +max_err = max_error(y_true, y_pred) +``` + +**When to use**: When worst-case performance matters + +## Custom Scoring Functions + +Create custom scorers for GridSearchCV and cross_val_score: + +```python +from sklearn.metrics import make_scorer, fbeta_score + +# F2 score (weights recall higher than precision) +f2_scorer = make_scorer(fbeta_score, beta=2) + +# Custom function +def custom_metric(y_true, y_pred): + # Your custom logic + return score + +custom_scorer = make_scorer(custom_metric, greater_is_better=True) + +# Use in cross-validation or grid search +scores = cross_val_score(model, X, y, cv=5, scoring=custom_scorer) +``` + +## Scoring Parameter Options + +Common scoring strings for `scoring` parameter: + +**Classification**: +- `'accuracy'`, `'balanced_accuracy'` +- `'precision'`, `'recall'`, `'f1'` (add `_macro`, `_micro`, `_weighted` for multiclass) +- `'roc_auc'`, `'roc_auc_ovr'`, `'roc_auc_ovo'` +- `'log_loss'` (lower is better, negate for maximization) +- `'jaccard'` (Jaccard similarity) + +**Regression**: +- `'r2'` +- `'neg_mean_squared_error'`, `'neg_root_mean_squared_error'` +- `'neg_mean_absolute_error'` +- `'neg_mean_absolute_percentage_error'` +- `'neg_median_absolute_error'` + +**Note**: Many metrics are negated (neg_*) so GridSearchCV can maximize them. + +## Validation Strategies + +### Train-Test Split +Simple single split. + +```python +from sklearn.model_selection import train_test_split + +X_train, X_test, y_train, y_test = train_test_split( + X, y, + test_size=0.2, + random_state=42, + stratify=y # For classification with imbalanced classes +) +``` + +**When to use**: Large datasets, quick evaluation +**Parameters**: +- `test_size`: Proportion for test (typically 0.2-0.3) +- `stratify`: Preserves class proportions +- `random_state`: Reproducibility + +### Train-Validation-Test Split +Three-way split for hyperparameter tuning. + +```python +# First split: train+val and test +X_trainval, X_test, y_trainval, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# Second split: train and validation +X_train, X_val, y_train, y_val = train_test_split( + X_trainval, y_trainval, test_size=0.2, random_state=42 +) + +# Or use GridSearchCV with train+val, then evaluate on test +``` + +**When to use**: Model selection and final evaluation +**Strategy**: +1. Train: Model training +2. Validation: Hyperparameter tuning +3. Test: Final, unbiased evaluation (touch only once!) + +### Learning Curves + +Diagnose bias vs variance issues. + +```python +from sklearn.model_selection import learning_curve +import matplotlib.pyplot as plt + +train_sizes, train_scores, val_scores = learning_curve( + model, X, y, + cv=5, + train_sizes=np.linspace(0.1, 1.0, 10), + scoring='accuracy', + n_jobs=-1 +) + +plt.plot(train_sizes, train_scores.mean(axis=1), label='Training score') +plt.plot(train_sizes, val_scores.mean(axis=1), label='Validation score') +plt.xlabel('Training set size') +plt.ylabel('Score') +plt.legend() +plt.show() +``` + +**Interpretation**: +- Large gap between train and validation: **Overfitting** (high variance) +- Both scores low: **Underfitting** (high bias) +- Scores converging but low: Need better features or more complex model +- Validation score still improving: More data would help + +## Best Practices + +### Metric Selection Guidelines + +**Classification - Balanced classes**: +- Accuracy or F1-score + +**Classification - Imbalanced classes**: +- Balanced accuracy +- F1-score (weighted or macro) +- ROC-AUC +- Precision-Recall curve + +**Classification - Cost-sensitive**: +- Custom scorer with cost matrix +- Adjust threshold on probabilities + +**Regression - Typical use**: +- RMSE (sensitive to outliers) +- R² (proportion of variance explained) + +**Regression - Outliers present**: +- MAE (robust to outliers) +- Median absolute error + +**Regression - Percentage errors matter**: +- MAPE + +### Cross-Validation Guidelines + +**Number of folds**: +- 5-10 folds typical +- More folds = more computation, less variance in estimate +- LeaveOneOut only for small datasets + +**Stratification**: +- Always use for classification with imbalanced classes +- Use StratifiedKFold by default for classification + +**Grouping**: +- Always use when samples are not independent +- Time series: Always use TimeSeriesSplit + +**Nested cross-validation**: +- For unbiased performance estimate when doing hyperparameter tuning +- Outer loop: Performance estimation +- Inner loop: Hyperparameter selection + +### Avoiding Common Pitfalls + +1. **Data leakage**: Fit preprocessors only on training data within each CV fold (use Pipeline!) +2. **Test set leakage**: Never use test set for model selection +3. **Improper metric**: Use metrics appropriate for problem (balanced_accuracy for imbalanced data) +4. **Multiple testing**: More models evaluated = higher chance of random good results +5. **Temporal leakage**: For time series, use TimeSeriesSplit, not random splits +6. **Target leakage**: Features shouldn't contain information not available at prediction time diff --git a/scientific-packages/scikit-learn/references/pipelines_and_composition.md b/scientific-packages/scikit-learn/references/pipelines_and_composition.md new file mode 100644 index 0000000..bcf898f --- /dev/null +++ b/scientific-packages/scikit-learn/references/pipelines_and_composition.md @@ -0,0 +1,679 @@ +# Pipelines and Composite Estimators in scikit-learn + +## Overview +Pipelines chain multiple estimators into a single unit, ensuring proper workflow sequencing and preventing data leakage. As the documentation states: "Pipeline can be used to chain multiple estimators into one. This is useful as there is often a fixed sequence of steps in processing the data, for example feature selection, normalization and classification." + +## Pipeline Basics + +### Creating Pipelines + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression + +# Method 1: List of (name, estimator) tuples +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=10)), + ('classifier', LogisticRegression()) +]) + +# Method 2: Using make_pipeline (auto-generates names) +from sklearn.pipeline import make_pipeline +pipeline = make_pipeline( + StandardScaler(), + PCA(n_components=10), + LogisticRegression() +) +``` + +### Using Pipelines + +```python +# Fit and predict like any estimator +pipeline.fit(X_train, y_train) +y_pred = pipeline.predict(X_test) +score = pipeline.score(X_test, y_test) + +# Access steps +pipeline.named_steps['scaler'] +pipeline.steps[0] # Returns ('scaler', StandardScaler(...)) +pipeline[0] # Returns StandardScaler(...) object +pipeline['scaler'] # Returns StandardScaler(...) object + +# Get final estimator +pipeline[-1] # Returns LogisticRegression(...) object +``` + +### Pipeline Rules + +**All steps except the last must be transformers** (have `fit()` and `transform()` methods). + +**The final step** can be: +- Predictor (classifier/regressor) with `fit()` and `predict()` +- Transformer with `fit()` and `transform()` +- Any estimator with at least `fit()` + +### Pipeline Benefits + +1. **Convenience**: Single `fit()` and `predict()` call +2. **Prevents data leakage**: Ensures proper fit/transform on train/test +3. **Joint parameter selection**: Tune all steps together with GridSearchCV +4. **Reproducibility**: Encapsulates entire workflow + +## Accessing and Setting Parameters + +### Nested Parameters + +Access step parameters using `stepname__parameter` syntax: + +```python +from sklearn.model_selection import GridSearchCV + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('clf', LogisticRegression()) +]) + +# Grid search over pipeline parameters +param_grid = { + 'scaler__with_mean': [True, False], + 'clf__C': [0.1, 1.0, 10.0], + 'clf__penalty': ['l1', 'l2'] +} + +grid_search = GridSearchCV(pipeline, param_grid, cv=5) +grid_search.fit(X_train, y_train) +``` + +### Setting Parameters + +```python +# Set parameters +pipeline.set_params(clf__C=10.0, scaler__with_std=False) + +# Get parameters +params = pipeline.get_params() +``` + +## Caching Intermediate Results + +Cache fitted transformers to avoid recomputation: + +```python +from tempfile import mkdtemp +from shutil import rmtree + +# Create cache directory +cachedir = mkdtemp() + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=10)), + ('clf', LogisticRegression()) +], memory=cachedir) + +# When doing grid search, scaler and PCA only fit once per fold +grid_search = GridSearchCV(pipeline, param_grid, cv=5) +grid_search.fit(X_train, y_train) + +# Clean up cache +rmtree(cachedir) + +# Or use joblib for persistent caching +from joblib import Memory +memory = Memory(location='./cache', verbose=0) +pipeline = Pipeline([...], memory=memory) +``` + +**When to use caching**: +- Expensive transformations (PCA, feature selection) +- Grid search over final estimator parameters only +- Multiple experiments with same preprocessing + +## ColumnTransformer + +Apply different transformations to different columns (essential for heterogeneous data). + +### Basic Usage + +```python +from sklearn.compose import ColumnTransformer +from sklearn.preprocessing import StandardScaler, OneHotEncoder + +# Define which transformations for which columns +preprocessor = ColumnTransformer( + transformers=[ + ('num', StandardScaler(), ['age', 'income', 'credit_score']), + ('cat', OneHotEncoder(), ['country', 'occupation']) + ], + remainder='drop' # What to do with remaining columns +) + +X_transformed = preprocessor.fit_transform(X) +``` + +### Column Selection Methods + +```python +# Method 1: Column names (list of strings) +('num', StandardScaler(), ['age', 'income']) + +# Method 2: Column indices (list of integers) +('num', StandardScaler(), [0, 1, 2]) + +# Method 3: Boolean mask +('num', StandardScaler(), [True, True, False, True, False]) + +# Method 4: Slice +('num', StandardScaler(), slice(0, 3)) + +# Method 5: make_column_selector (by dtype or pattern) +from sklearn.compose import make_column_selector as selector + +preprocessor = ColumnTransformer([ + ('num', StandardScaler(), selector(dtype_include='number')), + ('cat', OneHotEncoder(), selector(dtype_include='object')) +]) + +# Select by pattern +selector(pattern='.*_score$') # All columns ending with '_score' +``` + +### Remainder Parameter + +Controls what happens to columns not specified: + +```python +# Drop remaining columns (default) +remainder='drop' + +# Pass through remaining columns unchanged +remainder='passthrough' + +# Apply transformer to remaining columns +remainder=StandardScaler() +``` + +### Full Pipeline with ColumnTransformer + +```python +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.ensemble import RandomForestClassifier + +# Separate preprocessing for numeric and categorical +numeric_features = ['age', 'income', 'credit_score'] +categorical_features = ['country', 'occupation', 'education'] + +numeric_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()) +]) + +categorical_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore')) +]) + +preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, numeric_features), + ('cat', categorical_transformer, categorical_features) + ]) + +# Complete pipeline +clf = Pipeline(steps=[ + ('preprocessor', preprocessor), + ('classifier', RandomForestClassifier()) +]) + +clf.fit(X_train, y_train) +y_pred = clf.predict(X_test) + +# Grid search over preprocessing and model parameters +param_grid = { + 'preprocessor__num__imputer__strategy': ['mean', 'median'], + 'preprocessor__cat__onehot__max_categories': [10, 20, None], + 'classifier__n_estimators': [100, 200], + 'classifier__max_depth': [10, 20, None] +} + +grid_search = GridSearchCV(clf, param_grid, cv=5) +grid_search.fit(X_train, y_train) +``` + +## FeatureUnion + +Combine multiple transformer outputs by concatenating features side-by-side. + +```python +from sklearn.pipeline import FeatureUnion +from sklearn.decomposition import PCA +from sklearn.feature_selection import SelectKBest + +# Combine PCA and feature selection +combined_features = FeatureUnion([ + ('pca', PCA(n_components=10)), + ('univ_select', SelectKBest(k=5)) +]) + +X_features = combined_features.fit_transform(X, y) +# Result: 15 features (10 from PCA + 5 from SelectKBest) + +# In a pipeline +pipeline = Pipeline([ + ('features', combined_features), + ('classifier', LogisticRegression()) +]) +``` + +### FeatureUnion with Transformers on Different Data + +```python +from sklearn.pipeline import FeatureUnion +from sklearn.preprocessing import FunctionTransformer +import numpy as np + +def get_numeric_data(X): + return X[:, :3] # First 3 columns + +def get_text_data(X): + return X[:, 3] # 4th column (text) + +from sklearn.feature_extraction.text import TfidfVectorizer + +combined = FeatureUnion([ + ('numeric_features', Pipeline([ + ('selector', FunctionTransformer(get_numeric_data)), + ('scaler', StandardScaler()) + ])), + ('text_features', Pipeline([ + ('selector', FunctionTransformer(get_text_data)), + ('tfidf', TfidfVectorizer()) + ])) +]) +``` + +**Note**: ColumnTransformer is usually more convenient than FeatureUnion for heterogeneous data. + +## Common Pipeline Patterns + +### Classification Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.feature_selection import SelectKBest, f_classif +from sklearn.svm import SVC + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('feature_selection', SelectKBest(f_classif, k=10)), + ('classifier', SVC(kernel='rbf')) +]) +``` + +### Regression Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler, PolynomialFeatures +from sklearn.linear_model import Ridge + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('poly', PolynomialFeatures(degree=2)), + ('ridge', Ridge(alpha=1.0)) +]) +``` + +### Text Classification Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.naive_bayes import MultinomialNB + +pipeline = Pipeline([ + ('tfidf', TfidfVectorizer(max_features=1000)), + ('classifier', MultinomialNB()) +]) + +# Works directly with text +pipeline.fit(X_train_text, y_train) +y_pred = pipeline.predict(X_test_text) +``` + +### Image Processing Pipeline + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.neural_network import MLPClassifier + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=100)), + ('mlp', MLPClassifier(hidden_layer_sizes=(100, 50))) +]) +``` + +### Dimensionality Reduction + Clustering + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +from sklearn.cluster import KMeans + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=10)), + ('kmeans', KMeans(n_clusters=5)) +]) + +labels = pipeline.fit_predict(X) +``` + +## Custom Transformers + +### Using FunctionTransformer + +```python +from sklearn.preprocessing import FunctionTransformer +import numpy as np + +# Log transformation +log_transformer = FunctionTransformer(np.log1p) + +# Custom function +def custom_transform(X): + # Your transformation logic + return X_transformed + +custom_transformer = FunctionTransformer(custom_transform) + +# In pipeline +pipeline = Pipeline([ + ('log', log_transformer), + ('scaler', StandardScaler()), + ('model', LinearRegression()) +]) +``` + +### Creating Custom Transformer Class + +```python +from sklearn.base import BaseEstimator, TransformerMixin + +class CustomTransformer(BaseEstimator, TransformerMixin): + def __init__(self, parameter=1.0): + self.parameter = parameter + + def fit(self, X, y=None): + # Learn parameters from X + self.learned_param_ = X.mean() # Example + return self + + def transform(self, X): + # Transform X using learned parameters + return X * self.parameter - self.learned_param_ + + # Optional: for pipelines that need inverse transform + def inverse_transform(self, X): + return (X + self.learned_param_) / self.parameter + +# Use in pipeline +pipeline = Pipeline([ + ('custom', CustomTransformer(parameter=2.0)), + ('model', LinearRegression()) +]) +``` + +**Key requirements**: +- Inherit from `BaseEstimator` and `TransformerMixin` +- Implement `fit()` and `transform()` methods +- `fit()` must return `self` +- Use trailing underscore for learned attributes (`learned_param_`) +- Constructor parameters should be stored as attributes + +### Transformer for Pandas DataFrames + +```python +from sklearn.base import BaseEstimator, TransformerMixin +import pandas as pd + +class DataFrameTransformer(BaseEstimator, TransformerMixin): + def __init__(self, columns=None): + self.columns = columns + + def fit(self, X, y=None): + return self + + def transform(self, X): + if isinstance(X, pd.DataFrame): + if self.columns: + return X[self.columns].values + return X.values + return X +``` + +## Visualization + +### Display Pipeline in Jupyter + +```python +from sklearn import set_config + +# Enable HTML display +set_config(display='diagram') + +# Now displaying the pipeline shows interactive diagram +pipeline +``` + +### Print Pipeline Structure + +```python +from sklearn.utils import estimator_html_repr + +# Get HTML representation +html = estimator_html_repr(pipeline) + +# Or just print +print(pipeline) +``` + +## Advanced Patterns + +### Conditional Transformations + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler, FunctionTransformer + +def conditional_scale(X, scale=True): + if scale: + return StandardScaler().fit_transform(X) + return X + +pipeline = Pipeline([ + ('conditional_scaler', FunctionTransformer( + conditional_scale, + kw_args={'scale': True} + )), + ('model', LogisticRegression()) +]) +``` + +### Multiple Preprocessing Paths + +```python +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline + +# Different preprocessing for different feature types +preprocessor = ColumnTransformer([ + # Numeric: impute + scale + ('num_standard', Pipeline([ + ('imputer', SimpleImputer(strategy='mean')), + ('scaler', StandardScaler()) + ]), ['age', 'income']), + + # Numeric: impute + log + scale + ('num_skewed', Pipeline([ + ('imputer', SimpleImputer(strategy='median')), + ('log', FunctionTransformer(np.log1p)), + ('scaler', StandardScaler()) + ]), ['price', 'revenue']), + + # Categorical: impute + one-hot + ('cat', Pipeline([ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore')) + ]), ['category', 'region']), + + # Text: TF-IDF + ('text', TfidfVectorizer(), 'description') +]) +``` + +### Feature Engineering Pipeline + +```python +from sklearn.base import BaseEstimator, TransformerMixin + +class FeatureEngineer(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self + + def transform(self, X): + X = X.copy() + # Add engineered features + X['age_income_ratio'] = X['age'] / (X['income'] + 1) + X['total_score'] = X['score1'] + X['score2'] + X['score3'] + return X + +pipeline = Pipeline([ + ('engineer', FeatureEngineer()), + ('preprocessor', preprocessor), + ('model', RandomForestClassifier()) +]) +``` + +## Best Practices + +### Always Use Pipelines When + +1. **Preprocessing is needed**: Scaling, encoding, imputation +2. **Cross-validation**: Ensures proper fit/transform split +3. **Hyperparameter tuning**: Joint optimization of preprocessing and model +4. **Production deployment**: Single object to serialize +5. **Multiple steps**: Any workflow with >1 step + +### Pipeline Do's + +- ✅ Fit pipeline only on training data +- ✅ Use ColumnTransformer for heterogeneous data +- ✅ Cache expensive transformations during grid search +- ✅ Use make_pipeline for simple cases +- ✅ Set verbose=True to debug issues +- ✅ Use remainder='passthrough' when appropriate + +### Pipeline Don'ts + +- ❌ Fit preprocessing on full dataset before split (data leakage!) +- ❌ Manually transform test data (use pipeline.predict()) +- ❌ Forget to handle missing values before scaling +- ❌ Mix pandas DataFrames and arrays inconsistently +- ❌ Skip using pipelines for "just one preprocessing step" + +### Data Leakage Prevention + +```python +# ❌ WRONG - Data leakage +scaler = StandardScaler().fit(X) # Fit on all data +X_train, X_test, y_train, y_test = train_test_split(X, y) +X_train_scaled = scaler.transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# ✅ CORRECT - No leakage with pipeline +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('model', LogisticRegression()) +]) + +X_train, X_test, y_train, y_test = train_test_split(X, y) +pipeline.fit(X_train, y_train) # Scaler fits only on train +y_pred = pipeline.predict(X_test) # Scaler transforms only on test + +# ✅ CORRECT - No leakage in cross-validation +scores = cross_val_score(pipeline, X, y, cv=5) +# Each fold: scaler fits on train folds, transforms on test fold +``` + +### Debugging Pipelines + +```python +# Examine intermediate outputs +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('pca', PCA(n_components=10)), + ('model', LogisticRegression()) +]) + +# Fit pipeline +pipeline.fit(X_train, y_train) + +# Get output after scaling +X_scaled = pipeline.named_steps['scaler'].transform(X_train) + +# Get output after PCA +X_pca = pipeline[:-1].transform(X_train) # All steps except last + +# Or build partial pipeline +partial_pipeline = Pipeline(pipeline.steps[:-1]) +X_transformed = partial_pipeline.transform(X_train) +``` + +### Saving and Loading Pipelines + +```python +import joblib + +# Save pipeline +joblib.dump(pipeline, 'model_pipeline.pkl') + +# Load pipeline +pipeline = joblib.load('model_pipeline.pkl') + +# Use loaded pipeline +y_pred = pipeline.predict(X_new) +``` + +## Common Errors and Solutions + +**Error**: `ValueError: could not convert string to float` +- **Cause**: Categorical features not encoded +- **Solution**: Add OneHotEncoder or OrdinalEncoder to pipeline + +**Error**: `All intermediate steps should be transformers` +- **Cause**: Non-transformer in non-final position +- **Solution**: Ensure only last step is predictor + +**Error**: `X has different number of features than during fitting` +- **Cause**: Different columns in train and test +- **Solution**: Ensure consistent column handling, use `handle_unknown='ignore'` in OneHotEncoder + +**Error**: Different results in cross-validation vs train-test split +- **Cause**: Data leakage (fitting preprocessing on all data) +- **Solution**: Always use Pipeline for preprocessing + +**Error**: Pipeline too slow during grid search +- **Solution**: Use caching with `memory` parameter diff --git a/scientific-packages/scikit-learn/references/preprocessing.md b/scientific-packages/scikit-learn/references/preprocessing.md new file mode 100644 index 0000000..f718e67 --- /dev/null +++ b/scientific-packages/scikit-learn/references/preprocessing.md @@ -0,0 +1,413 @@ +# Data Preprocessing in scikit-learn + +## Overview +Preprocessing transforms raw data into a format suitable for machine learning algorithms. Many algorithms require standardized or normalized data to perform well. + +## Standardization and Scaling + +### StandardScaler +Removes mean and scales to unit variance (z-score normalization). + +**Formula**: `z = (x - μ) / σ` + +**Use cases**: +- Most ML algorithms (especially SVM, neural networks, PCA) +- When features have different units or scales +- When assuming Gaussian-like distribution + +**Important**: Fit only on training data, then transform both train and test sets. + +```python +from sklearn.preprocessing import StandardScaler +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) # Use same parameters +``` + +### MinMaxScaler +Scales features to a specified range, typically [0, 1]. + +**Formula**: `X_scaled = (X - X_min) / (X_max - X_min)` + +**Use cases**: +- When bounded range is needed +- Neural networks (often prefer [0, 1] range) +- When distribution is not Gaussian +- Image pixel values + +**Parameters**: +- `feature_range`: Tuple (min, max), default (0, 1) + +**Warning**: Sensitive to outliers since it uses min/max. + +### MaxAbsScaler +Scales to [-1, 1] by dividing by maximum absolute value. + +**Use cases**: +- Sparse data (preserves sparsity) +- Data already centered at zero +- When sign of values is meaningful + +**Advantage**: Doesn't shift/center the data, preserves zero entries. + +### RobustScaler +Uses median and interquartile range (IQR) instead of mean and standard deviation. + +**Formula**: `X_scaled = (X - median) / IQR` + +**Use cases**: +- When outliers are present +- When StandardScaler produces skewed results +- Robust statistics preferred + +**Parameters**: +- `quantile_range`: Tuple (q_min, q_max), default (25.0, 75.0) + +## Normalization + +### normalize() function and Normalizer +Scales individual samples (rows) to unit norm, not features (columns). + +**Use cases**: +- Text classification (TF-IDF vectors) +- When similarity metrics (dot product, cosine) are used +- When each sample should have equal weight + +**Norms**: +- `l1`: Manhattan norm (sum of absolutes = 1) +- `l2`: Euclidean norm (sum of squares = 1) - **most common** +- `max`: Maximum absolute value = 1 + +**Key difference from scalers**: Operates on rows (samples), not columns (features). + +```python +from sklearn.preprocessing import Normalizer +normalizer = Normalizer(norm='l2') +X_normalized = normalizer.transform(X) +``` + +## Encoding Categorical Features + +### OrdinalEncoder +Converts categories to integers (0 to n_categories - 1). + +**Use cases**: +- Ordinal relationships exist (small < medium < large) +- Preprocessing before other transformations +- Tree-based algorithms (which can handle integers) + +**Parameters**: +- `handle_unknown`: 'error' or 'use_encoded_value' +- `unknown_value`: Value for unknown categories +- `encoded_missing_value`: Value for missing data + +```python +from sklearn.preprocessing import OrdinalEncoder +encoder = OrdinalEncoder() +X_encoded = encoder.fit_transform(X_categorical) +``` + +### OneHotEncoder +Creates binary columns for each category. + +**Use cases**: +- Nominal categories (no order) +- Linear models, neural networks +- When category relationships shouldn't be assumed + +**Parameters**: +- `drop`: 'first', 'if_binary', array-like (prevents multicollinearity) +- `sparse_output`: True (default, memory efficient) or False +- `handle_unknown`: 'error', 'ignore', 'infrequent_if_exist' +- `min_frequency`: Group infrequent categories +- `max_categories`: Limit number of categories + +**High cardinality handling**: +```python +encoder = OneHotEncoder(min_frequency=100, handle_unknown='infrequent_if_exist') +# Groups categories appearing < 100 times into 'infrequent' category +``` + +**Memory tip**: Use `sparse_output=True` (default) for high-cardinality features. + +### TargetEncoder +Uses target statistics to encode categories. + +**Use cases**: +- High-cardinality categorical features (zip codes, user IDs) +- When linear relationships with target are expected +- Often improves performance over one-hot encoding + +**How it works**: +- Replaces category with mean of target for that category +- Uses cross-fitting during fit_transform() to prevent target leakage +- Applies smoothing to handle rare categories + +**Parameters**: +- `smooth`: Smoothing parameter for rare categories +- `cv`: Cross-validation strategy + +**Warning**: Only for supervised learning. Requires target variable. + +```python +from sklearn.preprocessing import TargetEncoder +encoder = TargetEncoder() +X_encoded = encoder.fit_transform(X_categorical, y) +``` + +### LabelEncoder +Encodes target labels into integers 0 to n_classes - 1. + +**Use cases**: Encoding target variable for classification (not features!) + +**Important**: Use `LabelEncoder` for targets, not features. For features, use OrdinalEncoder or OneHotEncoder. + +### Binarizer +Converts numeric values to binary (0 or 1) based on threshold. + +**Use cases**: Creating binary features from continuous values + +## Non-linear Transformations + +### QuantileTransformer +Maps features to uniform or normal distribution using rank transformation. + +**Use cases**: +- Unusual distributions (bimodal, heavy tails) +- Reducing outlier impact +- When normal distribution is desired + +**Parameters**: +- `output_distribution`: 'uniform' (default) or 'normal' +- `n_quantiles`: Number of quantiles (default: min(1000, n_samples)) + +**Effect**: Strong transformation that reduces outlier influence and makes data more Gaussian-like. + +### PowerTransformer +Applies parametric monotonic transformation to make data more Gaussian. + +**Methods**: +- `yeo-johnson`: Works with positive and negative values (default) +- `box-cox`: Only positive values + +**Use cases**: +- Skewed distributions +- When Gaussian assumption is important +- Variance stabilization + +**Advantage**: Less radical than QuantileTransformer, preserves more of original relationships. + +## Discretization + +### KBinsDiscretizer +Bins continuous features into discrete intervals. + +**Strategies**: +- `uniform`: Equal-width bins +- `quantile`: Equal-frequency bins +- `kmeans`: K-means clustering to determine bins + +**Encoding**: +- `ordinal`: Integer encoding (0 to n_bins - 1) +- `onehot`: One-hot encoding +- `onehot-dense`: Dense one-hot encoding + +**Use cases**: +- Making linear models handle non-linear relationships +- Reducing noise in features +- Making features more interpretable + +```python +from sklearn.preprocessing import KBinsDiscretizer +disc = KBinsDiscretizer(n_bins=5, encode='onehot', strategy='quantile') +X_binned = disc.fit_transform(X) +``` + +## Feature Generation + +### PolynomialFeatures +Generates polynomial and interaction features. + +**Parameters**: +- `degree`: Polynomial degree +- `interaction_only`: Only multiplicative interactions (no x²) +- `include_bias`: Include constant feature + +**Use cases**: +- Adding non-linearity to linear models +- Feature engineering +- Polynomial regression + +**Warning**: Number of features grows rapidly: (n+d)!/d!n! for degree d. + +```python +from sklearn.preprocessing import PolynomialFeatures +poly = PolynomialFeatures(degree=2, include_bias=False) +X_poly = poly.fit_transform(X) +# [x1, x2] → [x1, x2, x1², x1·x2, x2²] +``` + +### SplineTransformer +Generates B-spline basis functions. + +**Use cases**: +- Smooth non-linear transformations +- Alternative to PolynomialFeatures (less oscillation at boundaries) +- Generalized additive models (GAMs) + +**Parameters**: +- `n_knots`: Number of knots +- `degree`: Spline degree +- `knots`: Knot positions ('uniform', 'quantile', or array) + +## Missing Value Handling + +### SimpleImputer +Imputes missing values with various strategies. + +**Strategies**: +- `mean`: Mean of column (numeric only) +- `median`: Median of column (numeric only) +- `most_frequent`: Mode (numeric or categorical) +- `constant`: Fill with constant value + +**Parameters**: +- `strategy`: Imputation strategy +- `fill_value`: Value when strategy='constant' +- `missing_values`: What represents missing (np.nan, None, specific value) + +```python +from sklearn.impute import SimpleImputer +imputer = SimpleImputer(strategy='median') +X_imputed = imputer.fit_transform(X) +``` + +### KNNImputer +Imputes using k-nearest neighbors. + +**Use cases**: When relationships between features should inform imputation + +**Parameters**: +- `n_neighbors`: Number of neighbors +- `weights`: 'uniform' or 'distance' + +### IterativeImputer +Models each feature with missing values as function of other features. + +**Use cases**: +- Complex relationships between features +- When multiple features have missing values +- Higher quality imputation (but slower) + +**Parameters**: +- `estimator`: Estimator for regression (default: BayesianRidge) +- `max_iter`: Maximum iterations + +## Function Transformers + +### FunctionTransformer +Applies custom function to data. + +**Use cases**: +- Custom transformations in pipelines +- Log transformation, square root, etc. +- Domain-specific preprocessing + +```python +from sklearn.preprocessing import FunctionTransformer +import numpy as np + +log_transformer = FunctionTransformer(np.log1p, validate=True) +X_log = log_transformer.transform(X) +``` + +## Best Practices + +### Feature Scaling Guidelines + +**Always scale**: +- SVM, neural networks +- K-nearest neighbors +- Linear/Logistic regression with regularization +- PCA, LDA +- Gradient descent-based algorithms + +**Don't need to scale**: +- Tree-based algorithms (Decision Trees, Random Forests, Gradient Boosting) +- Naive Bayes + +### Pipeline Integration + +Always use preprocessing within pipelines to prevent data leakage: + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('classifier', LogisticRegression()) +]) + +pipeline.fit(X_train, y_train) # Scaler fit only on train data +y_pred = pipeline.predict(X_test) # Scaler transform only on test data +``` + +### Common Transformations by Data Type + +**Numeric - Continuous**: +- StandardScaler (most common) +- MinMaxScaler (neural networks) +- RobustScaler (outliers present) +- PowerTransformer (skewed data) + +**Numeric - Count Data**: +- sqrt or log transformation +- QuantileTransformer +- StandardScaler after transformation + +**Categorical - Low Cardinality (<10 categories)**: +- OneHotEncoder + +**Categorical - High Cardinality (>10 categories)**: +- TargetEncoder (supervised) +- Frequency encoding +- OneHotEncoder with min_frequency parameter + +**Categorical - Ordinal**: +- OrdinalEncoder + +**Text**: +- CountVectorizer or TfidfVectorizer +- Normalizer after vectorization + +### Data Leakage Prevention + +1. **Fit only on training data**: Never include test data when fitting preprocessors +2. **Use pipelines**: Ensures proper fit/transform separation +3. **Cross-validation**: Use Pipeline with cross_val_score() for proper evaluation +4. **Target encoding**: Use cv parameter in TargetEncoder for cross-fitting + +```python +# WRONG - data leakage +scaler = StandardScaler().fit(X_full) +X_train_scaled = scaler.transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# CORRECT - no leakage +scaler = StandardScaler().fit(X_train) +X_train_scaled = scaler.transform(X_train) +X_test_scaled = scaler.transform(X_test) +``` + +## Preprocessing Checklist + +Before modeling: +1. Handle missing values (imputation or removal) +2. Encode categorical variables appropriately +3. Scale/normalize numeric features (if needed for algorithm) +4. Handle outliers (RobustScaler, clipping, removal) +5. Create additional features if beneficial (PolynomialFeatures, domain knowledge) +6. Check for data leakage in preprocessing steps +7. Wrap everything in a Pipeline diff --git a/scientific-packages/scikit-learn/references/quick_reference.md b/scientific-packages/scikit-learn/references/quick_reference.md new file mode 100644 index 0000000..97adc71 --- /dev/null +++ b/scientific-packages/scikit-learn/references/quick_reference.md @@ -0,0 +1,625 @@ +# Scikit-learn Quick Reference + +## Essential Imports + +```python +# Core +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.compose import ColumnTransformer + +# Preprocessing +from sklearn.preprocessing import ( + StandardScaler, MinMaxScaler, RobustScaler, + OneHotEncoder, OrdinalEncoder, LabelEncoder, + PolynomialFeatures +) +from sklearn.impute import SimpleImputer + +# Models - Classification +from sklearn.linear_model import LogisticRegression +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble import ( + RandomForestClassifier, + GradientBoostingClassifier, + HistGradientBoostingClassifier +) +from sklearn.svm import SVC +from sklearn.neighbors import KNeighborsClassifier + +# Models - Regression +from sklearn.linear_model import LinearRegression, Ridge, Lasso +from sklearn.ensemble import ( + RandomForestRegressor, + GradientBoostingRegressor, + HistGradientBoostingRegressor +) + +# Clustering +from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering +from sklearn.mixture import GaussianMixture + +# Dimensionality Reduction +from sklearn.decomposition import PCA, NMF, TruncatedSVD +from sklearn.manifold import TSNE + +# Metrics +from sklearn.metrics import ( + accuracy_score, precision_score, recall_score, f1_score, + confusion_matrix, classification_report, + mean_squared_error, r2_score, mean_absolute_error +) +``` + +## Basic Workflow Template + +### Classification + +```python +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import classification_report + +# Split data +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y +) + +# Scale features +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# Train model +model = RandomForestClassifier(n_estimators=100, random_state=42) +model.fit(X_train_scaled, y_train) + +# Predict and evaluate +y_pred = model.predict(X_test_scaled) +print(classification_report(y_test, y_pred)) +``` + +### Regression + +```python +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import mean_squared_error, r2_score + +# Split data +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# Scale features +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# Train model +model = RandomForestRegressor(n_estimators=100, random_state=42) +model.fit(X_train_scaled, y_train) + +# Predict and evaluate +y_pred = model.predict(X_test_scaled) +print(f"RMSE: {mean_squared_error(y_test, y_pred, squared=False):.3f}") +print(f"R²: {r2_score(y_test, y_pred):.3f}") +``` + +### With Pipeline (Recommended) + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split, cross_val_score + +# Create pipeline +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('classifier', RandomForestClassifier(n_estimators=100, random_state=42)) +]) + +# Split and train +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) +pipeline.fit(X_train, y_train) + +# Evaluate +score = pipeline.score(X_test, y_test) +cv_scores = cross_val_score(pipeline, X_train, y_train, cv=5) +print(f"Test accuracy: {score:.3f}") +print(f"CV accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})") +``` + +## Common Preprocessing Patterns + +### Numeric Data + +```python +from sklearn.preprocessing import StandardScaler +from sklearn.impute import SimpleImputer +from sklearn.pipeline import Pipeline + +numeric_transformer = Pipeline([ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()) +]) +``` + +### Categorical Data + +```python +from sklearn.preprocessing import OneHotEncoder +from sklearn.impute import SimpleImputer +from sklearn.pipeline import Pipeline + +categorical_transformer = Pipeline([ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore')) +]) +``` + +### Mixed Data with ColumnTransformer + +```python +from sklearn.compose import ColumnTransformer + +numeric_features = ['age', 'income', 'credit_score'] +categorical_features = ['country', 'occupation'] + +preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, numeric_features), + ('cat', categorical_transformer, categorical_features) + ]) + +# Complete pipeline +from sklearn.ensemble import RandomForestClassifier +pipeline = Pipeline([ + ('preprocessor', preprocessor), + ('classifier', RandomForestClassifier()) +]) +``` + +## Model Selection Cheat Sheet + +### Quick Decision Tree + +``` +Is it supervised? +├─ Yes +│ ├─ Predicting categories? → Classification +│ │ ├─ Start with: LogisticRegression (baseline) +│ │ ├─ Then try: RandomForestClassifier +│ │ └─ Best performance: HistGradientBoostingClassifier +│ └─ Predicting numbers? → Regression +│ ├─ Start with: LinearRegression/Ridge (baseline) +│ ├─ Then try: RandomForestRegressor +│ └─ Best performance: HistGradientBoostingRegressor +└─ No + ├─ Grouping similar items? → Clustering + │ ├─ Know # clusters: KMeans + │ └─ Unknown # clusters: DBSCAN or HDBSCAN + ├─ Reducing dimensions? + │ ├─ For preprocessing: PCA + │ └─ For visualization: t-SNE or UMAP + └─ Finding outliers? → IsolationForest or LocalOutlierFactor +``` + +### Algorithm Selection by Data Size + +- **Small (<1K samples)**: Any algorithm +- **Medium (1K-100K)**: Random Forests, Gradient Boosting, Neural Networks +- **Large (>100K)**: SGDClassifier/Regressor, HistGradientBoosting, LinearSVC + +### When to Scale Features + +**Always scale**: +- SVM, Neural Networks +- K-Nearest Neighbors +- Linear/Logistic Regression (with regularization) +- PCA, LDA +- Any gradient descent algorithm + +**Don't need to scale**: +- Tree-based (Decision Trees, Random Forests, Gradient Boosting) +- Naive Bayes + +## Hyperparameter Tuning + +### GridSearchCV + +```python +from sklearn.model_selection import GridSearchCV + +param_grid = { + 'n_estimators': [100, 200, 500], + 'max_depth': [10, 20, None], + 'min_samples_split': [2, 5, 10] +} + +grid_search = GridSearchCV( + RandomForestClassifier(random_state=42), + param_grid, + cv=5, + scoring='f1_weighted', + n_jobs=-1 +) + +grid_search.fit(X_train, y_train) +best_model = grid_search.best_estimator_ +print(f"Best params: {grid_search.best_params_}") +``` + +### RandomizedSearchCV (Faster) + +```python +from sklearn.model_selection import RandomizedSearchCV +from scipy.stats import randint, uniform + +param_distributions = { + 'n_estimators': randint(100, 1000), + 'max_depth': randint(5, 50), + 'min_samples_split': randint(2, 20) +} + +random_search = RandomizedSearchCV( + RandomForestClassifier(random_state=42), + param_distributions, + n_iter=50, # Number of combinations to try + cv=5, + n_jobs=-1, + random_state=42 +) + +random_search.fit(X_train, y_train) +``` + +### Pipeline with GridSearchCV + +```python +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.model_selection import GridSearchCV + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('svm', SVC()) +]) + +param_grid = { + 'svm__C': [0.1, 1, 10], + 'svm__kernel': ['rbf', 'linear'], + 'svm__gamma': ['scale', 'auto'] +} + +grid = GridSearchCV(pipeline, param_grid, cv=5) +grid.fit(X_train, y_train) +``` + +## Cross-Validation + +### Basic Cross-Validation + +```python +from sklearn.model_selection import cross_val_score + +scores = cross_val_score(model, X, y, cv=5, scoring='accuracy') +print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +### Multiple Metrics + +```python +from sklearn.model_selection import cross_validate + +scoring = ['accuracy', 'precision_weighted', 'recall_weighted', 'f1_weighted'] +results = cross_validate(model, X, y, cv=5, scoring=scoring) + +for metric in scoring: + scores = results[f'test_{metric}'] + print(f"{metric}: {scores.mean():.3f} (+/- {scores.std():.3f})") +``` + +### Custom CV Strategies + +```python +from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit + +# For imbalanced classification +cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) + +# For time series +cv = TimeSeriesSplit(n_splits=5) + +scores = cross_val_score(model, X, y, cv=cv) +``` + +## Common Metrics + +### Classification + +```python +from sklearn.metrics import ( + accuracy_score, balanced_accuracy_score, + precision_score, recall_score, f1_score, + confusion_matrix, classification_report, + roc_auc_score +) + +# Basic metrics +accuracy = accuracy_score(y_true, y_pred) +f1 = f1_score(y_true, y_pred, average='weighted') + +# Comprehensive report +print(classification_report(y_true, y_pred)) + +# ROC AUC (requires probabilities) +y_proba = model.predict_proba(X_test)[:, 1] +auc = roc_auc_score(y_true, y_proba) +``` + +### Regression + +```python +from sklearn.metrics import ( + mean_squared_error, + mean_absolute_error, + r2_score +) + +mse = mean_squared_error(y_true, y_pred) +rmse = mean_squared_error(y_true, y_pred, squared=False) +mae = mean_absolute_error(y_true, y_pred) +r2 = r2_score(y_true, y_pred) + +print(f"RMSE: {rmse:.3f}") +print(f"MAE: {mae:.3f}") +print(f"R²: {r2:.3f}") +``` + +## Feature Engineering + +### Polynomial Features + +```python +from sklearn.preprocessing import PolynomialFeatures + +poly = PolynomialFeatures(degree=2, include_bias=False) +X_poly = poly.fit_transform(X) +# [x1, x2] → [x1, x2, x1², x1·x2, x2²] +``` + +### Feature Selection + +```python +from sklearn.feature_selection import ( + SelectKBest, f_classif, + RFE, + SelectFromModel +) + +# Univariate selection +selector = SelectKBest(f_classif, k=10) +X_selected = selector.fit_transform(X, y) + +# Recursive feature elimination +from sklearn.ensemble import RandomForestClassifier +rfe = RFE(RandomForestClassifier(), n_features_to_select=10) +X_selected = rfe.fit_transform(X, y) + +# Model-based selection +selector = SelectFromModel( + RandomForestClassifier(n_estimators=100), + threshold='median' +) +X_selected = selector.fit_transform(X, y) +``` + +### Feature Importance + +```python +# Tree-based models +model = RandomForestClassifier() +model.fit(X_train, y_train) +importances = model.feature_importances_ + +# Visualize +import matplotlib.pyplot as plt +indices = np.argsort(importances)[::-1] +plt.bar(range(X.shape[1]), importances[indices]) +plt.xticks(range(X.shape[1]), feature_names[indices], rotation=90) +plt.show() + +# Permutation importance (works for any model) +from sklearn.inspection import permutation_importance +result = permutation_importance(model, X_test, y_test, n_repeats=10) +importances = result.importances_mean +``` + +## Clustering + +### K-Means + +```python +from sklearn.cluster import KMeans +from sklearn.preprocessing import StandardScaler + +# Always scale for k-means +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# Fit k-means +kmeans = KMeans(n_clusters=3, random_state=42) +labels = kmeans.fit_predict(X_scaled) + +# Evaluate +from sklearn.metrics import silhouette_score +score = silhouette_score(X_scaled, labels) +print(f"Silhouette score: {score:.3f}") +``` + +### Elbow Method + +```python +inertias = [] +K_range = range(2, 11) + +for k in K_range: + kmeans = KMeans(n_clusters=k, random_state=42) + kmeans.fit(X_scaled) + inertias.append(kmeans.inertia_) + +plt.plot(K_range, inertias, 'bo-') +plt.xlabel('k') +plt.ylabel('Inertia') +plt.show() +``` + +### DBSCAN + +```python +from sklearn.cluster import DBSCAN + +dbscan = DBSCAN(eps=0.5, min_samples=5) +labels = dbscan.fit_predict(X_scaled) + +# -1 indicates noise/outliers +n_clusters = len(set(labels)) - (1 if -1 in labels else 0) +n_noise = list(labels).count(-1) +print(f"Clusters: {n_clusters}, Noise points: {n_noise}") +``` + +## Dimensionality Reduction + +### PCA + +```python +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +# Always scale before PCA +scaler = StandardScaler() +X_scaled = scaler.fit_transform(X) + +# Specify n_components +pca = PCA(n_components=2) +X_pca = pca.fit_transform(X_scaled) + +# Or specify variance to retain +pca = PCA(n_components=0.95) # Keep 95% variance +X_pca = pca.fit_transform(X_scaled) + +print(f"Explained variance: {pca.explained_variance_ratio_}") +print(f"Components needed: {pca.n_components_}") +``` + +### t-SNE (Visualization Only) + +```python +from sklearn.manifold import TSNE + +# Reduce to 50 dimensions with PCA first (recommended) +pca = PCA(n_components=50) +X_pca = pca.fit_transform(X_scaled) + +# Apply t-SNE +tsne = TSNE(n_components=2, random_state=42, perplexity=30) +X_tsne = tsne.fit_transform(X_pca) + +# Visualize +plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='viridis') +plt.colorbar() +plt.show() +``` + +## Saving and Loading Models + +```python +import joblib + +# Save model +joblib.dump(model, 'model.pkl') + +# Save pipeline +joblib.dump(pipeline, 'pipeline.pkl') + +# Load +model = joblib.load('model.pkl') +pipeline = joblib.load('pipeline.pkl') + +# Use loaded model +y_pred = model.predict(X_new) +``` + +## Common Pitfalls and Solutions + +### Data Leakage +❌ **Wrong**: Fit on all data before split +```python +scaler = StandardScaler().fit(X) +X_train, X_test = train_test_split(scaler.transform(X)) +``` + +✅ **Correct**: Use pipeline or fit only on train +```python +X_train, X_test = train_test_split(X) +pipeline = Pipeline([('scaler', StandardScaler()), ('model', model)]) +pipeline.fit(X_train, y_train) +``` + +### Not Scaling +❌ **Wrong**: Using SVM without scaling +```python +svm = SVC() +svm.fit(X_train, y_train) +``` + +✅ **Correct**: Scale for SVM +```python +pipeline = Pipeline([('scaler', StandardScaler()), ('svm', SVC())]) +pipeline.fit(X_train, y_train) +``` + +### Wrong Metric for Imbalanced Data +❌ **Wrong**: Using accuracy for 99:1 imbalance +```python +accuracy = accuracy_score(y_true, y_pred) # Can be misleading +``` + +✅ **Correct**: Use appropriate metrics +```python +f1 = f1_score(y_true, y_pred, average='weighted') +balanced_acc = balanced_accuracy_score(y_true, y_pred) +``` + +### Not Using Stratification +❌ **Wrong**: Random split for imbalanced data +```python +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) +``` + +✅ **Correct**: Stratify for imbalanced classes +```python +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, stratify=y +) +``` + +## Performance Tips + +1. **Use n_jobs=-1** for parallel processing (RandomForest, GridSearchCV) +2. **Use HistGradientBoosting** for large datasets (>10K samples) +3. **Use MiniBatchKMeans** for large clustering tasks +4. **Use IncrementalPCA** for data that doesn't fit in memory +5. **Use sparse matrices** for high-dimensional sparse data (text) +6. **Cache transformers** in pipelines during grid search +7. **Use RandomizedSearchCV** instead of GridSearchCV for large parameter spaces +8. **Reduce dimensionality** with PCA before applying expensive algorithms diff --git a/scientific-packages/scikit-learn/references/supervised_learning.md b/scientific-packages/scikit-learn/references/supervised_learning.md new file mode 100644 index 0000000..a424313 --- /dev/null +++ b/scientific-packages/scikit-learn/references/supervised_learning.md @@ -0,0 +1,261 @@ +# Supervised Learning in scikit-learn + +## Overview +Supervised learning algorithms learn patterns from labeled training data to make predictions on new data. Scikit-learn organizes supervised learning into 17 major categories. + +## Linear Models + +### Regression +- **LinearRegression**: Ordinary least squares regression +- **Ridge**: L2-regularized regression, good for multicollinearity +- **Lasso**: L1-regularized regression, performs feature selection +- **ElasticNet**: Combined L1/L2 regularization +- **LassoLars**: Lasso using Least Angle Regression algorithm +- **BayesianRidge**: Bayesian approach with automatic relevance determination + +### Classification +- **LogisticRegression**: Binary and multiclass classification +- **RidgeClassifier**: Ridge regression for classification +- **SGDClassifier**: Linear classifiers with SGD training + +**Use cases**: Baseline models, interpretable predictions, high-dimensional data, when linear relationships are expected + +**Key parameters**: +- `alpha`: Regularization strength (higher = more regularization) +- `fit_intercept`: Whether to calculate intercept +- `solver`: Optimization algorithm ('lbfgs', 'saga', 'liblinear') + +## Support Vector Machines (SVM) + +- **SVC**: Support Vector Classification +- **SVR**: Support Vector Regression +- **LinearSVC**: Linear SVM using liblinear (faster for large datasets) +- **OneClassSVM**: Unsupervised outlier detection + +**Use cases**: Complex non-linear decision boundaries, high-dimensional spaces, when clear margin of separation exists + +**Key parameters**: +- `kernel`: 'linear', 'poly', 'rbf', 'sigmoid' +- `C`: Regularization parameter (lower = more regularization) +- `gamma`: Kernel coefficient ('scale', 'auto', or float) +- `degree`: Polynomial degree (for poly kernel) + +**Performance tip**: SVMs don't scale well beyond tens of thousands of samples. Use LinearSVC for large datasets with linear kernel. + +## Decision Trees + +- **DecisionTreeClassifier**: Classification tree +- **DecisionTreeRegressor**: Regression tree +- **ExtraTreeClassifier/Regressor**: Extremely randomized tree + +**Use cases**: Non-linear relationships, feature importance analysis, interpretable rules, handling mixed data types + +**Key parameters**: +- `max_depth`: Maximum tree depth (controls overfitting) +- `min_samples_split`: Minimum samples to split a node +- `min_samples_leaf`: Minimum samples in leaf node +- `max_features`: Number of features to consider for splits +- `criterion`: 'gini', 'entropy' (classification); 'squared_error', 'absolute_error' (regression) + +**Overfitting prevention**: Limit `max_depth`, increase `min_samples_split/leaf`, use pruning with `ccp_alpha` + +## Ensemble Methods + +### Random Forests +- **RandomForestClassifier**: Ensemble of decision trees +- **RandomForestRegressor**: Regression variant + +**Use cases**: Robust general-purpose algorithm, reduces overfitting vs single trees, handles non-linear relationships + +**Key parameters**: +- `n_estimators`: Number of trees (higher = better but slower) +- `max_depth`: Maximum tree depth +- `max_features`: Features per split ('sqrt', 'log2', int, float) +- `bootstrap`: Whether to use bootstrap samples +- `n_jobs`: Parallel processing (-1 uses all cores) + +### Gradient Boosting +- **HistGradientBoostingClassifier/Regressor**: Histogram-based, fast for large datasets (>10k samples) +- **GradientBoostingClassifier/Regressor**: Traditional implementation, better for small datasets + +**Use cases**: High-performance predictions, winning Kaggle competitions, structured/tabular data + +**Key parameters**: +- `n_estimators`: Number of boosting stages +- `learning_rate`: Shrinks contribution of each tree +- `max_depth`: Maximum tree depth (typically 3-8) +- `subsample`: Fraction of samples per tree (enables stochastic gradient boosting) +- `early_stopping`: Stop when validation score stops improving + +**Performance tip**: HistGradientBoosting is orders of magnitude faster for large datasets + +### AdaBoost +- **AdaBoostClassifier/Regressor**: Adaptive boosting + +**Use cases**: Boosting weak learners, less prone to overfitting than other methods + +**Key parameters**: +- `estimator`: Base estimator (default: DecisionTreeClassifier with max_depth=1) +- `n_estimators`: Number of boosting iterations +- `learning_rate`: Weight applied to each classifier + +### Bagging +- **BaggingClassifier/Regressor**: Bootstrap aggregating with any base estimator + +**Use cases**: Reducing variance of unstable models, parallel ensemble creation + +**Key parameters**: +- `estimator`: Base estimator to fit +- `n_estimators`: Number of estimators +- `max_samples`: Samples to draw per estimator +- `bootstrap`: Whether to use replacement + +### Voting & Stacking +- **VotingClassifier/Regressor**: Combines different model types +- **StackingClassifier/Regressor**: Meta-learner trained on base predictions + +**Use cases**: Combining diverse models, leveraging different model strengths + +## Neural Networks + +- **MLPClassifier**: Multi-layer perceptron classifier +- **MLPRegressor**: Multi-layer perceptron regressor + +**Use cases**: Complex non-linear patterns, when gradient boosting is too slow, deep feature learning + +**Key parameters**: +- `hidden_layer_sizes`: Tuple of hidden layer sizes (e.g., (100, 50)) +- `activation`: 'relu', 'tanh', 'logistic' +- `solver`: 'adam', 'lbfgs', 'sgd' +- `alpha`: L2 regularization term +- `learning_rate`: Learning rate schedule +- `early_stopping`: Stop when validation score stops improving + +**Important**: Feature scaling is critical for neural networks. Always use StandardScaler or similar. + +## Nearest Neighbors + +- **KNeighborsClassifier/Regressor**: K-nearest neighbors +- **RadiusNeighborsClassifier/Regressor**: Radius-based neighbors +- **NearestCentroid**: Classification using class centroids + +**Use cases**: Simple baseline, irregular decision boundaries, when interpretability isn't critical + +**Key parameters**: +- `n_neighbors`: Number of neighbors (typically 3-11) +- `weights`: 'uniform' or 'distance' (distance-weighted voting) +- `metric`: Distance metric ('euclidean', 'manhattan', 'minkowski') +- `algorithm`: 'auto', 'ball_tree', 'kd_tree', 'brute' + +## Naive Bayes + +- **GaussianNB**: Assumes Gaussian distribution of features +- **MultinomialNB**: For discrete counts (text classification) +- **BernoulliNB**: For binary/boolean features +- **CategoricalNB**: For categorical features +- **ComplementNB**: Adapted for imbalanced datasets + +**Use cases**: Text classification, fast baseline, when features are independent, small training sets + +**Key parameters**: +- `alpha`: Smoothing parameter (Laplace/Lidstone smoothing) +- `fit_prior`: Whether to learn class prior probabilities + +## Linear/Quadratic Discriminant Analysis + +- **LinearDiscriminantAnalysis**: Linear decision boundary with dimensionality reduction +- **QuadraticDiscriminantAnalysis**: Quadratic decision boundary + +**Use cases**: When classes have Gaussian distributions, dimensionality reduction, when covariance assumptions hold + +## Gaussian Processes + +- **GaussianProcessClassifier**: Probabilistic classification +- **GaussianProcessRegressor**: Probabilistic regression with uncertainty estimates + +**Use cases**: When uncertainty quantification is important, small datasets, smooth function approximation + +**Key parameters**: +- `kernel`: Covariance function (RBF, Matern, RationalQuadratic, etc.) +- `alpha`: Noise level + +**Limitation**: Doesn't scale well to large datasets (O(n³) complexity) + +## Stochastic Gradient Descent + +- **SGDClassifier**: Linear classifiers with SGD +- **SGDRegressor**: Linear regressors with SGD + +**Use cases**: Very large datasets (>100k samples), online learning, when data doesn't fit in memory + +**Key parameters**: +- `loss`: Loss function ('hinge', 'log_loss', 'squared_error', etc.) +- `penalty`: Regularization ('l2', 'l1', 'elasticnet') +- `alpha`: Regularization strength +- `learning_rate`: Learning rate schedule + +## Semi-Supervised Learning + +- **SelfTrainingClassifier**: Self-training with any base classifier +- **LabelPropagation**: Label propagation through graph +- **LabelSpreading**: Label spreading (modified label propagation) + +**Use cases**: When labeled data is scarce but unlabeled data is abundant + +## Feature Selection + +- **VarianceThreshold**: Remove low-variance features +- **SelectKBest**: Select K highest scoring features +- **SelectPercentile**: Select top percentile of features +- **RFE**: Recursive feature elimination +- **RFECV**: RFE with cross-validation +- **SelectFromModel**: Select features based on importance +- **SequentialFeatureSelector**: Forward/backward feature selection + +**Use cases**: Reducing dimensionality, removing irrelevant features, improving interpretability, reducing overfitting + +## Probability Calibration + +- **CalibratedClassifierCV**: Calibrate classifier probabilities + +**Use cases**: When probability estimates are important (not just class predictions), especially with SVM and Naive Bayes + +**Methods**: +- `sigmoid`: Platt scaling +- `isotonic`: Isotonic regression (more flexible, needs more data) + +## Multi-Output Methods + +- **MultiOutputClassifier**: Fit one classifier per target +- **MultiOutputRegressor**: Fit one regressor per target +- **ClassifierChain**: Models dependencies between targets +- **RegressorChain**: Regression variant + +**Use cases**: Predicting multiple related targets simultaneously + +## Specialized Regression + +- **IsotonicRegression**: Monotonic regression +- **QuantileRegressor**: Quantile regression for prediction intervals + +## Algorithm Selection Guidelines + +**Start with**: +1. **Logistic Regression** (classification) or **LinearRegression/Ridge** (regression) as baseline +2. **RandomForestClassifier/Regressor** for general non-linear problems +3. **HistGradientBoostingClassifier/Regressor** when best performance is needed + +**Consider dataset size**: +- Small (<1k samples): SVM, Gaussian Processes, any algorithm +- Medium (1k-100k): Random Forests, Gradient Boosting, Neural Networks +- Large (>100k): SGD, HistGradientBoosting, LinearSVC + +**Consider interpretability needs**: +- High interpretability: Linear models, Decision Trees, Naive Bayes +- Medium: Random Forests (feature importance), Rule extraction +- Low (black box acceptable): Gradient Boosting, Neural Networks, SVM with RBF kernel + +**Consider training time**: +- Fast: Linear models, Naive Bayes, Decision Trees +- Medium: Random Forests (parallelizable), SVM (small data) +- Slow: Gradient Boosting, Neural Networks, SVM (large data), Gaussian Processes diff --git a/scientific-packages/scikit-learn/references/unsupervised_learning.md b/scientific-packages/scikit-learn/references/unsupervised_learning.md new file mode 100644 index 0000000..b379c48 --- /dev/null +++ b/scientific-packages/scikit-learn/references/unsupervised_learning.md @@ -0,0 +1,728 @@ +# Unsupervised Learning in scikit-learn + +## Overview +Unsupervised learning discovers patterns in data without labeled targets. Main tasks include clustering (grouping similar samples), dimensionality reduction (reducing feature count), and anomaly detection (finding outliers). + +## Clustering Algorithms + +### K-Means + +Groups data into k clusters by minimizing within-cluster variance. + +**Algorithm**: +1. Initialize k centroids (k-means++ initialization recommended) +2. Assign each point to nearest centroid +3. Update centroids to mean of assigned points +4. Repeat until convergence + +```python +from sklearn.cluster import KMeans + +kmeans = KMeans( + n_clusters=3, + init='k-means++', # Smart initialization + n_init=10, # Number of times to run with different seeds + max_iter=300, + random_state=42 +) +labels = kmeans.fit_predict(X) +centroids = kmeans.cluster_centers_ +``` + +**Use cases**: +- Customer segmentation +- Image compression +- Data preprocessing (clustering as features) + +**Strengths**: +- Fast and scalable +- Simple to understand +- Works well with spherical clusters + +**Limitations**: +- Assumes spherical clusters of similar size +- Sensitive to initialization (mitigated by k-means++) +- Must specify k beforehand +- Sensitive to outliers + +**Choosing k**: Use elbow method, silhouette score, or domain knowledge + +**Variants**: +- **MiniBatchKMeans**: Faster for large datasets, uses mini-batches +- **KMeans with n_init='auto'**: Adaptive number of initializations + +### DBSCAN + +Density-Based Spatial Clustering of Applications with Noise. Identifies clusters as dense regions separated by sparse areas. + +```python +from sklearn.cluster import DBSCAN + +dbscan = DBSCAN( + eps=0.5, # Maximum distance between neighbors + min_samples=5, # Minimum points to form dense region + metric='euclidean' +) +labels = dbscan.fit_predict(X) +# -1 indicates noise/outliers +``` + +**Use cases**: +- Arbitrary cluster shapes +- Outlier detection +- When cluster count is unknown +- Geographic/spatial data + +**Strengths**: +- Discovers arbitrary-shaped clusters +- Automatically detects outliers +- Doesn't require specifying number of clusters +- Robust to outliers + +**Limitations**: +- Struggles with varying densities +- Sensitive to eps and min_samples parameters +- Not deterministic (border points may vary) + +**Parameter tuning**: +- `eps`: Plot k-distance graph, look for elbow +- `min_samples`: Rule of thumb: 2 * dimensions + +### HDBSCAN + +Hierarchical DBSCAN that handles variable cluster densities. + +```python +from sklearn.cluster import HDBSCAN + +hdbscan = HDBSCAN( + min_cluster_size=5, + min_samples=None, # Defaults to min_cluster_size + metric='euclidean' +) +labels = hdbscan.fit_predict(X) +``` + +**Advantages over DBSCAN**: +- Handles variable density clusters +- More robust parameter selection +- Provides cluster membership probabilities +- Hierarchical structure + +**Use cases**: When DBSCAN struggles with varying densities + +### Hierarchical Clustering + +Builds nested cluster hierarchies using agglomerative (bottom-up) approach. + +```python +from sklearn.cluster import AgglomerativeClustering + +agg_clust = AgglomerativeClustering( + n_clusters=3, + linkage='ward', # 'ward', 'complete', 'average', 'single' + metric='euclidean' +) +labels = agg_clust.fit_predict(X) + +# Visualize with dendrogram +from scipy.cluster.hierarchy import dendrogram, linkage as scipy_linkage +import matplotlib.pyplot as plt + +linkage_matrix = scipy_linkage(X, method='ward') +dendrogram(linkage_matrix) +plt.show() +``` + +**Linkage methods**: +- `ward`: Minimizes variance (only with Euclidean) - **most common** +- `complete`: Maximum distance between clusters +- `average`: Average distance between clusters +- `single`: Minimum distance between clusters + +**Use cases**: +- When hierarchical structure is meaningful +- Taxonomy/phylogenetic trees +- When visualization is important (dendrograms) + +**Strengths**: +- No need to specify k initially (cut dendrogram at desired level) +- Produces hierarchy of clusters +- Deterministic + +**Limitations**: +- Computationally expensive (O(n²) to O(n³)) +- Not suitable for large datasets +- Cannot undo previous merges + +### Spectral Clustering + +Performs dimensionality reduction using affinity matrix before clustering. + +```python +from sklearn.cluster import SpectralClustering + +spectral = SpectralClustering( + n_clusters=3, + affinity='rbf', # 'rbf', 'nearest_neighbors', 'precomputed' + gamma=1.0, + n_neighbors=10, + random_state=42 +) +labels = spectral.fit_predict(X) +``` + +**Use cases**: +- Non-convex clusters +- Image segmentation +- Graph clustering +- When similarity matrix is available + +**Strengths**: +- Handles non-convex clusters +- Works with similarity matrices +- Often better than k-means for complex shapes + +**Limitations**: +- Computationally expensive +- Requires specifying number of clusters +- Memory intensive + +### Mean Shift + +Discovers clusters through iterative centroid updates based on density. + +```python +from sklearn.cluster import MeanShift, estimate_bandwidth + +# Estimate bandwidth +bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) + +mean_shift = MeanShift(bandwidth=bandwidth) +labels = mean_shift.fit_predict(X) +cluster_centers = mean_shift.cluster_centers_ +``` + +**Use cases**: +- When cluster count is unknown +- Computer vision applications +- Object tracking + +**Strengths**: +- Automatically determines number of clusters +- Handles arbitrary shapes +- No assumptions about cluster shape + +**Limitations**: +- Computationally expensive +- Very sensitive to bandwidth parameter +- Doesn't scale well + +### Affinity Propagation + +Uses message-passing between samples to identify exemplars. + +```python +from sklearn.cluster import AffinityPropagation + +affinity_prop = AffinityPropagation( + damping=0.5, # Damping factor (0.5-1.0) + preference=None, # Self-preference (controls number of clusters) + random_state=42 +) +labels = affinity_prop.fit_predict(X) +exemplars = affinity_prop.cluster_centers_indices_ +``` + +**Use cases**: +- When number of clusters is unknown +- When exemplars (representative samples) are needed + +**Strengths**: +- Automatically determines number of clusters +- Identifies exemplar samples +- No initialization required + +**Limitations**: +- Very slow: O(n²t) where t is iterations +- Not suitable for large datasets +- Memory intensive + +### Gaussian Mixture Models (GMM) + +Probabilistic model assuming data comes from mixture of Gaussian distributions. + +```python +from sklearn.mixture import GaussianMixture + +gmm = GaussianMixture( + n_components=3, + covariance_type='full', # 'full', 'tied', 'diag', 'spherical' + random_state=42 +) +labels = gmm.fit_predict(X) +probabilities = gmm.predict_proba(X) # Soft clustering +``` + +**Covariance types**: +- `full`: Each component has its own covariance matrix +- `tied`: All components share same covariance +- `diag`: Diagonal covariance (independent features) +- `spherical`: Spherical covariance (isotropic) + +**Use cases**: +- When soft clustering is needed (probabilities) +- When clusters have different shapes/sizes +- Generative modeling +- Density estimation + +**Strengths**: +- Provides probabilities (soft clustering) +- Can handle elliptical clusters +- Generative model (can sample new data) +- Model selection with BIC/AIC + +**Limitations**: +- Assumes Gaussian distributions +- Sensitive to initialization +- Can converge to local optima + +**Model selection**: +```python +from sklearn.mixture import GaussianMixture +import numpy as np + +n_components_range = range(2, 10) +bic_scores = [] + +for n in n_components_range: + gmm = GaussianMixture(n_components=n, random_state=42) + gmm.fit(X) + bic_scores.append(gmm.bic(X)) + +optimal_n = n_components_range[np.argmin(bic_scores)] +``` + +### BIRCH + +Builds Clustering Feature Tree for memory-efficient processing of large datasets. + +```python +from sklearn.cluster import Birch + +birch = Birch( + n_clusters=3, + threshold=0.5, + branching_factor=50 +) +labels = birch.fit_predict(X) +``` + +**Use cases**: +- Very large datasets +- Streaming data +- Memory constraints + +**Strengths**: +- Memory efficient +- Single pass over data +- Incremental learning + +## Dimensionality Reduction + +### Principal Component Analysis (PCA) + +Finds orthogonal components that explain maximum variance. + +```python +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt + +# Specify number of components +pca = PCA(n_components=2, random_state=42) +X_transformed = pca.fit_transform(X) + +print("Explained variance ratio:", pca.explained_variance_ratio_) +print("Total variance explained:", pca.explained_variance_ratio_.sum()) + +# Or specify variance to retain +pca = PCA(n_components=0.95) # Keep 95% of variance +X_transformed = pca.fit_transform(X) +print(f"Components needed: {pca.n_components_}") + +# Visualize explained variance +plt.plot(np.cumsum(pca.explained_variance_ratio_)) +plt.xlabel('Number of components') +plt.ylabel('Cumulative explained variance') +plt.show() +``` + +**Use cases**: +- Visualization (reduce to 2-3 dimensions) +- Remove multicollinearity +- Noise reduction +- Speed up training +- Feature extraction + +**Strengths**: +- Fast and efficient +- Reduces multicollinearity +- Works well for linear relationships +- Interpretable components + +**Limitations**: +- Only linear transformations +- Sensitive to scaling (always standardize first!) +- Components may be hard to interpret + +**Variants**: +- **IncrementalPCA**: For datasets that don't fit in memory +- **KernelPCA**: Non-linear dimensionality reduction +- **SparsePCA**: Sparse loadings for interpretability + +### t-SNE + +t-Distributed Stochastic Neighbor Embedding for visualization. + +```python +from sklearn.manifold import TSNE + +tsne = TSNE( + n_components=2, + perplexity=30, # Balance local vs global structure (5-50) + learning_rate='auto', + n_iter=1000, + random_state=42 +) +X_embedded = tsne.fit_transform(X) + +# Visualize +plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y) +plt.show() +``` + +**Use cases**: +- Visualization only (do not use for preprocessing!) +- Exploring high-dimensional data +- Finding clusters visually + +**Important notes**: +- **Only for visualization**, not for preprocessing +- Each run produces different results (use random_state for reproducibility) +- Slow for large datasets +- Cannot transform new data (no transform() method) + +**Parameter tuning**: +- `perplexity`: 5-50, larger for larger datasets +- Lower perplexity = focus on local structure +- Higher perplexity = focus on global structure + +### UMAP + +Uniform Manifold Approximation and Projection (requires umap-learn package). + +**Advantages over t-SNE**: +- Preserves global structure better +- Faster +- Can transform new data +- Can be used for preprocessing (not just visualization) + +### Truncated SVD (LSA) + +Similar to PCA but works with sparse matrices (e.g., TF-IDF). + +```python +from sklearn.decomposition import TruncatedSVD + +svd = TruncatedSVD(n_components=100, random_state=42) +X_reduced = svd.fit_transform(X_sparse) +``` + +**Use cases**: +- Text data (after TF-IDF) +- Sparse matrices +- Latent Semantic Analysis (LSA) + +### Non-negative Matrix Factorization (NMF) + +Factorizes data into non-negative components. + +```python +from sklearn.decomposition import NMF + +nmf = NMF(n_components=10, init='nndsvd', random_state=42) +W = nmf.fit_transform(X) # Document-topic matrix +H = nmf.components_ # Topic-word matrix +``` + +**Use cases**: +- Topic modeling +- Audio source separation +- Image processing +- When non-negativity is important (e.g., counts) + +**Strengths**: +- Interpretable components (additive, non-negative) +- Sparse representations + +### Independent Component Analysis (ICA) + +Separates multivariate signal into independent components. + +```python +from sklearn.decomposition import FastICA + +ica = FastICA(n_components=10, random_state=42) +X_independent = ica.fit_transform(X) +``` + +**Use cases**: +- Blind source separation +- Signal processing +- Feature extraction when independence is expected + +### Factor Analysis + +Models observed variables as linear combinations of latent factors plus noise. + +```python +from sklearn.decomposition import FactorAnalysis + +fa = FactorAnalysis(n_components=5, random_state=42) +X_factors = fa.fit_transform(X) +``` + +**Use cases**: +- When noise is heteroscedastic +- Latent variable modeling +- Psychology/social science research + +**Difference from PCA**: Models noise explicitly, assumes features have independent noise + +## Anomaly Detection + +### One-Class SVM + +Learns boundary around normal data. + +```python +from sklearn.svm import OneClassSVM + +oc_svm = OneClassSVM( + nu=0.1, # Proportion of outliers expected + kernel='rbf', + gamma='auto' +) +oc_svm.fit(X_train) +predictions = oc_svm.predict(X_test) # 1 for inliers, -1 for outliers +``` + +**Use cases**: +- Novelty detection +- When only normal data is available for training + +### Isolation Forest + +Isolates outliers using random forests. + +```python +from sklearn.ensemble import IsolationForest + +iso_forest = IsolationForest( + contamination=0.1, # Expected proportion of outliers + random_state=42 +) +predictions = iso_forest.fit_predict(X) # 1 for inliers, -1 for outliers +scores = iso_forest.score_samples(X) # Anomaly scores +``` + +**Use cases**: +- General anomaly detection +- Works well with high-dimensional data +- Fast and scalable + +**Strengths**: +- Fast +- Effective in high dimensions +- Low memory requirements + +### Local Outlier Factor (LOF) + +Detects outliers based on local density deviation. + +```python +from sklearn.neighbors import LocalOutlierFactor + +lof = LocalOutlierFactor( + n_neighbors=20, + contamination=0.1 +) +predictions = lof.fit_predict(X) # 1 for inliers, -1 for outliers +scores = lof.negative_outlier_factor_ # Anomaly scores (negative) +``` + +**Use cases**: +- Finding local outliers +- When global methods fail + +## Clustering Evaluation + +### With Ground Truth Labels + +When true labels are available (for validation): + +**Adjusted Rand Index (ARI)**: +```python +from sklearn.metrics import adjusted_rand_score +ari = adjusted_rand_score(y_true, y_pred) +# Range: [-1, 1], 1 = perfect, 0 = random +``` + +**Normalized Mutual Information (NMI)**: +```python +from sklearn.metrics import normalized_mutual_info_score +nmi = normalized_mutual_info_score(y_true, y_pred) +# Range: [0, 1], 1 = perfect +``` + +**V-Measure**: +```python +from sklearn.metrics import v_measure_score +v = v_measure_score(y_true, y_pred) +# Range: [0, 1], harmonic mean of homogeneity and completeness +``` + +### Without Ground Truth Labels + +When true labels are unavailable (unsupervised evaluation): + +**Silhouette Score**: +Measures how similar objects are to their own cluster vs other clusters. + +```python +from sklearn.metrics import silhouette_score, silhouette_samples +import matplotlib.pyplot as plt + +score = silhouette_score(X, labels) +# Range: [-1, 1], higher is better +# >0.7: Strong structure +# 0.5-0.7: Reasonable structure +# 0.25-0.5: Weak structure +# <0.25: No substantial structure + +# Per-sample scores for detailed analysis +sample_scores = silhouette_samples(X, labels) + +# Visualize silhouette plot +for i in range(n_clusters): + cluster_scores = sample_scores[labels == i] + cluster_scores.sort() + plt.barh(range(len(cluster_scores)), cluster_scores) +plt.axvline(x=score, color='red', linestyle='--') +plt.show() +``` + +**Davies-Bouldin Index**: +```python +from sklearn.metrics import davies_bouldin_score +db = davies_bouldin_score(X, labels) +# Lower is better, 0 = perfect +``` + +**Calinski-Harabasz Index** (Variance Ratio Criterion): +```python +from sklearn.metrics import calinski_harabasz_score +ch = calinski_harabasz_score(X, labels) +# Higher is better +``` + +**Inertia** (K-Means specific): +```python +inertia = kmeans.inertia_ +# Sum of squared distances to nearest cluster center +# Use for elbow method +``` + +### Elbow Method (K-Means) + +```python +from sklearn.cluster import KMeans +import matplotlib.pyplot as plt + +inertias = [] +K_range = range(2, 11) + +for k in K_range: + kmeans = KMeans(n_clusters=k, random_state=42) + kmeans.fit(X) + inertias.append(kmeans.inertia_) + +plt.plot(K_range, inertias, 'bo-') +plt.xlabel('Number of clusters (k)') +plt.ylabel('Inertia') +plt.title('Elbow Method') +plt.show() +# Look for "elbow" where inertia starts decreasing more slowly +``` + +## Best Practices + +### Clustering Algorithm Selection + +**Use K-Means when**: +- Clusters are spherical and similar size +- Speed is important +- Data is not too high-dimensional + +**Use DBSCAN when**: +- Arbitrary cluster shapes +- Number of clusters unknown +- Outlier detection needed + +**Use Hierarchical when**: +- Hierarchy is meaningful +- Small to medium datasets +- Visualization is important + +**Use GMM when**: +- Soft clustering needed +- Clusters have different shapes/sizes +- Probabilistic interpretation needed + +**Use Spectral Clustering when**: +- Non-convex clusters +- Have similarity matrix +- Moderate dataset size + +### Preprocessing for Clustering + +1. **Always scale features**: Use StandardScaler or MinMaxScaler +2. **Handle outliers**: Remove or use robust algorithms (DBSCAN, HDBSCAN) +3. **Reduce dimensionality if needed**: PCA for speed, careful with interpretation +4. **Check for categorical variables**: Encode appropriately or use specialized algorithms + +### Dimensionality Reduction Guidelines + +**For preprocessing/feature extraction**: +- PCA (linear relationships) +- TruncatedSVD (sparse data) +- NMF (non-negative data) + +**For visualization only**: +- t-SNE (preserves local structure) +- UMAP (preserves both local and global structure) + +**Always**: +- Standardize features before PCA +- Use appropriate n_components (elbow plot, explained variance) +- Don't use t-SNE for anything except visualization + +### Common Pitfalls + +1. **Not scaling data**: Most algorithms sensitive to scale +2. **Using t-SNE for preprocessing**: Only for visualization! +3. **Overfitting cluster count**: Too many clusters = overfitting noise +4. **Ignoring outliers**: Can severely affect centroid-based methods +5. **Wrong metric**: Euclidean assumes all features equally important +6. **Not validating results**: Always check with multiple metrics and domain knowledge +7. **PCA without standardization**: Components dominated by high-variance features diff --git a/scientific-packages/scikit-learn/scripts/classification_pipeline.py b/scientific-packages/scikit-learn/scripts/classification_pipeline.py new file mode 100644 index 0000000..749fd6d --- /dev/null +++ b/scientific-packages/scikit-learn/scripts/classification_pipeline.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +Complete classification pipeline with preprocessing, training, evaluation, and hyperparameter tuning. +Demonstrates best practices for scikit-learn workflows. +""" + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from sklearn.impute import SimpleImputer +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score +import joblib + + +def create_preprocessing_pipeline(numeric_features, categorical_features): + """ + Create preprocessing pipeline for mixed data types. + + Args: + numeric_features: List of numeric column names + categorical_features: List of categorical column names + + Returns: + ColumnTransformer with appropriate preprocessing for each data type + """ + numeric_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='median')), + ('scaler', StandardScaler()) + ]) + + categorical_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=True)) + ]) + + preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, numeric_features), + ('cat', categorical_transformer, categorical_features) + ]) + + return preprocessor + + +def create_full_pipeline(preprocessor, classifier=None): + """ + Create complete ML pipeline with preprocessing and classification. + + Args: + preprocessor: Preprocessing ColumnTransformer + classifier: Classifier instance (default: RandomForestClassifier) + + Returns: + Complete Pipeline + """ + if classifier is None: + classifier = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) + + pipeline = Pipeline(steps=[ + ('preprocessor', preprocessor), + ('classifier', classifier) + ]) + + return pipeline + + +def evaluate_model(pipeline, X_train, y_train, X_test, y_test, cv=5): + """ + Evaluate model using cross-validation and test set. + + Args: + pipeline: Trained pipeline + X_train, y_train: Training data + X_test, y_test: Test data + cv: Number of cross-validation folds + + Returns: + Dictionary with evaluation results + """ + # Cross-validation on training set + cv_scores = cross_val_score(pipeline, X_train, y_train, cv=cv, scoring='accuracy') + + # Test set evaluation + y_pred = pipeline.predict(X_test) + test_score = pipeline.score(X_test, y_test) + + # Get probabilities if available + try: + y_proba = pipeline.predict_proba(X_test) + if len(np.unique(y_test)) == 2: + # Binary classification + auc = roc_auc_score(y_test, y_proba[:, 1]) + else: + # Multiclass + auc = roc_auc_score(y_test, y_proba, multi_class='ovr') + except: + auc = None + + results = { + 'cv_mean': cv_scores.mean(), + 'cv_std': cv_scores.std(), + 'test_score': test_score, + 'auc': auc, + 'classification_report': classification_report(y_test, y_pred), + 'confusion_matrix': confusion_matrix(y_test, y_pred) + } + + return results + + +def tune_hyperparameters(pipeline, X_train, y_train, param_grid, cv=5): + """ + Perform hyperparameter tuning using GridSearchCV. + + Args: + pipeline: Pipeline to tune + X_train, y_train: Training data + param_grid: Dictionary of parameters to search + cv: Number of cross-validation folds + + Returns: + GridSearchCV object with best model + """ + grid_search = GridSearchCV( + pipeline, + param_grid, + cv=cv, + scoring='f1_weighted', + n_jobs=-1, + verbose=1 + ) + + grid_search.fit(X_train, y_train) + + print(f"Best parameters: {grid_search.best_params_}") + print(f"Best CV score: {grid_search.best_score_:.3f}") + + return grid_search + + +def main(): + """ + Example usage of the classification pipeline. + """ + # Load your data here + # X, y = load_data() + + # Example with synthetic data + from sklearn.datasets import make_classification + X, y = make_classification( + n_samples=1000, + n_features=20, + n_informative=15, + n_redundant=5, + random_state=42 + ) + + # Convert to DataFrame for demonstration + feature_names = [f'feature_{i}' for i in range(X.shape[1])] + X = pd.DataFrame(X, columns=feature_names) + + # Split features into numeric and categorical (all numeric in this example) + numeric_features = feature_names + categorical_features = [] + + # Split data (use stratify for imbalanced classes) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, stratify=y + ) + + # Create preprocessing pipeline + preprocessor = create_preprocessing_pipeline(numeric_features, categorical_features) + + # Create full pipeline + pipeline = create_full_pipeline(preprocessor) + + # Train model + print("Training model...") + pipeline.fit(X_train, y_train) + + # Evaluate model + print("\nEvaluating model...") + results = evaluate_model(pipeline, X_train, y_train, X_test, y_test) + + print(f"CV Accuracy: {results['cv_mean']:.3f} (+/- {results['cv_std']:.3f})") + print(f"Test Accuracy: {results['test_score']:.3f}") + if results['auc']: + print(f"ROC-AUC: {results['auc']:.3f}") + print("\nClassification Report:") + print(results['classification_report']) + + # Hyperparameter tuning (optional) + print("\nTuning hyperparameters...") + param_grid = { + 'classifier__n_estimators': [100, 200], + 'classifier__max_depth': [10, 20, None], + 'classifier__min_samples_split': [2, 5] + } + + grid_search = tune_hyperparameters(pipeline, X_train, y_train, param_grid) + + # Evaluate best model + print("\nEvaluating tuned model...") + best_pipeline = grid_search.best_estimator_ + y_pred = best_pipeline.predict(X_test) + print(classification_report(y_test, y_pred)) + + # Save model + print("\nSaving model...") + joblib.dump(best_pipeline, 'best_model.pkl') + print("Model saved as 'best_model.pkl'") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/scikit-learn/scripts/clustering_analysis.py b/scientific-packages/scikit-learn/scripts/clustering_analysis.py new file mode 100644 index 0000000..c8625f8 --- /dev/null +++ b/scientific-packages/scikit-learn/scripts/clustering_analysis.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +""" +Clustering analysis script with multiple algorithms and evaluation. +Demonstrates k-means, DBSCAN, and hierarchical clustering with visualization. +""" + +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering +from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import seaborn as sns + + +def scale_data(X): + """ + Scale features using StandardScaler. + ALWAYS scale data before clustering! + + Args: + X: Feature matrix + + Returns: + Scaled feature matrix and fitted scaler + """ + scaler = StandardScaler() + X_scaled = scaler.fit_transform(X) + return X_scaled, scaler + + +def find_optimal_k(X_scaled, k_range=range(2, 11)): + """ + Find optimal number of clusters using elbow method and silhouette score. + + Args: + X_scaled: Scaled feature matrix + k_range: Range of k values to try + + Returns: + Dictionary with inertias and silhouette scores + """ + inertias = [] + silhouette_scores = [] + + for k in k_range: + kmeans = KMeans(n_clusters=k, random_state=42, n_init=10) + labels = kmeans.fit_predict(X_scaled) + inertias.append(kmeans.inertia_) + silhouette_scores.append(silhouette_score(X_scaled, labels)) + + return { + 'k_values': list(k_range), + 'inertias': inertias, + 'silhouette_scores': silhouette_scores + } + + +def plot_elbow_silhouette(results): + """ + Plot elbow method and silhouette scores. + + Args: + results: Dictionary from find_optimal_k + """ + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + + # Elbow plot + ax1.plot(results['k_values'], results['inertias'], 'bo-') + ax1.set_xlabel('Number of clusters (k)') + ax1.set_ylabel('Inertia') + ax1.set_title('Elbow Method') + ax1.grid(True, alpha=0.3) + + # Silhouette plot + ax2.plot(results['k_values'], results['silhouette_scores'], 'ro-') + ax2.set_xlabel('Number of clusters (k)') + ax2.set_ylabel('Silhouette Score') + ax2.set_title('Silhouette Score vs k') + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig('elbow_silhouette.png', dpi=300, bbox_inches='tight') + print("Saved elbow and silhouette plots to 'elbow_silhouette.png'") + plt.close() + + +def evaluate_clustering(X_scaled, labels, algorithm_name): + """ + Evaluate clustering using multiple metrics. + + Args: + X_scaled: Scaled feature matrix + labels: Cluster labels + algorithm_name: Name of clustering algorithm + + Returns: + Dictionary with evaluation metrics + """ + # Filter out noise points for DBSCAN (-1 labels) + mask = labels != -1 + X_filtered = X_scaled[mask] + labels_filtered = labels[mask] + + n_clusters = len(set(labels_filtered)) + n_noise = list(labels).count(-1) + + results = { + 'algorithm': algorithm_name, + 'n_clusters': n_clusters, + 'n_noise': n_noise + } + + # Calculate metrics if we have valid clusters + if n_clusters > 1: + results['silhouette'] = silhouette_score(X_filtered, labels_filtered) + results['davies_bouldin'] = davies_bouldin_score(X_filtered, labels_filtered) + results['calinski_harabasz'] = calinski_harabasz_score(X_filtered, labels_filtered) + else: + results['silhouette'] = None + results['davies_bouldin'] = None + results['calinski_harabasz'] = None + + return results + + +def perform_kmeans(X_scaled, n_clusters=3): + """ + Perform k-means clustering. + + Args: + X_scaled: Scaled feature matrix + n_clusters: Number of clusters + + Returns: + Fitted KMeans model and labels + """ + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) + labels = kmeans.fit_predict(X_scaled) + return kmeans, labels + + +def perform_dbscan(X_scaled, eps=0.5, min_samples=5): + """ + Perform DBSCAN clustering. + + Args: + X_scaled: Scaled feature matrix + eps: Maximum distance between neighbors + min_samples: Minimum points to form dense region + + Returns: + Fitted DBSCAN model and labels + """ + dbscan = DBSCAN(eps=eps, min_samples=min_samples) + labels = dbscan.fit_predict(X_scaled) + return dbscan, labels + + +def perform_hierarchical(X_scaled, n_clusters=3, linkage='ward'): + """ + Perform hierarchical clustering. + + Args: + X_scaled: Scaled feature matrix + n_clusters: Number of clusters + linkage: Linkage criterion ('ward', 'complete', 'average', 'single') + + Returns: + Fitted AgglomerativeClustering model and labels + """ + hierarchical = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage) + labels = hierarchical.fit_predict(X_scaled) + return hierarchical, labels + + +def visualize_clusters_2d(X_scaled, labels, algorithm_name, method='pca'): + """ + Visualize clusters in 2D using PCA or t-SNE. + + Args: + X_scaled: Scaled feature matrix + labels: Cluster labels + algorithm_name: Name of algorithm for title + method: 'pca' or 'tsne' + """ + # Reduce to 2D + if method == 'pca': + pca = PCA(n_components=2, random_state=42) + X_2d = pca.fit_transform(X_scaled) + variance = pca.explained_variance_ratio_ + xlabel = f'PC1 ({variance[0]:.1%} variance)' + ylabel = f'PC2 ({variance[1]:.1%} variance)' + else: + from sklearn.manifold import TSNE + # Use PCA first to speed up t-SNE + pca = PCA(n_components=min(50, X_scaled.shape[1]), random_state=42) + X_pca = pca.fit_transform(X_scaled) + tsne = TSNE(n_components=2, random_state=42, perplexity=30) + X_2d = tsne.fit_transform(X_pca) + xlabel = 't-SNE 1' + ylabel = 't-SNE 2' + + # Plot + plt.figure(figsize=(10, 8)) + scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=labels, cmap='viridis', alpha=0.6, s=50) + plt.colorbar(scatter, label='Cluster') + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(f'{algorithm_name} Clustering ({method.upper()})') + plt.grid(True, alpha=0.3) + + filename = f'{algorithm_name.lower().replace(" ", "_")}_{method}.png' + plt.savefig(filename, dpi=300, bbox_inches='tight') + print(f"Saved visualization to '{filename}'") + plt.close() + + +def main(): + """ + Example clustering analysis workflow. + """ + # Load your data here + # X = load_data() + + # Example with synthetic data + from sklearn.datasets import make_blobs + X, y_true = make_blobs( + n_samples=500, + n_features=10, + centers=4, + cluster_std=1.0, + random_state=42 + ) + + print(f"Dataset shape: {X.shape}") + + # Scale data (ALWAYS scale for clustering!) + print("\nScaling data...") + X_scaled, scaler = scale_data(X) + + # Find optimal k + print("\nFinding optimal number of clusters...") + results = find_optimal_k(X_scaled) + plot_elbow_silhouette(results) + + # Based on elbow/silhouette, choose optimal k + optimal_k = 4 # Adjust based on plots + + # Perform k-means + print(f"\nPerforming k-means with k={optimal_k}...") + kmeans, kmeans_labels = perform_kmeans(X_scaled, n_clusters=optimal_k) + kmeans_results = evaluate_clustering(X_scaled, kmeans_labels, 'K-Means') + + # Perform DBSCAN + print("\nPerforming DBSCAN...") + dbscan, dbscan_labels = perform_dbscan(X_scaled, eps=0.5, min_samples=5) + dbscan_results = evaluate_clustering(X_scaled, dbscan_labels, 'DBSCAN') + + # Perform hierarchical clustering + print("\nPerforming hierarchical clustering...") + hierarchical, hier_labels = perform_hierarchical(X_scaled, n_clusters=optimal_k) + hier_results = evaluate_clustering(X_scaled, hier_labels, 'Hierarchical') + + # Print results + print("\n" + "="*60) + print("CLUSTERING RESULTS") + print("="*60) + + for results in [kmeans_results, dbscan_results, hier_results]: + print(f"\n{results['algorithm']}:") + print(f" Clusters: {results['n_clusters']}") + if results['n_noise'] > 0: + print(f" Noise points: {results['n_noise']}") + if results['silhouette']: + print(f" Silhouette Score: {results['silhouette']:.3f}") + print(f" Davies-Bouldin Index: {results['davies_bouldin']:.3f} (lower is better)") + print(f" Calinski-Harabasz Index: {results['calinski_harabasz']:.1f} (higher is better)") + + # Visualize clusters + print("\nCreating visualizations...") + visualize_clusters_2d(X_scaled, kmeans_labels, 'K-Means', method='pca') + visualize_clusters_2d(X_scaled, dbscan_labels, 'DBSCAN', method='pca') + visualize_clusters_2d(X_scaled, hier_labels, 'Hierarchical', method='pca') + + print("\nClustering analysis complete!") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/seaborn/SKILL.md b/scientific-packages/seaborn/SKILL.md new file mode 100644 index 0000000..75795a9 --- /dev/null +++ b/scientific-packages/seaborn/SKILL.md @@ -0,0 +1,667 @@ +--- +name: seaborn +description: Comprehensive toolkit for creating statistical data visualizations with seaborn, a Python library built on matplotlib. Use this skill when creating plots for exploratory data analysis, statistical relationships, distributions, categorical comparisons, regression analysis, heatmaps, or multi-panel figures. Applies to tasks involving scatter plots, line plots, histograms, KDE plots, box plots, violin plots, bar plots, pair plots, joint plots, and faceted visualizations. +--- + +# Seaborn Statistical Visualization + +## Overview + +Seaborn is a Python visualization library providing a high-level, dataset-oriented interface for creating publication-quality statistical graphics. Built on matplotlib, seaborn emphasizes declarative syntax that allows focus on data relationships rather than visual implementation details. The library excels at multivariate analysis, automatic statistical estimation, and creating complex multi-panel figures with minimal code. + +## Design Philosophy + +Seaborn follows these core principles: + +1. **Dataset-oriented**: Work directly with DataFrames and named variables rather than abstract coordinates +2. **Semantic mapping**: Automatically translate data values into visual properties (colors, sizes, styles) +3. **Statistical awareness**: Built-in aggregation, error estimation, and confidence intervals +4. **Aesthetic defaults**: Publication-ready themes and color palettes out of the box +5. **Matplotlib integration**: Full compatibility with matplotlib customization when needed + +## Quick Start + +```python +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd + +# Load example dataset +df = sns.load_dataset('tips') + +# Create a simple visualization +sns.scatterplot(data=df, x='total_bill', y='tip', hue='day') +plt.show() +``` + +## Core Plotting Interfaces + +### Function Interface (Traditional) + +The function interface provides specialized plotting functions organized by visualization type. Each category has **axes-level** functions (plot to single axes) and **figure-level** functions (manage entire figure with faceting). + +**When to use:** +- Quick exploratory analysis +- Single-purpose visualizations +- When you need a specific plot type + +### Objects Interface (Modern) + +The `seaborn.objects` interface provides a declarative, composable API similar to ggplot2. Build visualizations by chaining methods to specify data mappings, marks, transformations, and scales. + +**When to use:** +- Complex layered visualizations +- When you need fine-grained control over transformations +- Building custom plot types +- Programmatic plot generation + +```python +from seaborn import objects as so + +# Declarative syntax +( + so.Plot(data=df, x='total_bill', y='tip') + .add(so.Dot(), color='day') + .add(so.Line(), so.PolyFit()) +) +``` + +## Plotting Functions by Category + +### Relational Plots (Relationships Between Variables) + +**Use for:** Exploring how two or more variables relate to each other + +- `scatterplot()` - Display individual observations as points +- `lineplot()` - Show trends and changes (automatically aggregates and computes CI) +- `relplot()` - Figure-level interface with automatic faceting + +**Key parameters:** +- `x`, `y` - Primary variables +- `hue` - Color encoding for additional categorical/continuous variable +- `size` - Point/line size encoding +- `style` - Marker/line style encoding +- `col`, `row` - Facet into multiple subplots (figure-level only) + +```python +# Scatter with multiple semantic mappings +sns.scatterplot(data=df, x='total_bill', y='tip', + hue='time', size='size', style='sex') + +# Line plot with confidence intervals +sns.lineplot(data=timeseries, x='date', y='value', hue='category') + +# Faceted relational plot +sns.relplot(data=df, x='total_bill', y='tip', + col='time', row='sex', hue='smoker', kind='scatter') +``` + +### Distribution Plots (Single and Bivariate Distributions) + +**Use for:** Understanding data spread, shape, and probability density + +- `histplot()` - Bar-based frequency distributions with flexible binning +- `kdeplot()` - Smooth density estimates using Gaussian kernels +- `ecdfplot()` - Empirical cumulative distribution (no parameters to tune) +- `rugplot()` - Individual observation tick marks +- `displot()` - Figure-level interface for univariate and bivariate distributions +- `jointplot()` - Bivariate plot with marginal distributions +- `pairplot()` - Matrix of pairwise relationships across dataset + +**Key parameters:** +- `x`, `y` - Variables (y optional for univariate) +- `hue` - Separate distributions by category +- `stat` - Normalization: "count", "frequency", "probability", "density" +- `bins` / `binwidth` - Histogram binning control +- `bw_adjust` - KDE bandwidth multiplier (higher = smoother) +- `fill` - Fill area under curve +- `multiple` - How to handle hue: "layer", "stack", "dodge", "fill" + +```python +# Histogram with density normalization +sns.histplot(data=df, x='total_bill', hue='time', + stat='density', multiple='stack') + +# Bivariate KDE with contours +sns.kdeplot(data=df, x='total_bill', y='tip', + fill=True, levels=5, thresh=0.1) + +# Joint plot with marginals +sns.jointplot(data=df, x='total_bill', y='tip', + kind='scatter', hue='time') + +# Pairwise relationships +sns.pairplot(data=df, hue='species', corner=True) +``` + +### Categorical Plots (Comparisons Across Categories) + +**Use for:** Comparing distributions or statistics across discrete categories + +**Categorical scatterplots:** +- `stripplot()` - Points with jitter to show all observations +- `swarmplot()` - Non-overlapping points (beeswarm algorithm) + +**Distribution comparisons:** +- `boxplot()` - Quartiles and outliers +- `violinplot()` - KDE + quartile information +- `boxenplot()` - Enhanced boxplot for larger datasets + +**Statistical estimates:** +- `barplot()` - Mean/aggregate with confidence intervals +- `pointplot()` - Point estimates with connecting lines +- `countplot()` - Count of observations per category + +**Figure-level:** +- `catplot()` - Faceted categorical plots (set `kind` parameter) + +**Key parameters:** +- `x`, `y` - Variables (one typically categorical) +- `hue` - Additional categorical grouping +- `order`, `hue_order` - Control category ordering +- `dodge` - Separate hue levels side-by-side +- `orient` - "v" (vertical) or "h" (horizontal) +- `kind` - Plot type for catplot: "strip", "swarm", "box", "violin", "bar", "point" + +```python +# Swarm plot showing all points +sns.swarmplot(data=df, x='day', y='total_bill', hue='sex') + +# Violin plot with split for comparison +sns.violinplot(data=df, x='day', y='total_bill', + hue='sex', split=True) + +# Bar plot with error bars +sns.barplot(data=df, x='day', y='total_bill', + hue='sex', estimator='mean', errorbar='ci') + +# Faceted categorical plot +sns.catplot(data=df, x='day', y='total_bill', + col='time', kind='box') +``` + +### Regression Plots (Linear Relationships) + +**Use for:** Visualizing linear regressions and residuals + +- `regplot()` - Axes-level regression plot with scatter + fit line +- `lmplot()` - Figure-level with faceting support +- `residplot()` - Residual plot for assessing model fit + +**Key parameters:** +- `x`, `y` - Variables to regress +- `order` - Polynomial regression order +- `logistic` - Fit logistic regression +- `robust` - Use robust regression (less sensitive to outliers) +- `ci` - Confidence interval width (default 95) +- `scatter_kws`, `line_kws` - Customize scatter and line properties + +```python +# Simple linear regression +sns.regplot(data=df, x='total_bill', y='tip') + +# Polynomial regression with faceting +sns.lmplot(data=df, x='total_bill', y='tip', + col='time', order=2, ci=95) + +# Check residuals +sns.residplot(data=df, x='total_bill', y='tip') +``` + +### Matrix Plots (Rectangular Data) + +**Use for:** Visualizing matrices, correlations, and grid-structured data + +- `heatmap()` - Color-encoded matrix with annotations +- `clustermap()` - Hierarchically-clustered heatmap + +**Key parameters:** +- `data` - 2D rectangular dataset (DataFrame or array) +- `annot` - Display values in cells +- `fmt` - Format string for annotations (e.g., ".2f") +- `cmap` - Colormap name +- `center` - Value at colormap center (for diverging colormaps) +- `vmin`, `vmax` - Color scale limits +- `square` - Force square cells +- `linewidths` - Gap between cells + +```python +# Correlation heatmap +corr = df.corr() +sns.heatmap(corr, annot=True, fmt='.2f', + cmap='coolwarm', center=0, square=True) + +# Clustered heatmap +sns.clustermap(data, cmap='viridis', + standard_scale=1, figsize=(10, 10)) +``` + +## Multi-Plot Grids + +Seaborn provides grid objects for creating complex multi-panel figures: + +### FacetGrid + +Create subplots based on categorical variables. Most useful when called through figure-level functions (`relplot`, `displot`, `catplot`), but can be used directly for custom plots. + +```python +g = sns.FacetGrid(df, col='time', row='sex', hue='smoker') +g.map(sns.scatterplot, 'total_bill', 'tip') +g.add_legend() +``` + +### PairGrid + +Show pairwise relationships between all variables in a dataset. + +```python +g = sns.PairGrid(df, hue='species') +g.map_upper(sns.scatterplot) +g.map_lower(sns.kdeplot) +g.map_diag(sns.histplot) +g.add_legend() +``` + +### JointGrid + +Combine bivariate plot with marginal distributions. + +```python +g = sns.JointGrid(data=df, x='total_bill', y='tip') +g.plot_joint(sns.scatterplot) +g.plot_marginals(sns.histplot) +``` + +## Figure-Level vs Axes-Level Functions + +Understanding this distinction is crucial for effective seaborn usage: + +### Axes-Level Functions +- Plot to a single matplotlib `Axes` object +- Integrate easily into complex matplotlib figures +- Accept `ax=` parameter for precise placement +- Return `Axes` object +- Examples: `scatterplot`, `histplot`, `boxplot`, `regplot`, `heatmap` + +**When to use:** +- Building custom multi-plot layouts +- Combining different plot types +- Need matplotlib-level control +- Integrating with existing matplotlib code + +```python +fig, axes = plt.subplots(2, 2, figsize=(10, 10)) +sns.scatterplot(data=df, x='x', y='y', ax=axes[0, 0]) +sns.histplot(data=df, x='x', ax=axes[0, 1]) +sns.boxplot(data=df, x='cat', y='y', ax=axes[1, 0]) +sns.kdeplot(data=df, x='x', y='y', ax=axes[1, 1]) +``` + +### Figure-Level Functions +- Manage entire figure including all subplots +- Built-in faceting via `col` and `row` parameters +- Return `FacetGrid`, `JointGrid`, or `PairGrid` objects +- Use `height` and `aspect` for sizing (per subplot) +- Cannot be placed in existing figure +- Examples: `relplot`, `displot`, `catplot`, `lmplot`, `jointplot`, `pairplot` + +**When to use:** +- Faceted visualizations (small multiples) +- Quick exploratory analysis +- Consistent multi-panel layouts +- Don't need to combine with other plot types + +```python +# Automatic faceting +sns.relplot(data=df, x='x', y='y', col='category', row='group', + hue='type', height=3, aspect=1.2) +``` + +## Data Structure Requirements + +### Long-Form Data (Preferred) + +Each variable is a column, each observation is a row. This "tidy" format provides maximum flexibility: + +```python +# Long-form structure + subject condition measurement +0 1 control 10.5 +1 1 treatment 12.3 +2 2 control 9.8 +3 2 treatment 13.1 +``` + +**Advantages:** +- Works with all seaborn functions +- Easy to remap variables to visual properties +- Supports arbitrary complexity +- Natural for DataFrame operations + +### Wide-Form Data + +Variables are spread across columns. Useful for simple rectangular data: + +```python +# Wide-form structure + control treatment +0 10.5 12.3 +1 9.8 13.1 +``` + +**Use cases:** +- Simple time series +- Correlation matrices +- Heatmaps +- Quick plots of array data + +**Converting wide to long:** +```python +df_long = df.melt(var_name='condition', value_name='measurement') +``` + +## Color Palettes + +Seaborn provides carefully designed color palettes for different data types: + +### Qualitative Palettes (Categorical Data) + +Distinguish categories through hue variation: +- `"deep"` - Default, vivid colors +- `"muted"` - Softer, less saturated +- `"pastel"` - Light, desaturated +- `"bright"` - Highly saturated +- `"dark"` - Dark values +- `"colorblind"` - Safe for color vision deficiency + +```python +sns.set_palette("colorblind") +sns.color_palette("Set2") +``` + +### Sequential Palettes (Ordered Data) + +Show progression from low to high values: +- `"rocket"`, `"mako"` - Wide luminance range (good for heatmaps) +- `"flare"`, `"crest"` - Restricted luminance (good for points/lines) +- `"viridis"`, `"magma"`, `"plasma"` - Matplotlib perceptually uniform + +```python +sns.heatmap(data, cmap='rocket') +sns.kdeplot(data=df, x='x', y='y', cmap='mako', fill=True) +``` + +### Diverging Palettes (Centered Data) + +Emphasize deviations from a midpoint: +- `"vlag"` - Blue to red +- `"icefire"` - Blue to orange +- `"coolwarm"` - Cool to warm +- `"Spectral"` - Rainbow diverging + +```python +sns.heatmap(correlation_matrix, cmap='vlag', center=0) +``` + +### Custom Palettes + +```python +# Create custom palette +custom = sns.color_palette("husl", 8) + +# Light to dark gradient +palette = sns.light_palette("seagreen", as_cmap=True) + +# Diverging palette from hues +palette = sns.diverging_palette(250, 10, as_cmap=True) +``` + +## Theming and Aesthetics + +### Set Theme + +`set_theme()` controls overall appearance: + +```python +# Set complete theme +sns.set_theme(style='whitegrid', palette='pastel', font='sans-serif') + +# Reset to defaults +sns.set_theme() +``` + +### Styles + +Control background and grid appearance: +- `"darkgrid"` - Gray background with white grid (default) +- `"whitegrid"` - White background with gray grid +- `"dark"` - Gray background, no grid +- `"white"` - White background, no grid +- `"ticks"` - White background with axis ticks + +```python +sns.set_style("whitegrid") + +# Remove spines +sns.despine(left=False, bottom=False, offset=10, trim=True) + +# Temporary style +with sns.axes_style("white"): + sns.scatterplot(data=df, x='x', y='y') +``` + +### Contexts + +Scale elements for different use cases: +- `"paper"` - Smallest (default) +- `"notebook"` - Slightly larger +- `"talk"` - Presentation slides +- `"poster"` - Large format + +```python +sns.set_context("talk", font_scale=1.2) + +# Temporary context +with sns.plotting_context("poster"): + sns.barplot(data=df, x='category', y='value') +``` + +## Best Practices + +### 1. Data Preparation + +Always use well-structured DataFrames with meaningful column names: + +```python +# Good: Named columns in DataFrame +df = pd.DataFrame({'bill': bills, 'tip': tips, 'day': days}) +sns.scatterplot(data=df, x='bill', y='tip', hue='day') + +# Avoid: Unnamed arrays +sns.scatterplot(x=x_array, y=y_array) # Loses axis labels +``` + +### 2. Choose the Right Plot Type + +**Continuous x, continuous y:** `scatterplot`, `lineplot`, `kdeplot`, `regplot` +**Continuous x, categorical y:** `violinplot`, `boxplot`, `stripplot`, `swarmplot` +**One continuous variable:** `histplot`, `kdeplot`, `ecdfplot` +**Correlations/matrices:** `heatmap`, `clustermap` +**Pairwise relationships:** `pairplot`, `jointplot` + +### 3. Use Figure-Level Functions for Faceting + +```python +# Instead of manual subplot creation +sns.relplot(data=df, x='x', y='y', col='category', col_wrap=3) + +# Not: Creating subplots manually for simple faceting +``` + +### 4. Leverage Semantic Mappings + +Use `hue`, `size`, and `style` to encode additional dimensions: + +```python +sns.scatterplot(data=df, x='x', y='y', + hue='category', # Color by category + size='importance', # Size by continuous variable + style='type') # Marker style by type +``` + +### 5. Control Statistical Estimation + +Many functions compute statistics automatically. Understand and customize: + +```python +# Lineplot computes mean and 95% CI by default +sns.lineplot(data=df, x='time', y='value', + errorbar='sd') # Use standard deviation instead + +# Barplot computes mean by default +sns.barplot(data=df, x='category', y='value', + estimator='median', # Use median instead + errorbar=('ci', 95)) # Bootstrapped CI +``` + +### 6. Combine with Matplotlib + +Seaborn integrates seamlessly with matplotlib for fine-tuning: + +```python +ax = sns.scatterplot(data=df, x='x', y='y') +ax.set(xlabel='Custom X Label', ylabel='Custom Y Label', + title='Custom Title') +ax.axhline(y=0, color='r', linestyle='--') +plt.tight_layout() +``` + +### 7. Save High-Quality Figures + +```python +fig = sns.relplot(data=df, x='x', y='y', col='group') +fig.savefig('figure.png', dpi=300, bbox_inches='tight') +fig.savefig('figure.pdf') # Vector format for publications +``` + +## Common Patterns + +### Exploratory Data Analysis + +```python +# Quick overview of all relationships +sns.pairplot(data=df, hue='target', corner=True) + +# Distribution exploration +sns.displot(data=df, x='variable', hue='group', + kind='kde', fill=True, col='category') + +# Correlation analysis +corr = df.corr() +sns.heatmap(corr, annot=True, cmap='coolwarm', center=0) +``` + +### Publication-Quality Figures + +```python +sns.set_theme(style='ticks', context='paper', font_scale=1.1) + +g = sns.catplot(data=df, x='treatment', y='response', + col='cell_line', kind='box', height=3, aspect=1.2) +g.set_axis_labels('Treatment Condition', 'Response (μM)') +g.set_titles('{col_name}') +sns.despine(trim=True) + +g.savefig('figure.pdf', dpi=300, bbox_inches='tight') +``` + +### Complex Multi-Panel Figures + +```python +# Using matplotlib subplots with seaborn +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + +sns.scatterplot(data=df, x='x1', y='y', hue='group', ax=axes[0, 0]) +sns.histplot(data=df, x='x1', hue='group', ax=axes[0, 1]) +sns.violinplot(data=df, x='group', y='y', ax=axes[1, 0]) +sns.heatmap(df.pivot_table(values='y', index='x1', columns='x2'), + ax=axes[1, 1], cmap='viridis') + +plt.tight_layout() +``` + +### Time Series with Confidence Bands + +```python +# Lineplot automatically aggregates and shows CI +sns.lineplot(data=timeseries, x='date', y='measurement', + hue='sensor', style='location', errorbar='sd') + +# For more control +g = sns.relplot(data=timeseries, x='date', y='measurement', + col='location', hue='sensor', kind='line', + height=4, aspect=1.5, errorbar=('ci', 95)) +g.set_axis_labels('Date', 'Measurement (units)') +``` + +## Troubleshooting + +### Issue: Legend Outside Plot Area + +Figure-level functions place legends outside by default. To move inside: + +```python +g = sns.relplot(data=df, x='x', y='y', hue='category') +g._legend.set_bbox_to_anchor((0.9, 0.5)) # Adjust position +``` + +### Issue: Overlapping Labels + +```python +plt.xticks(rotation=45, ha='right') +plt.tight_layout() +``` + +### Issue: Figure Too Small + +For figure-level functions: +```python +sns.relplot(data=df, x='x', y='y', height=6, aspect=1.5) +``` + +For axes-level functions: +```python +fig, ax = plt.subplots(figsize=(10, 6)) +sns.scatterplot(data=df, x='x', y='y', ax=ax) +``` + +### Issue: Colors Not Distinct Enough + +```python +# Use a different palette +sns.set_palette("bright") + +# Or specify number of colors +palette = sns.color_palette("husl", n_colors=len(df['category'].unique())) +sns.scatterplot(data=df, x='x', y='y', hue='category', palette=palette) +``` + +### Issue: KDE Too Smooth or Jagged + +```python +# Adjust bandwidth +sns.kdeplot(data=df, x='x', bw_adjust=0.5) # Less smooth +sns.kdeplot(data=df, x='x', bw_adjust=2) # More smooth +``` + +## Resources + +This skill includes reference materials for deeper exploration: + +### references/ + +- `function_reference.md` - Comprehensive listing of all seaborn functions with parameters and examples +- `objects_interface.md` - Detailed guide to the modern seaborn.objects API +- `examples.md` - Common use cases and code patterns for different analysis scenarios + +Load reference files as needed for detailed function signatures, advanced parameters, or specific examples. diff --git a/scientific-packages/seaborn/references/examples.md b/scientific-packages/seaborn/references/examples.md new file mode 100644 index 0000000..cd7a0d4 --- /dev/null +++ b/scientific-packages/seaborn/references/examples.md @@ -0,0 +1,822 @@ +# Seaborn Common Use Cases and Examples + +This document provides practical examples for common data visualization scenarios using seaborn. + +## Exploratory Data Analysis + +### Quick Dataset Overview + +```python +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd + +# Load data +df = pd.read_csv('data.csv') + +# Pairwise relationships for all numeric variables +sns.pairplot(df, hue='target_variable', corner=True, diag_kind='kde') +plt.suptitle('Dataset Overview', y=1.01) +plt.savefig('overview.png', dpi=300, bbox_inches='tight') +``` + +### Distribution Exploration + +```python +# Multiple distributions across categories +g = sns.displot( + data=df, + x='measurement', + hue='condition', + col='timepoint', + kind='kde', + fill=True, + height=3, + aspect=1.5, + col_wrap=3, + common_norm=False +) +g.set_axis_labels('Measurement Value', 'Density') +g.set_titles('{col_name}') +``` + +### Correlation Analysis + +```python +# Compute correlation matrix +corr = df.select_dtypes(include='number').corr() + +# Create mask for upper triangle +mask = np.triu(np.ones_like(corr, dtype=bool)) + +# Plot heatmap +fig, ax = plt.subplots(figsize=(10, 8)) +sns.heatmap( + corr, + mask=mask, + annot=True, + fmt='.2f', + cmap='coolwarm', + center=0, + square=True, + linewidths=1, + cbar_kws={'shrink': 0.8} +) +plt.title('Correlation Matrix') +plt.tight_layout() +``` + +## Scientific Publications + +### Multi-Panel Figure with Different Plot Types + +```python +# Set publication style +sns.set_theme(style='ticks', context='paper', font_scale=1.1) +sns.set_palette('colorblind') + +# Create figure with custom layout +fig = plt.figure(figsize=(12, 8)) +gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3) + +# Panel A: Time series +ax1 = fig.add_subplot(gs[0, :2]) +sns.lineplot( + data=timeseries_df, + x='time', + y='expression', + hue='gene', + style='treatment', + markers=True, + dashes=False, + ax=ax1 +) +ax1.set_title('A. Gene Expression Over Time', loc='left', fontweight='bold') +ax1.set_xlabel('Time (hours)') +ax1.set_ylabel('Expression Level (AU)') + +# Panel B: Distribution comparison +ax2 = fig.add_subplot(gs[0, 2]) +sns.violinplot( + data=expression_df, + x='treatment', + y='expression', + inner='box', + ax=ax2 +) +ax2.set_title('B. Expression Distribution', loc='left', fontweight='bold') +ax2.set_xlabel('Treatment') +ax2.set_ylabel('') + +# Panel C: Correlation +ax3 = fig.add_subplot(gs[1, 0]) +sns.scatterplot( + data=correlation_df, + x='gene1', + y='gene2', + hue='cell_type', + alpha=0.6, + ax=ax3 +) +sns.regplot( + data=correlation_df, + x='gene1', + y='gene2', + scatter=False, + color='black', + ax=ax3 +) +ax3.set_title('C. Gene Correlation', loc='left', fontweight='bold') +ax3.set_xlabel('Gene 1 Expression') +ax3.set_ylabel('Gene 2 Expression') + +# Panel D: Heatmap +ax4 = fig.add_subplot(gs[1, 1:]) +sns.heatmap( + sample_matrix, + cmap='RdBu_r', + center=0, + annot=True, + fmt='.1f', + cbar_kws={'label': 'Log2 Fold Change'}, + ax=ax4 +) +ax4.set_title('D. Treatment Effects', loc='left', fontweight='bold') +ax4.set_xlabel('Sample') +ax4.set_ylabel('Gene') + +# Clean up +sns.despine() +plt.savefig('figure.pdf', dpi=300, bbox_inches='tight') +plt.savefig('figure.png', dpi=300, bbox_inches='tight') +``` + +### Box Plot with Significance Annotations + +```python +import numpy as np +from scipy import stats + +# Create plot +fig, ax = plt.subplots(figsize=(8, 6)) +sns.boxplot( + data=df, + x='treatment', + y='response', + order=['Control', 'Low', 'Medium', 'High'], + palette='Set2', + ax=ax +) + +# Add individual points +sns.stripplot( + data=df, + x='treatment', + y='response', + order=['Control', 'Low', 'Medium', 'High'], + color='black', + alpha=0.3, + size=3, + ax=ax +) + +# Add significance bars +def add_significance_bar(ax, x1, x2, y, h, text): + ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], 'k-', lw=1.5) + ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom') + +y_max = df['response'].max() +add_significance_bar(ax, 0, 3, y_max + 1, 0.5, '***') +add_significance_bar(ax, 0, 1, y_max + 3, 0.5, 'ns') + +ax.set_ylabel('Response (μM)') +ax.set_xlabel('Treatment Condition') +ax.set_title('Treatment Response Analysis') +sns.despine() +``` + +## Time Series Analysis + +### Multiple Time Series with Confidence Bands + +```python +# Plot with automatic aggregation +fig, ax = plt.subplots(figsize=(10, 6)) +sns.lineplot( + data=timeseries_df, + x='timestamp', + y='value', + hue='sensor', + style='location', + markers=True, + dashes=False, + errorbar=('ci', 95), + ax=ax +) + +# Customize +ax.set_xlabel('Date') +ax.set_ylabel('Measurement (units)') +ax.set_title('Sensor Measurements Over Time') +ax.legend(title='Sensor & Location', bbox_to_anchor=(1.05, 1), loc='upper left') + +# Format x-axis for dates +import matplotlib.dates as mdates +ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) +ax.xaxis.set_major_locator(mdates.DayLocator(interval=7)) +plt.xticks(rotation=45, ha='right') + +plt.tight_layout() +``` + +### Faceted Time Series + +```python +# Create faceted time series +g = sns.relplot( + data=long_timeseries, + x='date', + y='measurement', + hue='device', + col='location', + row='metric', + kind='line', + height=3, + aspect=2, + errorbar='sd', + facet_kws={'sharex': True, 'sharey': False} +) + +# Customize facet titles +g.set_titles('{row_name} - {col_name}') +g.set_axis_labels('Date', 'Value') + +# Rotate x-axis labels +for ax in g.axes.flat: + ax.tick_params(axis='x', rotation=45) + +g.tight_layout() +``` + +## Categorical Comparisons + +### Nested Categorical Variables + +```python +# Create figure +fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + +# Left panel: Grouped bar plot +sns.barplot( + data=df, + x='category', + y='value', + hue='subcategory', + errorbar=('ci', 95), + capsize=0.1, + ax=axes[0] +) +axes[0].set_title('Mean Values with 95% CI') +axes[0].set_ylabel('Value (units)') +axes[0].legend(title='Subcategory') + +# Right panel: Strip + violin plot +sns.violinplot( + data=df, + x='category', + y='value', + hue='subcategory', + inner=None, + alpha=0.3, + ax=axes[1] +) +sns.stripplot( + data=df, + x='category', + y='value', + hue='subcategory', + dodge=True, + size=3, + alpha=0.6, + ax=axes[1] +) +axes[1].set_title('Distribution of Individual Values') +axes[1].set_ylabel('') +axes[1].get_legend().remove() + +plt.tight_layout() +``` + +### Point Plot for Trends + +```python +# Show how values change across categories +sns.pointplot( + data=df, + x='timepoint', + y='score', + hue='treatment', + markers=['o', 's', '^'], + linestyles=['-', '--', '-.'], + dodge=0.3, + capsize=0.1, + errorbar=('ci', 95) +) + +plt.xlabel('Timepoint') +plt.ylabel('Performance Score') +plt.title('Treatment Effects Over Time') +plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left') +sns.despine() +plt.tight_layout() +``` + +## Regression and Relationships + +### Linear Regression with Facets + +```python +# Fit separate regressions for each category +g = sns.lmplot( + data=df, + x='predictor', + y='response', + hue='treatment', + col='cell_line', + height=4, + aspect=1.2, + scatter_kws={'alpha': 0.5, 's': 50}, + ci=95, + palette='Set2' +) + +g.set_axis_labels('Predictor Variable', 'Response Variable') +g.set_titles('{col_name}') +g.tight_layout() +``` + +### Polynomial Regression + +```python +fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + +for idx, order in enumerate([1, 2, 3]): + sns.regplot( + data=df, + x='x', + y='y', + order=order, + scatter_kws={'alpha': 0.5}, + line_kws={'color': 'red'}, + ci=95, + ax=axes[idx] + ) + axes[idx].set_title(f'Order {order} Polynomial Fit') + axes[idx].set_xlabel('X Variable') + axes[idx].set_ylabel('Y Variable') + +plt.tight_layout() +``` + +### Residual Analysis + +```python +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + +# Main regression +sns.regplot(data=df, x='x', y='y', ax=axes[0, 0]) +axes[0, 0].set_title('Regression Fit') + +# Residuals vs fitted +sns.residplot(data=df, x='x', y='y', lowess=True, + scatter_kws={'alpha': 0.5}, + line_kws={'color': 'red', 'lw': 2}, + ax=axes[0, 1]) +axes[0, 1].set_title('Residuals vs Fitted') +axes[0, 1].axhline(0, ls='--', color='gray') + +# Q-Q plot (using scipy) +from scipy import stats as sp_stats +residuals = df['y'] - np.poly1d(np.polyfit(df['x'], df['y'], 1))(df['x']) +sp_stats.probplot(residuals, dist="norm", plot=axes[1, 0]) +axes[1, 0].set_title('Q-Q Plot') + +# Histogram of residuals +sns.histplot(residuals, kde=True, ax=axes[1, 1]) +axes[1, 1].set_title('Residual Distribution') +axes[1, 1].set_xlabel('Residuals') + +plt.tight_layout() +``` + +## Bivariate and Joint Distributions + +### Joint Plot with Multiple Representations + +```python +# Scatter with marginals +g = sns.jointplot( + data=df, + x='var1', + y='var2', + hue='category', + kind='scatter', + height=8, + ratio=4, + space=0.1, + joint_kws={'alpha': 0.5, 's': 50}, + marginal_kws={'kde': True, 'bins': 30} +) + +# Add reference lines +g.ax_joint.axline((0, 0), slope=1, color='r', ls='--', alpha=0.5, label='y=x') +g.ax_joint.legend() + +g.set_axis_labels('Variable 1', 'Variable 2', fontsize=12) +``` + +### KDE Contour Plot + +```python +fig, ax = plt.subplots(figsize=(8, 8)) + +# Bivariate KDE with filled contours +sns.kdeplot( + data=df, + x='x', + y='y', + fill=True, + levels=10, + cmap='viridis', + thresh=0.05, + ax=ax +) + +# Overlay scatter +sns.scatterplot( + data=df, + x='x', + y='y', + color='white', + edgecolor='black', + s=50, + alpha=0.6, + ax=ax +) + +ax.set_xlabel('X Variable') +ax.set_ylabel('Y Variable') +ax.set_title('Bivariate Distribution') +``` + +### Hexbin with Marginals + +```python +# For large datasets +g = sns.jointplot( + data=large_df, + x='x', + y='y', + kind='hex', + height=8, + ratio=5, + space=0.1, + joint_kws={'gridsize': 30, 'cmap': 'viridis'}, + marginal_kws={'bins': 50, 'color': 'skyblue'} +) + +g.set_axis_labels('X Variable', 'Y Variable') +``` + +## Matrix and Heatmap Visualizations + +### Hierarchical Clustering Heatmap + +```python +# Prepare data (samples x features) +data_matrix = df.set_index('sample_id')[feature_columns] + +# Create color annotations +row_colors = df.set_index('sample_id')['condition'].map({ + 'control': '#1f77b4', + 'treatment': '#ff7f0e' +}) + +col_colors = pd.Series(['#2ca02c' if 'gene' in col else '#d62728' + for col in data_matrix.columns]) + +# Plot +g = sns.clustermap( + data_matrix, + method='ward', + metric='euclidean', + z_score=0, # Normalize rows + cmap='RdBu_r', + center=0, + row_colors=row_colors, + col_colors=col_colors, + figsize=(12, 10), + dendrogram_ratio=(0.1, 0.1), + cbar_pos=(0.02, 0.8, 0.03, 0.15), + linewidths=0.5 +) + +g.ax_heatmap.set_xlabel('Features') +g.ax_heatmap.set_ylabel('Samples') +plt.savefig('clustermap.png', dpi=300, bbox_inches='tight') +``` + +### Annotated Heatmap with Custom Colorbar + +```python +# Pivot data for heatmap +pivot_data = df.pivot(index='row_var', columns='col_var', values='value') + +# Create heatmap +fig, ax = plt.subplots(figsize=(10, 8)) +sns.heatmap( + pivot_data, + annot=True, + fmt='.1f', + cmap='RdYlGn', + center=pivot_data.mean().mean(), + vmin=pivot_data.min().min(), + vmax=pivot_data.max().max(), + linewidths=0.5, + linecolor='gray', + cbar_kws={ + 'label': 'Value (units)', + 'orientation': 'vertical', + 'shrink': 0.8, + 'aspect': 20 + }, + ax=ax +) + +ax.set_title('Variable Relationships', fontsize=14, pad=20) +ax.set_xlabel('Column Variable', fontsize=12) +ax.set_ylabel('Row Variable', fontsize=12) + +plt.xticks(rotation=45, ha='right') +plt.yticks(rotation=0) +plt.tight_layout() +``` + +## Statistical Comparisons + +### Before/After Comparison + +```python +# Reshape data for paired comparison +df_paired = df.melt( + id_vars='subject', + value_vars=['before', 'after'], + var_name='timepoint', + value_name='measurement' +) + +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# Left: Individual trajectories +for subject in df_paired['subject'].unique(): + subject_data = df_paired[df_paired['subject'] == subject] + axes[0].plot(subject_data['timepoint'], subject_data['measurement'], + 'o-', alpha=0.3, color='gray') + +sns.pointplot( + data=df_paired, + x='timepoint', + y='measurement', + color='red', + markers='D', + scale=1.5, + errorbar=('ci', 95), + capsize=0.2, + ax=axes[0] +) +axes[0].set_title('Individual Changes') +axes[0].set_ylabel('Measurement') + +# Right: Distribution comparison +sns.violinplot( + data=df_paired, + x='timepoint', + y='measurement', + inner='box', + ax=axes[1] +) +sns.swarmplot( + data=df_paired, + x='timepoint', + y='measurement', + color='black', + alpha=0.5, + size=3, + ax=axes[1] +) +axes[1].set_title('Distribution Comparison') +axes[1].set_ylabel('') + +plt.tight_layout() +``` + +### Dose-Response Curve + +```python +# Create dose-response plot +fig, ax = plt.subplots(figsize=(8, 6)) + +# Plot individual points +sns.stripplot( + data=dose_df, + x='dose', + y='response', + order=sorted(dose_df['dose'].unique()), + color='gray', + alpha=0.3, + jitter=0.2, + ax=ax +) + +# Overlay mean with CI +sns.pointplot( + data=dose_df, + x='dose', + y='response', + order=sorted(dose_df['dose'].unique()), + color='blue', + markers='o', + scale=1.2, + errorbar=('ci', 95), + capsize=0.1, + ax=ax +) + +# Fit sigmoid curve +from scipy.optimize import curve_fit + +def sigmoid(x, bottom, top, ec50, hill): + return bottom + (top - bottom) / (1 + (ec50 / x) ** hill) + +doses_numeric = dose_df['dose'].astype(float) +params, _ = curve_fit(sigmoid, doses_numeric, dose_df['response']) + +x_smooth = np.logspace(np.log10(doses_numeric.min()), + np.log10(doses_numeric.max()), 100) +y_smooth = sigmoid(x_smooth, *params) + +ax.plot(range(len(sorted(dose_df['dose'].unique()))), + sigmoid(sorted(doses_numeric.unique()), *params), + 'r-', linewidth=2, label='Sigmoid Fit') + +ax.set_xlabel('Dose') +ax.set_ylabel('Response') +ax.set_title('Dose-Response Analysis') +ax.legend() +sns.despine() +``` + +## Custom Styling + +### Custom Color Palette from Hex Codes + +```python +# Define custom palette +custom_palette = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F'] +sns.set_palette(custom_palette) + +# Or use for specific plot +sns.scatterplot( + data=df, + x='x', + y='y', + hue='category', + palette=custom_palette +) +``` + +### Publication-Ready Theme + +```python +# Set comprehensive theme +sns.set_theme( + context='paper', + style='ticks', + palette='colorblind', + font='Arial', + font_scale=1.1, + rc={ + 'figure.dpi': 300, + 'savefig.dpi': 300, + 'savefig.format': 'pdf', + 'axes.linewidth': 1.0, + 'axes.labelweight': 'bold', + 'xtick.major.width': 1.0, + 'ytick.major.width': 1.0, + 'xtick.direction': 'out', + 'ytick.direction': 'out', + 'legend.frameon': False, + 'pdf.fonttype': 42, # True Type fonts for PDFs + } +) +``` + +### Diverging Colormap Centered on Zero + +```python +# For data with meaningful zero point (e.g., log fold change) +from matplotlib.colors import TwoSlopeNorm + +# Find data range +vmin, vmax = df['value'].min(), df['value'].max() +vcenter = 0 + +# Create norm +norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax) + +# Plot +sns.heatmap( + pivot_data, + cmap='RdBu_r', + norm=norm, + center=0, + annot=True, + fmt='.2f' +) +``` + +## Large Datasets + +### Downsampling Strategy + +```python +# For very large datasets, sample intelligently +def smart_sample(df, target_size=10000, category_col=None): + if len(df) <= target_size: + return df + + if category_col: + # Stratified sampling + return df.groupby(category_col, group_keys=False).apply( + lambda x: x.sample(min(len(x), target_size // df[category_col].nunique())) + ) + else: + # Simple random sampling + return df.sample(target_size) + +# Use sampled data for visualization +df_sampled = smart_sample(large_df, target_size=5000, category_col='category') + +sns.scatterplot(data=df_sampled, x='x', y='y', hue='category', alpha=0.5) +``` + +### Hexbin for Dense Scatter Plots + +```python +# For millions of points +fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + +# Regular scatter (slow) +axes[0].scatter(df['x'], df['y'], alpha=0.1, s=1) +axes[0].set_title('Scatter (all points)') + +# Hexbin (fast) +hb = axes[1].hexbin(df['x'], df['y'], gridsize=50, cmap='viridis', mincnt=1) +axes[1].set_title('Hexbin Aggregation') +plt.colorbar(hb, ax=axes[1], label='Count') + +plt.tight_layout() +``` + +## Interactive Elements for Notebooks + +### Adjustable Parameters + +```python +from ipywidgets import interact, FloatSlider + +@interact(bandwidth=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0)) +def plot_kde(bandwidth): + plt.figure(figsize=(10, 6)) + sns.kdeplot(data=df, x='value', hue='category', + bw_adjust=bandwidth, fill=True) + plt.title(f'KDE with bandwidth adjustment = {bandwidth}') + plt.show() +``` + +### Dynamic Filtering + +```python +from ipywidgets import interact, SelectMultiple + +categories = df['category'].unique().tolist() + +@interact(selected=SelectMultiple(options=categories, value=[categories[0]])) +def filtered_plot(selected): + filtered_df = df[df['category'].isin(selected)] + + fig, ax = plt.subplots(figsize=(10, 6)) + sns.violinplot(data=filtered_df, x='category', y='value', ax=ax) + ax.set_title(f'Showing {len(selected)} categories') + plt.show() +``` diff --git a/scientific-packages/seaborn/references/function_reference.md b/scientific-packages/seaborn/references/function_reference.md new file mode 100644 index 0000000..1393918 --- /dev/null +++ b/scientific-packages/seaborn/references/function_reference.md @@ -0,0 +1,770 @@ +# Seaborn Function Reference + +This document provides a comprehensive reference for all major seaborn functions, organized by category. + +## Relational Plots + +### scatterplot() + +**Purpose:** Create a scatter plot with points representing individual observations. + +**Key Parameters:** +- `data` - DataFrame, array, or dict of arrays +- `x, y` - Variables for x and y axes +- `hue` - Grouping variable for color encoding +- `size` - Grouping variable for size encoding +- `style` - Grouping variable for marker style +- `palette` - Color palette name or list +- `hue_order` - Order for categorical hue levels +- `hue_norm` - Normalization for numeric hue (tuple or Normalize object) +- `sizes` - Size range for size encoding (tuple or dict) +- `size_order` - Order for categorical size levels +- `size_norm` - Normalization for numeric size +- `markers` - Marker style(s) (string, list, or dict) +- `style_order` - Order for categorical style levels +- `legend` - How to draw legend: "auto", "brief", "full", or False +- `ax` - Matplotlib axes to plot on + +**Example:** +```python +sns.scatterplot(data=df, x='height', y='weight', + hue='gender', size='age', style='smoker', + palette='Set2', sizes=(20, 200)) +``` + +### lineplot() + +**Purpose:** Draw a line plot with automatic aggregation and confidence intervals for repeated measures. + +**Key Parameters:** +- `data` - DataFrame, array, or dict of arrays +- `x, y` - Variables for x and y axes +- `hue` - Grouping variable for color encoding +- `size` - Grouping variable for line width +- `style` - Grouping variable for line style (dashes) +- `units` - Grouping variable for sampling units (no aggregation within units) +- `estimator` - Function for aggregating across observations (default: mean) +- `errorbar` - Method for error bars: "sd", "se", "pi", ("ci", level), ("pi", level), or None +- `n_boot` - Number of bootstrap iterations for CI computation +- `seed` - Random seed for reproducible bootstrapping +- `sort` - Sort data before plotting +- `err_style` - "band" or "bars" for error representation +- `err_kws` - Additional parameters for error representation +- `markers` - Marker style(s) for emphasizing data points +- `dashes` - Dash style(s) for lines +- `legend` - How to draw legend +- `ax` - Matplotlib axes to plot on + +**Example:** +```python +sns.lineplot(data=timeseries, x='time', y='signal', + hue='condition', style='subject', + errorbar=('ci', 95), markers=True) +``` + +### relplot() + +**Purpose:** Figure-level interface for drawing relational plots (scatter or line) onto a FacetGrid. + +**Key Parameters:** +All parameters from `scatterplot()` and `lineplot()`, plus: +- `kind` - "scatter" or "line" +- `col` - Categorical variable for column facets +- `row` - Categorical variable for row facets +- `col_wrap` - Wrap columns after this many columns +- `col_order` - Order for column facet levels +- `row_order` - Order for row facet levels +- `height` - Height of each facet in inches +- `aspect` - Aspect ratio (width = height * aspect) +- `facet_kws` - Additional parameters for FacetGrid + +**Example:** +```python +sns.relplot(data=df, x='time', y='measurement', + hue='treatment', style='batch', + col='cell_line', row='timepoint', + kind='line', height=3, aspect=1.5) +``` + +## Distribution Plots + +### histplot() + +**Purpose:** Plot univariate or bivariate histograms with flexible binning. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (y optional for bivariate) +- `hue` - Grouping variable +- `weights` - Variable for weighting observations +- `stat` - Aggregate statistic: "count", "frequency", "probability", "percent", "density" +- `bins` - Number of bins, bin edges, or method ("auto", "fd", "doane", "scott", "stone", "rice", "sturges", "sqrt") +- `binwidth` - Width of bins (overrides bins) +- `binrange` - Range for binning (tuple) +- `discrete` - Treat x as discrete (centers bars on values) +- `cumulative` - Compute cumulative distribution +- `common_bins` - Use same bins for all hue levels +- `common_norm` - Normalize across hue levels +- `multiple` - How to handle hue: "layer", "dodge", "stack", "fill" +- `element` - Visual element: "bars", "step", "poly" +- `fill` - Fill bars/elements +- `shrink` - Scale bar width (for multiple="dodge") +- `kde` - Overlay KDE estimate +- `kde_kws` - Parameters for KDE +- `line_kws` - Parameters for step/poly elements +- `thresh` - Minimum count threshold for bins +- `pthresh` - Minimum probability threshold +- `pmax` - Maximum probability for color scaling +- `log_scale` - Log scale for axis (bool or base) +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.histplot(data=df, x='measurement', hue='condition', + stat='density', bins=30, kde=True, + multiple='layer', alpha=0.5) +``` + +### kdeplot() + +**Purpose:** Plot univariate or bivariate kernel density estimates. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (y optional for bivariate) +- `hue` - Grouping variable +- `weights` - Variable for weighting observations +- `palette` - Color palette +- `hue_order` - Order for hue levels +- `hue_norm` - Normalization for numeric hue +- `multiple` - How to handle hue: "layer", "stack", "fill" +- `common_norm` - Normalize across hue levels +- `common_grid` - Use same grid for all hue levels +- `cumulative` - Compute cumulative distribution +- `bw_method` - Method for bandwidth: "scott", "silverman", or scalar +- `bw_adjust` - Bandwidth multiplier (higher = smoother) +- `log_scale` - Log scale for axis +- `levels` - Number or values for contour levels (bivariate) +- `thresh` - Minimum density threshold for contours +- `gridsize` - Grid resolution +- `cut` - Extension beyond data extremes (in bandwidth units) +- `clip` - Data range for curve (tuple) +- `fill` - Fill area under curve/contours +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +# Univariate +sns.kdeplot(data=df, x='measurement', hue='condition', + fill=True, common_norm=False, bw_adjust=1.5) + +# Bivariate +sns.kdeplot(data=df, x='var1', y='var2', + fill=True, levels=10, thresh=0.05) +``` + +### ecdfplot() + +**Purpose:** Plot empirical cumulative distribution functions. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (specify one) +- `hue` - Grouping variable +- `weights` - Variable for weighting observations +- `stat` - "proportion" or "count" +- `complementary` - Plot complementary CDF (1 - ECDF) +- `palette` - Color palette +- `hue_order` - Order for hue levels +- `hue_norm` - Normalization for numeric hue +- `log_scale` - Log scale for axis +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.ecdfplot(data=df, x='response_time', hue='treatment', + stat='proportion', complementary=False) +``` + +### rugplot() + +**Purpose:** Plot tick marks showing individual observations along an axis. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variable (specify one) +- `hue` - Grouping variable +- `height` - Height of ticks (proportion of axis) +- `expand_margins` - Add margin space for rug +- `palette` - Color palette +- `hue_order` - Order for hue levels +- `hue_norm` - Normalization for numeric hue +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.rugplot(data=df, x='value', hue='category', height=0.05) +``` + +### displot() + +**Purpose:** Figure-level interface for distribution plots onto a FacetGrid. + +**Key Parameters:** +All parameters from `histplot()`, `kdeplot()`, and `ecdfplot()`, plus: +- `kind` - "hist", "kde", "ecdf" +- `rug` - Add rug plot on marginal axes +- `rug_kws` - Parameters for rug plot +- `col` - Categorical variable for column facets +- `row` - Categorical variable for row facets +- `col_wrap` - Wrap columns +- `col_order` - Order for column facets +- `row_order` - Order for row facets +- `height` - Height of each facet +- `aspect` - Aspect ratio +- `facet_kws` - Additional parameters for FacetGrid + +**Example:** +```python +sns.displot(data=df, x='measurement', hue='treatment', + col='timepoint', kind='kde', fill=True, + height=3, aspect=1.5, rug=True) +``` + +### jointplot() + +**Purpose:** Draw a bivariate plot with marginal univariate plots. + +**Key Parameters:** +- `data` - DataFrame +- `x, y` - Variables for x and y axes +- `hue` - Grouping variable +- `kind` - "scatter", "kde", "hist", "hex", "reg", "resid" +- `height` - Size of the figure (square) +- `ratio` - Ratio of joint to marginal axes +- `space` - Space between joint and marginal axes +- `dropna` - Drop missing values +- `xlim, ylim` - Axis limits (tuples) +- `marginal_ticks` - Show ticks on marginal axes +- `joint_kws` - Parameters for joint plot +- `marginal_kws` - Parameters for marginal plots +- `hue_order` - Order for hue levels +- `palette` - Color palette + +**Example:** +```python +sns.jointplot(data=df, x='var1', y='var2', hue='group', + kind='scatter', height=6, ratio=4, + joint_kws={'alpha': 0.5}) +``` + +### pairplot() + +**Purpose:** Plot pairwise relationships in a dataset. + +**Key Parameters:** +- `data` - DataFrame +- `hue` - Grouping variable for color encoding +- `hue_order` - Order for hue levels +- `palette` - Color palette +- `vars` - Variables to plot (default: all numeric) +- `x_vars, y_vars` - Variables for x and y axes (non-square grid) +- `kind` - "scatter", "kde", "hist", "reg" +- `diag_kind` - "auto", "hist", "kde", None +- `markers` - Marker style(s) +- `height` - Height of each facet +- `aspect` - Aspect ratio +- `corner` - Plot only lower triangle +- `dropna` - Drop missing values +- `plot_kws` - Parameters for non-diagonal plots +- `diag_kws` - Parameters for diagonal plots +- `grid_kws` - Parameters for PairGrid + +**Example:** +```python +sns.pairplot(data=df, hue='species', palette='Set2', + vars=['sepal_length', 'sepal_width', 'petal_length'], + corner=True, height=2.5) +``` + +## Categorical Plots + +### stripplot() + +**Purpose:** Draw a categorical scatterplot with jittered points. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (one categorical, one continuous) +- `hue` - Grouping variable +- `order` - Order for categorical levels +- `hue_order` - Order for hue levels +- `jitter` - Amount of jitter: True, float, or False +- `dodge` - Separate hue levels side-by-side +- `orient` - "v" or "h" (usually inferred) +- `color` - Single color for all elements +- `palette` - Color palette +- `size` - Marker size +- `edgecolor` - Marker edge color +- `linewidth` - Marker edge width +- `native_scale` - Use numeric scale for categorical axis +- `formatter` - Formatter for categorical axis +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.stripplot(data=df, x='day', y='total_bill', + hue='sex', dodge=True, jitter=0.2) +``` + +### swarmplot() + +**Purpose:** Draw a categorical scatterplot with non-overlapping points. + +**Key Parameters:** +Same as `stripplot()`, except: +- No `jitter` parameter +- `size` - Marker size (important for avoiding overlap) +- `warn_thresh` - Threshold for warning about too many points (default: 0.05) + +**Note:** Computationally intensive for large datasets. Use stripplot for >1000 points. + +**Example:** +```python +sns.swarmplot(data=df, x='day', y='total_bill', + hue='time', dodge=True, size=5) +``` + +### boxplot() + +**Purpose:** Draw a box plot showing quartiles and outliers. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (one categorical, one continuous) +- `hue` - Grouping variable +- `order` - Order for categorical levels +- `hue_order` - Order for hue levels +- `orient` - "v" or "h" +- `color` - Single color for boxes +- `palette` - Color palette +- `saturation` - Color saturation intensity +- `width` - Width of boxes +- `dodge` - Separate hue levels side-by-side +- `fliersize` - Size of outlier markers +- `linewidth` - Box line width +- `whis` - IQR multiplier for whiskers (default: 1.5) +- `notch` - Draw notched boxes +- `showcaps` - Show whisker caps +- `showmeans` - Show mean value +- `meanprops` - Properties for mean marker +- `boxprops` - Properties for boxes +- `whiskerprops` - Properties for whiskers +- `capprops` - Properties for caps +- `flierprops` - Properties for outliers +- `medianprops` - Properties for median line +- `native_scale` - Use numeric scale +- `formatter` - Formatter for categorical axis +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.boxplot(data=df, x='day', y='total_bill', + hue='smoker', palette='Set3', + showmeans=True, notch=True) +``` + +### violinplot() + +**Purpose:** Draw a violin plot combining boxplot and KDE. + +**Key Parameters:** +Same as `boxplot()`, plus: +- `bw_method` - KDE bandwidth method +- `bw_adjust` - KDE bandwidth multiplier +- `cut` - KDE extension beyond extremes +- `density_norm` - "area", "count", "width" +- `inner` - "box", "quartile", "point", "stick", None +- `split` - Split violins for hue comparison +- `scale` - Scaling method: "area", "count", "width" +- `scale_hue` - Scale across hue levels +- `gridsize` - KDE grid resolution + +**Example:** +```python +sns.violinplot(data=df, x='day', y='total_bill', + hue='sex', split=True, inner='quartile', + palette='muted') +``` + +### boxenplot() + +**Purpose:** Draw enhanced box plot for larger datasets showing more quantiles. + +**Key Parameters:** +Same as `boxplot()`, plus: +- `k_depth` - "tukey", "proportion", "trustworthy", "full", or int +- `outlier_prop` - Proportion of data as outliers +- `trust_alpha` - Alpha for trustworthy depth +- `showfliers` - Show outlier points + +**Example:** +```python +sns.boxenplot(data=df, x='day', y='total_bill', + hue='time', palette='Set2') +``` + +### barplot() + +**Purpose:** Draw a bar plot with error bars showing statistical estimates. + +**Key Parameters:** +- `data` - DataFrame, array, or dict +- `x, y` - Variables (one categorical, one continuous) +- `hue` - Grouping variable +- `order` - Order for categorical levels +- `hue_order` - Order for hue levels +- `estimator` - Aggregation function (default: mean) +- `errorbar` - Error representation: "sd", "se", "pi", ("ci", level), ("pi", level), or None +- `n_boot` - Bootstrap iterations +- `seed` - Random seed +- `units` - Identifier for sampling units +- `weights` - Observation weights +- `orient` - "v" or "h" +- `color` - Single bar color +- `palette` - Color palette +- `saturation` - Color saturation +- `width` - Bar width +- `dodge` - Separate hue levels side-by-side +- `errcolor` - Error bar color +- `errwidth` - Error bar line width +- `capsize` - Error bar cap width +- `native_scale` - Use numeric scale +- `formatter` - Formatter for categorical axis +- `legend` - Whether to show legend +- `ax` - Matplotlib axes + +**Example:** +```python +sns.barplot(data=df, x='day', y='total_bill', + hue='sex', estimator='median', + errorbar=('ci', 95), capsize=0.1) +``` + +### countplot() + +**Purpose:** Show counts of observations in each categorical bin. + +**Key Parameters:** +Same as `barplot()`, but: +- Only specify one of x or y (the categorical variable) +- No estimator or errorbar (shows counts) +- `stat` - "count" or "percent" + +**Example:** +```python +sns.countplot(data=df, x='day', hue='time', + palette='pastel', dodge=True) +``` + +### pointplot() + +**Purpose:** Show point estimates and confidence intervals with connecting lines. + +**Key Parameters:** +Same as `barplot()`, plus: +- `markers` - Marker style(s) +- `linestyles` - Line style(s) +- `scale` - Scale for markers +- `join` - Connect points with lines +- `capsize` - Error bar cap width + +**Example:** +```python +sns.pointplot(data=df, x='time', y='total_bill', + hue='sex', markers=['o', 's'], + linestyles=['-', '--'], capsize=0.1) +``` + +### catplot() + +**Purpose:** Figure-level interface for categorical plots onto a FacetGrid. + +**Key Parameters:** +All parameters from categorical plots, plus: +- `kind` - "strip", "swarm", "box", "violin", "boxen", "bar", "point", "count" +- `col` - Categorical variable for column facets +- `row` - Categorical variable for row facets +- `col_wrap` - Wrap columns +- `col_order` - Order for column facets +- `row_order` - Order for row facets +- `height` - Height of each facet +- `aspect` - Aspect ratio +- `sharex, sharey` - Share axes across facets +- `legend` - Whether to show legend +- `legend_out` - Place legend outside figure +- `facet_kws` - Additional FacetGrid parameters + +**Example:** +```python +sns.catplot(data=df, x='day', y='total_bill', + hue='smoker', col='time', + kind='violin', split=True, + height=4, aspect=0.8) +``` + +## Regression Plots + +### regplot() + +**Purpose:** Plot data and a linear regression fit. + +**Key Parameters:** +- `data` - DataFrame +- `x, y` - Variables or data vectors +- `x_estimator` - Apply estimator to x bins +- `x_bins` - Bin x for estimator +- `x_ci` - CI for binned estimates +- `scatter` - Show scatter points +- `fit_reg` - Plot regression line +- `ci` - CI for regression estimate (int or None) +- `n_boot` - Bootstrap iterations for CI +- `units` - Identifier for sampling units +- `seed` - Random seed +- `order` - Polynomial regression order +- `logistic` - Fit logistic regression +- `lowess` - Fit lowess smoother +- `robust` - Fit robust regression +- `logx` - Log-transform x +- `x_partial, y_partial` - Partial regression (regress out variables) +- `truncate` - Limit regression line to data range +- `dropna` - Drop missing values +- `x_jitter, y_jitter` - Add jitter to data +- `label` - Label for legend +- `color` - Color for all elements +- `marker` - Marker style +- `scatter_kws` - Parameters for scatter +- `line_kws` - Parameters for regression line +- `ax` - Matplotlib axes + +**Example:** +```python +sns.regplot(data=df, x='total_bill', y='tip', + order=2, robust=True, ci=95, + scatter_kws={'alpha': 0.5}) +``` + +### lmplot() + +**Purpose:** Figure-level interface for regression plots onto a FacetGrid. + +**Key Parameters:** +All parameters from `regplot()`, plus: +- `hue` - Grouping variable +- `col` - Column facets +- `row` - Row facets +- `palette` - Color palette +- `col_wrap` - Wrap columns +- `height` - Facet height +- `aspect` - Aspect ratio +- `markers` - Marker style(s) +- `sharex, sharey` - Share axes +- `hue_order` - Order for hue levels +- `col_order` - Order for column facets +- `row_order` - Order for row facets +- `legend` - Whether to show legend +- `legend_out` - Place legend outside +- `facet_kws` - FacetGrid parameters + +**Example:** +```python +sns.lmplot(data=df, x='total_bill', y='tip', + hue='smoker', col='time', row='sex', + height=3, aspect=1.2, ci=None) +``` + +### residplot() + +**Purpose:** Plot residuals of a regression. + +**Key Parameters:** +Same as `regplot()`, but: +- Always plots residuals (y - predicted) vs x +- Adds horizontal line at y=0 +- `lowess` - Fit lowess smoother to residuals + +**Example:** +```python +sns.residplot(data=df, x='x', y='y', lowess=True, + scatter_kws={'alpha': 0.5}) +``` + +## Matrix Plots + +### heatmap() + +**Purpose:** Plot rectangular data as a color-encoded matrix. + +**Key Parameters:** +- `data` - 2D array-like data +- `vmin, vmax` - Anchor values for colormap +- `cmap` - Colormap name or object +- `center` - Value at colormap center +- `robust` - Use robust quantiles for colormap range +- `annot` - Annotate cells: True, False, or array +- `fmt` - Format string for annotations (e.g., ".2f") +- `annot_kws` - Parameters for annotations +- `linewidths` - Width of cell borders +- `linecolor` - Color of cell borders +- `cbar` - Draw colorbar +- `cbar_kws` - Colorbar parameters +- `cbar_ax` - Axes for colorbar +- `square` - Force square cells +- `xticklabels, yticklabels` - Tick labels (True, False, int, or list) +- `mask` - Boolean array to mask cells +- `ax` - Matplotlib axes + +**Example:** +```python +# Correlation matrix +corr = df.corr() +mask = np.triu(np.ones_like(corr, dtype=bool)) +sns.heatmap(corr, mask=mask, annot=True, fmt='.2f', + cmap='coolwarm', center=0, square=True, + linewidths=1, cbar_kws={'shrink': 0.8}) +``` + +### clustermap() + +**Purpose:** Plot a hierarchically-clustered heatmap. + +**Key Parameters:** +All parameters from `heatmap()`, plus: +- `pivot_kws` - Parameters for pivoting (if needed) +- `method` - Linkage method: "single", "complete", "average", "weighted", "centroid", "median", "ward" +- `metric` - Distance metric for clustering +- `standard_scale` - Standardize data: 0 (rows), 1 (columns), or None +- `z_score` - Z-score normalize data: 0 (rows), 1 (columns), or None +- `row_cluster, col_cluster` - Cluster rows/columns +- `row_linkage, col_linkage` - Precomputed linkage matrices +- `row_colors, col_colors` - Additional color annotations +- `dendrogram_ratio` - Ratio of dendrogram to heatmap +- `colors_ratio` - Ratio of color annotations to heatmap +- `cbar_pos` - Colorbar position (tuple: x, y, width, height) +- `tree_kws` - Parameters for dendrogram +- `figsize` - Figure size + +**Example:** +```python +sns.clustermap(data, method='average', metric='euclidean', + z_score=0, cmap='viridis', + row_colors=row_colors, col_colors=col_colors, + figsize=(12, 12), dendrogram_ratio=0.1) +``` + +## Multi-Plot Grids + +### FacetGrid + +**Purpose:** Multi-plot grid for plotting conditional relationships. + +**Initialization:** +```python +g = sns.FacetGrid(data, row=None, col=None, hue=None, + col_wrap=None, sharex=True, sharey=True, + height=3, aspect=1, palette=None, + row_order=None, col_order=None, hue_order=None, + hue_kws=None, dropna=False, legend_out=True, + despine=True, margin_titles=False, + xlim=None, ylim=None, subplot_kws=None, + gridspec_kws=None) +``` + +**Methods:** +- `map(func, *args, **kwargs)` - Apply function to each facet +- `map_dataframe(func, *args, **kwargs)` - Apply function with full DataFrame +- `set_axis_labels(x_var, y_var)` - Set axis labels +- `set_titles(template, **kwargs)` - Set subplot titles +- `set(kwargs)` - Set attributes on all axes +- `add_legend(legend_data, title, label_order, **kwargs)` - Add legend +- `savefig(*args, **kwargs)` - Save figure + +**Example:** +```python +g = sns.FacetGrid(df, col='time', row='sex', hue='smoker', + height=3, aspect=1.5, margin_titles=True) +g.map(sns.scatterplot, 'total_bill', 'tip', alpha=0.7) +g.add_legend() +g.set_axis_labels('Total Bill ($)', 'Tip ($)') +g.set_titles('{col_name} | {row_name}') +``` + +### PairGrid + +**Purpose:** Grid for plotting pairwise relationships in a dataset. + +**Initialization:** +```python +g = sns.PairGrid(data, hue=None, vars=None, + x_vars=None, y_vars=None, + hue_order=None, palette=None, + hue_kws=None, corner=False, + diag_sharey=True, height=2.5, + aspect=1, layout_pad=0.5, + despine=True, dropna=False) +``` + +**Methods:** +- `map(func, **kwargs)` - Apply function to all subplots +- `map_diag(func, **kwargs)` - Apply to diagonal +- `map_offdiag(func, **kwargs)` - Apply to off-diagonal +- `map_upper(func, **kwargs)` - Apply to upper triangle +- `map_lower(func, **kwargs)` - Apply to lower triangle +- `add_legend(legend_data, **kwargs)` - Add legend +- `savefig(*args, **kwargs)` - Save figure + +**Example:** +```python +g = sns.PairGrid(df, hue='species', vars=['a', 'b', 'c', 'd'], + corner=True, height=2.5) +g.map_upper(sns.scatterplot, alpha=0.5) +g.map_lower(sns.kdeplot) +g.map_diag(sns.histplot, kde=True) +g.add_legend() +``` + +### JointGrid + +**Purpose:** Grid for bivariate plot with marginal univariate plots. + +**Initialization:** +```python +g = sns.JointGrid(data=None, x=None, y=None, hue=None, + height=6, ratio=5, space=0.2, + dropna=False, xlim=None, ylim=None, + marginal_ticks=False, hue_order=None, + palette=None) +``` + +**Methods:** +- `plot(joint_func, marginal_func, **kwargs)` - Plot both joint and marginals +- `plot_joint(func, **kwargs)` - Plot joint distribution +- `plot_marginals(func, **kwargs)` - Plot marginal distributions +- `refline(x, y, **kwargs)` - Add reference line +- `set_axis_labels(xlabel, ylabel, **kwargs)` - Set axis labels +- `savefig(*args, **kwargs)` - Save figure + +**Example:** +```python +g = sns.JointGrid(data=df, x='x', y='y', hue='group', + height=6, ratio=5, space=0.2) +g.plot_joint(sns.scatterplot, alpha=0.5) +g.plot_marginals(sns.histplot, kde=True) +g.set_axis_labels('Variable X', 'Variable Y') +``` diff --git a/scientific-packages/seaborn/references/objects_interface.md b/scientific-packages/seaborn/references/objects_interface.md new file mode 100644 index 0000000..3cd1be5 --- /dev/null +++ b/scientific-packages/seaborn/references/objects_interface.md @@ -0,0 +1,964 @@ +# Seaborn Objects Interface + +The `seaborn.objects` interface provides a modern, declarative API for building visualizations through composition. This guide covers the complete objects interface introduced in seaborn 0.12+. + +## Core Concept + +The objects interface separates **what you want to show** (data and mappings) from **how to show it** (marks, stats, and moves). Build plots by: + +1. Creating a `Plot` object with data and aesthetic mappings +2. Adding layers with `.add()` combining marks and statistical transformations +3. Customizing with `.scale()`, `.label()`, `.limit()`, `.theme()`, etc. +4. Rendering with `.show()` or `.save()` + +## Basic Usage + +```python +from seaborn import objects as so +import pandas as pd + +# Create plot with data and mappings +p = so.Plot(data=df, x='x_var', y='y_var') + +# Add mark (visual representation) +p = p.add(so.Dot()) + +# Display (automatic in Jupyter) +p.show() +``` + +## Plot Class + +The `Plot` class is the foundation of the objects interface. + +### Initialization + +```python +so.Plot(data=None, x=None, y=None, color=None, alpha=None, + fill=None, fillalpha=None, fillcolor=None, marker=None, + pointsize=None, stroke=None, text=None, **variables) +``` + +**Parameters:** +- `data` - DataFrame or dict of data vectors +- `x, y` - Variables for position +- `color` - Variable for color encoding +- `alpha` - Variable for transparency +- `marker` - Variable for marker shape +- `pointsize` - Variable for point size +- `stroke` - Variable for line width +- `text` - Variable for text labels +- `**variables` - Additional mappings using property names + +**Examples:** +```python +# Basic mapping +so.Plot(df, x='total_bill', y='tip') + +# Multiple mappings +so.Plot(df, x='total_bill', y='tip', color='day', pointsize='size') + +# All variables in Plot +p = so.Plot(df, x='x', y='y', color='cat') +p.add(so.Dot()) # Uses all mappings + +# Some variables in add() +p = so.Plot(df, x='x', y='y') +p.add(so.Dot(), color='cat') # Only this layer uses color +``` + +### Methods + +#### add() + +Add a layer to the plot with mark and optional stat/move. + +```python +Plot.add(mark, *transforms, orient=None, legend=True, data=None, + **variables) +``` + +**Parameters:** +- `mark` - Mark object defining visual representation +- `*transforms` - Stat and/or Move objects for data transformation +- `orient` - "x", "y", or "v"/"h" for orientation +- `legend` - Include in legend (True/False) +- `data` - Override data for this layer +- `**variables` - Override or add variable mappings + +**Examples:** +```python +# Simple mark +p.add(so.Dot()) + +# Mark with stat +p.add(so.Line(), so.PolyFit(order=2)) + +# Mark with multiple transforms +p.add(so.Bar(), so.Agg(), so.Dodge()) + +# Layer-specific mappings +p.add(so.Dot(), color='category') +p.add(so.Line(), so.Agg(), color='category') + +# Layer-specific data +p.add(so.Dot()) +p.add(so.Line(), data=summary_df) +``` + +#### facet() + +Create subplots from categorical variables. + +```python +Plot.facet(col=None, row=None, order=None, wrap=None) +``` + +**Parameters:** +- `col` - Variable for column facets +- `row` - Variable for row facets +- `order` - Dict with facet orders (keys: variable names) +- `wrap` - Wrap columns after this many + +**Example:** +```python +p.facet(col='time', row='sex') +p.facet(col='category', wrap=3) +p.facet(col='day', order={'day': ['Thur', 'Fri', 'Sat', 'Sun']}) +``` + +#### pair() + +Create pairwise subplots for multiple variables. + +```python +Plot.pair(x=None, y=None, wrap=None, cross=True) +``` + +**Parameters:** +- `x` - Variables for x-axis pairings +- `y` - Variables for y-axis pairings (if None, uses x) +- `wrap` - Wrap after this many columns +- `cross` - Include all x/y combinations (vs. only diagonal) + +**Example:** +```python +# Pairs of all variables +p = so.Plot(df).pair(x=['a', 'b', 'c']) +p.add(so.Dot()) + +# Rectangular grid +p = so.Plot(df).pair(x=['a', 'b'], y=['c', 'd']) +p.add(so.Dot(), alpha=0.5) +``` + +#### scale() + +Customize how data maps to visual properties. + +```python +Plot.scale(**scales) +``` + +**Parameters:** Keyword arguments with property names and Scale objects + +**Example:** +```python +p.scale( + x=so.Continuous().tick(every=5), + y=so.Continuous().label(like='{x:.1f}'), + color=so.Nominal(['#1f77b4', '#ff7f0e', '#2ca02c']), + pointsize=(5, 10) # Shorthand for range +) +``` + +#### limit() + +Set axis limits. + +```python +Plot.limit(x=None, y=None) +``` + +**Parameters:** +- `x` - Tuple of (min, max) for x-axis +- `y` - Tuple of (min, max) for y-axis + +**Example:** +```python +p.limit(x=(0, 100), y=(0, 50)) +``` + +#### label() + +Set axis labels and titles. + +```python +Plot.label(x=None, y=None, color=None, title=None, **labels) +``` + +**Parameters:** Keyword arguments with property names and label strings + +**Example:** +```python +p.label( + x='Total Bill ($)', + y='Tip Amount ($)', + color='Day of Week', + title='Restaurant Tips Analysis' +) +``` + +#### theme() + +Apply matplotlib style settings. + +```python +Plot.theme(config, **kwargs) +``` + +**Parameters:** +- `config` - Dict of rcParams or seaborn theme dict +- `**kwargs` - Individual rcParams + +**Example:** +```python +# Seaborn theme +p.theme({**sns.axes_style('whitegrid'), **sns.plotting_context('talk')}) + +# Custom rcParams +p.theme({'axes.facecolor': 'white', 'axes.grid': True}) + +# Individual parameters +p.theme(axes_facecolor='white', font_scale=1.2) +``` + +#### layout() + +Configure subplot layout. + +```python +Plot.layout(size=None, extent=None, engine=None) +``` + +**Parameters:** +- `size` - (width, height) in inches +- `extent` - (left, bottom, right, top) for subplots +- `engine` - "tight", "constrained", or None + +**Example:** +```python +p.layout(size=(10, 6), engine='constrained') +``` + +#### share() + +Control axis sharing across facets. + +```python +Plot.share(x=None, y=None) +``` + +**Parameters:** +- `x` - Share x-axis: True, False, or "col"/"row" +- `y` - Share y-axis: True, False, or "col"/"row" + +**Example:** +```python +p.share(x=True, y=False) # Share x across all, independent y +p.share(x='col') # Share x within columns only +``` + +#### on() + +Plot on existing matplotlib figure or axes. + +```python +Plot.on(target) +``` + +**Parameters:** +- `target` - matplotlib Figure or Axes object + +**Example:** +```python +import matplotlib.pyplot as plt + +fig, axes = plt.subplots(2, 2, figsize=(10, 10)) +so.Plot(df, x='x', y='y').add(so.Dot()).on(axes[0, 0]) +so.Plot(df, x='x', y='z').add(so.Line()).on(axes[0, 1]) +``` + +#### show() + +Render and display the plot. + +```python +Plot.show(**kwargs) +``` + +**Parameters:** Passed to `matplotlib.pyplot.show()` + +#### save() + +Save the plot to file. + +```python +Plot.save(filename, **kwargs) +``` + +**Parameters:** +- `filename` - Output filename +- `**kwargs` - Passed to `matplotlib.figure.Figure.savefig()` + +**Example:** +```python +p.save('plot.png', dpi=300, bbox_inches='tight') +p.save('plot.pdf') +``` + +## Mark Objects + +Marks define how data is visually represented. + +### Dot + +Points/markers for individual observations. + +```python +so.Dot(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Fill color +- `alpha` - Transparency +- `fillcolor` - Alternate color property +- `fillalpha` - Alternate alpha property +- `edgecolor` - Edge color +- `edgealpha` - Edge transparency +- `edgewidth` - Edge line width +- `marker` - Marker style +- `pointsize` - Marker size +- `stroke` - Edge width + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Dot(color='blue', pointsize=10)) +so.Plot(df, x='x', y='y', color='cat').add(so.Dot(alpha=0.5)) +``` + +### Line + +Lines connecting observations. + +```python +so.Line(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Line color +- `alpha` - Transparency +- `linewidth` - Line width +- `linestyle` - Line style ("-", "--", "-.", ":") +- `marker` - Marker at data points +- `pointsize` - Marker size +- `edgecolor` - Marker edge color +- `edgewidth` - Marker edge width + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Line()) +so.Plot(df, x='x', y='y', color='cat').add(so.Line(linewidth=2)) +``` + +### Path + +Like Line but connects points in data order (not sorted by x). + +```python +so.Path(artist_kws=None, **kwargs) +``` + +Properties same as `Line`. + +**Example:** +```python +# For trajectories, loops, etc. +so.Plot(trajectory_df, x='x', y='y').add(so.Path()) +``` + +### Bar + +Rectangular bars. + +```python +so.Bar(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Fill color +- `alpha` - Transparency +- `edgecolor` - Edge color +- `edgealpha` - Edge transparency +- `edgewidth` - Edge line width +- `width` - Bar width (data units) + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Bar()) +so.Plot(df, x='x', y='y').add(so.Bar(color='#1f77b4', width=0.5)) +``` + +### Bars + +Multiple bars (for aggregated data with error bars). + +```python +so.Bars(artist_kws=None, **kwargs) +``` + +Properties same as `Bar`. Used with `Agg()` or `Est()` stats. + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Bars(), so.Agg()) +``` + +### Area + +Filled area between line and baseline. + +```python +so.Area(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Fill color +- `alpha` - Transparency +- `edgecolor` - Edge color +- `edgealpha` - Edge transparency +- `edgewidth` - Edge line width +- `baseline` - Baseline value (default: 0) + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Area(alpha=0.3)) +so.Plot(df, x='x', y='y', color='cat').add(so.Area()) +``` + +### Band + +Filled band between two lines (for ranges/intervals). + +```python +so.Band(artist_kws=None, **kwargs) +``` + +Properties same as `Area`. Requires `ymin` and `ymax` mappings or used with `Est()` stat. + +**Example:** +```python +so.Plot(df, x='x', ymin='lower', ymax='upper').add(so.Band()) +so.Plot(df, x='x', y='y').add(so.Band(), so.Est()) +``` + +### Range + +Line with markers at endpoints (for ranges). + +```python +so.Range(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Line and marker color +- `alpha` - Transparency +- `linewidth` - Line width +- `marker` - Marker style at endpoints +- `pointsize` - Marker size +- `edgewidth` - Marker edge width + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Range(), so.Est()) +``` + +### Dash + +Short horizontal/vertical lines (for distribution marks). + +```python +so.Dash(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Line color +- `alpha` - Transparency +- `linewidth` - Line width +- `width` - Dash length (data units) + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Dash()) +``` + +### Text + +Text labels at data points. + +```python +so.Text(artist_kws=None, **kwargs) +``` + +**Properties:** +- `color` - Text color +- `alpha` - Transparency +- `fontsize` - Font size +- `halign` - Horizontal alignment: "left", "center", "right" +- `valign` - Vertical alignment: "bottom", "center", "top" +- `offset` - (x, y) offset from point + +Requires `text` mapping. + +**Example:** +```python +so.Plot(df, x='x', y='y', text='label').add(so.Text()) +so.Plot(df, x='x', y='y', text='value').add(so.Text(fontsize=10, offset=(0, 5))) +``` + +## Stat Objects + +Stats transform data before rendering. Compose with marks in `.add()`. + +### Agg + +Aggregate observations by group. + +```python +so.Agg(func='mean') +``` + +**Parameters:** +- `func` - Aggregation function: "mean", "median", "sum", "min", "max", "count", or callable + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Bar(), so.Agg('mean')) +so.Plot(df, x='x', y='y', color='group').add(so.Line(), so.Agg('median')) +``` + +### Est + +Estimate central tendency with error intervals. + +```python +so.Est(func='mean', errorbar=('ci', 95), n_boot=1000, seed=None) +``` + +**Parameters:** +- `func` - Estimator: "mean", "median", "sum", or callable +- `errorbar` - Error representation: + - `("ci", level)` - Confidence interval via bootstrap + - `("pi", level)` - Percentile interval + - `("se", scale)` - Standard error scaled by factor + - `"sd"` - Standard deviation +- `n_boot` - Bootstrap iterations +- `seed` - Random seed + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Bar(), so.Est()) +so.Plot(df, x='x', y='y').add(so.Line(), so.Est(errorbar='sd')) +so.Plot(df, x='x', y='y').add(so.Line(), so.Est(errorbar=('ci', 95))) +so.Plot(df, x='x', y='y').add(so.Band(), so.Est()) +``` + +### Hist + +Bin observations and count/aggregate. + +```python +so.Hist(stat='count', bins='auto', binwidth=None, binrange=None, + common_norm=True, common_bins=True, cumulative=False) +``` + +**Parameters:** +- `stat` - "count", "density", "probability", "percent", "frequency" +- `bins` - Number of bins, bin method, or edges +- `binwidth` - Width of bins +- `binrange` - (min, max) range for binning +- `common_norm` - Normalize across groups together +- `common_bins` - Use same bins for all groups +- `cumulative` - Cumulative histogram + +**Example:** +```python +so.Plot(df, x='value').add(so.Bars(), so.Hist()) +so.Plot(df, x='value').add(so.Bars(), so.Hist(bins=20, stat='density')) +so.Plot(df, x='value', color='group').add(so.Area(), so.Hist(cumulative=True)) +``` + +### KDE + +Kernel density estimate. + +```python +so.KDE(bw_method='scott', bw_adjust=1, gridsize=200, + cut=3, cumulative=False) +``` + +**Parameters:** +- `bw_method` - Bandwidth method: "scott", "silverman", or scalar +- `bw_adjust` - Bandwidth multiplier +- `gridsize` - Resolution of density curve +- `cut` - Extension beyond data range (in bandwidth units) +- `cumulative` - Cumulative density + +**Example:** +```python +so.Plot(df, x='value').add(so.Line(), so.KDE()) +so.Plot(df, x='value', color='group').add(so.Area(alpha=0.5), so.KDE()) +so.Plot(df, x='x', y='y').add(so.Line(), so.KDE(bw_adjust=0.5)) +``` + +### Count + +Count observations per group. + +```python +so.Count() +``` + +**Example:** +```python +so.Plot(df, x='category').add(so.Bar(), so.Count()) +``` + +### PolyFit + +Polynomial regression fit. + +```python +so.PolyFit(order=1) +``` + +**Parameters:** +- `order` - Polynomial order (1 = linear, 2 = quadratic, etc.) + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Dot()) +so.Plot(df, x='x', y='y').add(so.Line(), so.PolyFit(order=2)) +``` + +### Perc + +Compute percentiles. + +```python +so.Perc(k=5, method='linear') +``` + +**Parameters:** +- `k` - Number of percentile intervals +- `method` - Interpolation method + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Band(), so.Perc()) +``` + +## Move Objects + +Moves adjust positions to resolve overlaps or create specific layouts. + +### Dodge + +Shift positions side-by-side. + +```python +so.Dodge(empty='keep', gap=0) +``` + +**Parameters:** +- `empty` - How to handle empty groups: "keep", "drop", "fill" +- `gap` - Gap between dodged elements (proportion) + +**Example:** +```python +so.Plot(df, x='category', y='value', color='group').add(so.Bar(), so.Dodge()) +so.Plot(df, x='cat', y='val', color='hue').add(so.Dot(), so.Dodge(gap=0.1)) +``` + +### Stack + +Stack marks vertically. + +```python +so.Stack() +``` + +**Example:** +```python +so.Plot(df, x='x', y='y', color='category').add(so.Bar(), so.Stack()) +so.Plot(df, x='x', y='y', color='group').add(so.Area(), so.Stack()) +``` + +### Jitter + +Add random noise to positions. + +```python +so.Jitter(width=None, height=None, seed=None) +``` + +**Parameters:** +- `width` - Jitter in x direction (data units or proportion) +- `height` - Jitter in y direction +- `seed` - Random seed + +**Example:** +```python +so.Plot(df, x='category', y='value').add(so.Dot(), so.Jitter()) +so.Plot(df, x='cat', y='val').add(so.Dot(), so.Jitter(width=0.2)) +``` + +### Shift + +Shift positions by constant amount. + +```python +so.Shift(x=0, y=0) +``` + +**Parameters:** +- `x` - Shift in x direction (data units) +- `y` - Shift in y direction + +**Example:** +```python +so.Plot(df, x='x', y='y').add(so.Dot(), so.Shift(x=1)) +``` + +### Norm + +Normalize values. + +```python +so.Norm(func='max', where=None, by=None, percent=False) +``` + +**Parameters:** +- `func` - Normalization: "max", "sum", "area", or callable +- `where` - Apply to which axis: "x", "y", or None +- `by` - Grouping variables for separate normalization +- `percent` - Show as percentage + +**Example:** +```python +so.Plot(df, x='x', y='y', color='group').add(so.Area(), so.Norm()) +``` + +## Scale Objects + +Scales control how data values map to visual properties. + +### Continuous + +For numeric data. + +```python +so.Continuous(values=None, norm=None, trans=None) +``` + +**Methods:** +- `.tick(at=None, every=None, between=None, minor=None)` - Configure ticks +- `.label(like=None, base=None, unit=None)` - Format labels + +**Parameters:** +- `values` - Explicit value range (min, max) +- `norm` - Normalization function +- `trans` - Transformation: "log", "sqrt", "symlog", "logit", "pow10", or callable + +**Example:** +```python +p.scale( + x=so.Continuous().tick(every=10), + y=so.Continuous(trans='log').tick(at=[1, 10, 100]), + color=so.Continuous(values=(0, 1)), + pointsize=(5, 20) # Shorthand for Continuous range +) +``` + +### Nominal + +For categorical data. + +```python +so.Nominal(values=None, order=None) +``` + +**Parameters:** +- `values` - Explicit values (e.g., colors, markers) +- `order` - Category order + +**Example:** +```python +p.scale( + color=so.Nominal(['#1f77b4', '#ff7f0e', '#2ca02c']), + marker=so.Nominal(['o', 's', '^']), + x=so.Nominal(order=['Low', 'Medium', 'High']) +) +``` + +### Temporal + +For datetime data. + +```python +so.Temporal(values=None, trans=None) +``` + +**Methods:** +- `.tick(every=None, between=None)` - Configure ticks +- `.label(concise=False)` - Format labels + +**Example:** +```python +p.scale(x=so.Temporal().tick(every=('month', 1)).label(concise=True)) +``` + +## Complete Examples + +### Layered Plot with Statistics + +```python +( + so.Plot(df, x='total_bill', y='tip', color='time') + .add(so.Dot(), alpha=0.5) + .add(so.Line(), so.PolyFit(order=2)) + .scale(color=so.Nominal(['#1f77b4', '#ff7f0e'])) + .label(x='Total Bill ($)', y='Tip ($)', title='Tips Analysis') + .theme({**sns.axes_style('whitegrid')}) +) +``` + +### Faceted Distribution + +```python +( + so.Plot(df, x='measurement', color='treatment') + .facet(col='timepoint', wrap=3) + .add(so.Area(alpha=0.5), so.KDE()) + .add(so.Dot(), so.Jitter(width=0.1), y=0) + .scale(x=so.Continuous().tick(every=5)) + .label(x='Measurement (units)', title='Treatment Effects Over Time') + .share(x=True, y=False) +) +``` + +### Grouped Bar Chart + +```python +( + so.Plot(df, x='category', y='value', color='group') + .add(so.Bar(), so.Agg('mean'), so.Dodge()) + .add(so.Range(), so.Est(errorbar='se'), so.Dodge()) + .scale(color=so.Nominal(order=['A', 'B', 'C'])) + .label(y='Mean Value', title='Comparison by Category and Group') +) +``` + +### Complex Multi-Layer + +```python +( + so.Plot(df, x='date', y='value') + .add(so.Dot(color='gray', pointsize=3), alpha=0.3) + .add(so.Line(color='blue', linewidth=2), so.Agg('mean')) + .add(so.Band(color='blue', alpha=0.2), so.Est(errorbar=('ci', 95))) + .facet(col='sensor', row='location') + .scale( + x=so.Temporal().label(concise=True), + y=so.Continuous().tick(every=10) + ) + .label( + x='Date', + y='Measurement', + title='Sensor Measurements by Location' + ) + .layout(size=(12, 8), engine='constrained') +) +``` + +## Migration from Function Interface + +### Scatter Plot + +**Function interface:** +```python +sns.scatterplot(data=df, x='x', y='y', hue='category', size='value') +``` + +**Objects interface:** +```python +so.Plot(df, x='x', y='y', color='category', pointsize='value').add(so.Dot()) +``` + +### Line Plot with CI + +**Function interface:** +```python +sns.lineplot(data=df, x='time', y='measurement', hue='group', errorbar='ci') +``` + +**Objects interface:** +```python +( + so.Plot(df, x='time', y='measurement', color='group') + .add(so.Line(), so.Est()) +) +``` + +### Histogram + +**Function interface:** +```python +sns.histplot(data=df, x='value', hue='category', stat='density', kde=True) +``` + +**Objects interface:** +```python +( + so.Plot(df, x='value', color='category') + .add(so.Bars(), so.Hist(stat='density')) + .add(so.Line(), so.KDE()) +) +``` + +### Bar Plot with Error Bars + +**Function interface:** +```python +sns.barplot(data=df, x='category', y='value', hue='group', errorbar='ci') +``` + +**Objects interface:** +```python +( + so.Plot(df, x='category', y='value', color='group') + .add(so.Bar(), so.Agg(), so.Dodge()) + .add(so.Range(), so.Est(), so.Dodge()) +) +``` + +## Tips and Best Practices + +1. **Method chaining**: Each method returns a new Plot object, enabling fluent chaining +2. **Layer composition**: Combine multiple `.add()` calls to overlay different marks +3. **Transform order**: In `.add(mark, stat, move)`, stat applies first, then move +4. **Variable priority**: Layer-specific mappings override Plot-level mappings +5. **Scale shortcuts**: Use tuples for simple ranges: `color=(min, max)` vs full Scale object +6. **Jupyter rendering**: Plots render automatically when returned; use `.show()` otherwise +7. **Saving**: Use `.save()` rather than `plt.savefig()` for proper handling +8. **Matplotlib access**: Use `.on(ax)` to integrate with matplotlib figures diff --git a/scientific-packages/torch_geometric/SKILL.md b/scientific-packages/torch_geometric/SKILL.md new file mode 100644 index 0000000..49c8cc3 --- /dev/null +++ b/scientific-packages/torch_geometric/SKILL.md @@ -0,0 +1,670 @@ +--- +name: torch-geometric +description: PyTorch Geometric (PyG) skill for building and training Graph Neural Networks (GNNs) on structured data including graphs, 3D meshes, and point clouds. Use this skill when working with graph-based machine learning tasks such as node classification, graph classification, link prediction, or geometric deep learning on irregular structures. Applies to molecular property prediction, social network analysis, citation networks, 3D vision, and any domain involving relational or geometric data. +--- + +# PyTorch Geometric (PyG) + +## Overview + +PyTorch Geometric is a library built on PyTorch that enables development and training of Graph Neural Networks (GNNs) for applications involving structured data. It provides comprehensive tools for deep learning on graphs and other irregular structures (geometric deep learning), including mini-batch processing, multi-GPU support, and extensive benchmark datasets. + +## When to Use This Skill + +Use this skill when working with: +- **Graph-based machine learning**: Node classification, graph classification, link prediction +- **Molecular property prediction**: Drug discovery, chemical property prediction +- **Social network analysis**: Community detection, influence prediction +- **Citation networks**: Paper classification, recommendation systems +- **3D geometric data**: Point clouds, meshes, molecular structures +- **Heterogeneous graphs**: Multi-type nodes and edges (e.g., knowledge graphs) +- **Large-scale graph learning**: Neighbor sampling, distributed training + +## Quick Start + +### Installation + +```bash +pip install torch_geometric +``` + +For additional dependencies (sparse operations, clustering): +```bash +pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html +``` + +### Basic Graph Creation + +```python +import torch +from torch_geometric.data import Data + +# Create a simple graph with 3 nodes +edge_index = torch.tensor([[0, 1, 1, 2], # source nodes + [1, 0, 2, 1]], dtype=torch.long) # target nodes +x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features + +data = Data(x=x, edge_index=edge_index) +print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") +``` + +### Loading a Benchmark Dataset + +```python +from torch_geometric.datasets import Planetoid + +# Load Cora citation network +dataset = Planetoid(root='/tmp/Cora', name='Cora') +data = dataset[0] # Get the first (and only) graph + +print(f"Dataset: {dataset}") +print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") +print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}") +``` + +## Core Concepts + +### Data Structure + +PyG represents graphs using the `torch_geometric.data.Data` class with these key attributes: + +- **`data.x`**: Node feature matrix `[num_nodes, num_node_features]` +- **`data.edge_index`**: Graph connectivity in COO format `[2, num_edges]` +- **`data.edge_attr`**: Edge feature matrix `[num_edges, num_edge_features]` (optional) +- **`data.y`**: Target labels for nodes or graphs +- **`data.pos`**: Node spatial positions `[num_nodes, num_dimensions]` (optional) +- **Custom attributes**: Can add any attribute (e.g., `data.train_mask`, `data.batch`) + +**Important**: These attributes are not mandatory—extend Data objects with custom attributes as needed. + +### Edge Index Format + +Edges are stored in COO (coordinate) format as a `[2, num_edges]` tensor: +- First row: source node indices +- Second row: target node indices + +```python +# Edge list: (0→1), (1→0), (1→2), (2→1) +edge_index = torch.tensor([[0, 1, 1, 2], + [1, 0, 2, 1]], dtype=torch.long) +``` + +### Mini-Batch Processing + +PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph: + +- Adjacency matrices are stacked diagonally +- Node features are concatenated along the node dimension +- A `batch` vector maps each node to its source graph +- No padding needed—computationally efficient + +```python +from torch_geometric.loader import DataLoader + +loader = DataLoader(dataset, batch_size=32, shuffle=True) +for batch in loader: + print(f"Batch size: {batch.num_graphs}") + print(f"Total nodes: {batch.num_nodes}") + # batch.batch maps nodes to graphs +``` + +## Building Graph Neural Networks + +### Message Passing Paradigm + +GNNs in PyG follow a neighborhood aggregation scheme: +1. Transform node features +2. Propagate messages along edges +3. Aggregate messages from neighbors +4. Update node representations + +### Using Pre-Built Layers + +PyG provides 40+ convolutional layers. Common ones include: + +**GCNConv** (Graph Convolutional Network): +```python +from torch_geometric.nn import GCNConv +import torch.nn.functional as F + +class GCN(torch.nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.conv1 = GCNConv(num_features, 16) + self.conv2 = GCNConv(16, num_classes) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index) + return F.log_softmax(x, dim=1) +``` + +**GATConv** (Graph Attention Network): +```python +from torch_geometric.nn import GATConv + +class GAT(torch.nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6) + self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + x = F.dropout(x, p=0.6, training=self.training) + x = F.elu(self.conv1(x, edge_index)) + x = F.dropout(x, p=0.6, training=self.training) + x = self.conv2(x, edge_index) + return F.log_softmax(x, dim=1) +``` + +**GraphSAGE**: +```python +from torch_geometric.nn import SAGEConv + +class GraphSAGE(torch.nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.conv1 = SAGEConv(num_features, 64) + self.conv2 = SAGEConv(64, num_classes) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index) + return F.log_softmax(x, dim=1) +``` + +### Custom Message Passing Layers + +For custom layers, inherit from `MessagePassing`: + +```python +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import add_self_loops, degree + +class CustomConv(MessagePassing): + def __init__(self, in_channels, out_channels): + super().__init__(aggr='add') # "add", "mean", or "max" + self.lin = torch.nn.Linear(in_channels, out_channels) + + def forward(self, x, edge_index): + # Add self-loops to adjacency matrix + edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) + + # Transform node features + x = self.lin(x) + + # Compute normalization + row, col = edge_index + deg = degree(col, x.size(0), dtype=x.dtype) + deg_inv_sqrt = deg.pow(-0.5) + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + # Propagate messages + return self.propagate(edge_index, x=x, norm=norm) + + def message(self, x_j, norm): + # x_j: features of source nodes + return norm.view(-1, 1) * x_j +``` + +Key methods: +- **`forward()`**: Main entry point +- **`message()`**: Constructs messages from source to target nodes +- **`aggregate()`**: Aggregates messages (usually don't override—set `aggr` parameter) +- **`update()`**: Updates node embeddings after aggregation + +**Variable naming convention**: Appending `_i` or `_j` to tensor names automatically maps them to target or source nodes. + +## Working with Datasets + +### Loading Built-in Datasets + +PyG provides extensive benchmark datasets: + +```python +# Citation networks (node classification) +from torch_geometric.datasets import Planetoid +dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed' + +# Graph classification +from torch_geometric.datasets import TUDataset +dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') + +# Molecular datasets +from torch_geometric.datasets import QM9 +dataset = QM9(root='/tmp/QM9') + +# Large-scale datasets +from torch_geometric.datasets import Reddit +dataset = Reddit(root='/tmp/Reddit') +``` + +Check `references/datasets_reference.md` for a comprehensive list. + +### Creating Custom Datasets + +For datasets that fit in memory, inherit from `InMemoryDataset`: + +```python +from torch_geometric.data import InMemoryDataset, Data +import torch + +class MyOwnDataset(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None): + super().__init__(root, transform, pre_transform) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return ['my_data.csv'] # Files needed in raw_dir + + @property + def processed_file_names(self): + return ['data.pt'] # Files in processed_dir + + def download(self): + # Download raw data to self.raw_dir + pass + + def process(self): + # Read data, create Data objects + data_list = [] + + # Example: Create a simple graph + edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) + x = torch.randn(2, 16) + y = torch.tensor([0], dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, y=y) + data_list.append(data) + + # Apply pre_filter and pre_transform + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + # Save processed data + self.save(data_list, self.processed_paths[0]) +``` + +For large datasets that don't fit in memory, inherit from `Dataset` and implement `len()` and `get(idx)`. + +### Loading Graphs from CSV + +```python +import pandas as pd +import torch +from torch_geometric.data import HeteroData + +# Load nodes +nodes_df = pd.read_csv('nodes.csv') +x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float) + +# Load edges +edges_df = pd.read_csv('edges.csv') +edge_index = torch.tensor([edges_df['source'].values, + edges_df['target'].values], dtype=torch.long) + +data = Data(x=x, edge_index=edge_index) +``` + +## Training Workflows + +### Node Classification (Single Graph) + +```python +import torch +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid + +# Load dataset +dataset = Planetoid(root='/tmp/Cora', name='Cora') +data = dataset[0] + +# Create model +model = GCN(dataset.num_features, dataset.num_classes) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + +# Training +model.train() +for epoch in range(200): + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + + if epoch % 10 == 0: + print(f'Epoch {epoch}, Loss: {loss.item():.4f}') + +# Evaluation +model.eval() +pred = model(data).argmax(dim=1) +correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() +acc = int(correct) / int(data.test_mask.sum()) +print(f'Test Accuracy: {acc:.4f}') +``` + +### Graph Classification (Multiple Graphs) + +```python +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import global_mean_pool + +class GraphClassifier(torch.nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.conv1 = GCNConv(num_features, 64) + self.conv2 = GCNConv(64, 64) + self.lin = torch.nn.Linear(64, num_classes) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + x = F.relu(x) + + # Global pooling (aggregate node features to graph-level) + x = global_mean_pool(x, batch) + + x = self.lin(x) + return F.log_softmax(x, dim=1) + +# Load dataset +dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') +loader = DataLoader(dataset, batch_size=32, shuffle=True) + +model = GraphClassifier(dataset.num_features, dataset.num_classes) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +# Training +model.train() +for epoch in range(100): + total_loss = 0 + for batch in loader: + optimizer.zero_grad() + out = model(batch) + loss = F.nll_loss(out, batch.y) + loss.backward() + optimizer.step() + total_loss += loss.item() + + if epoch % 10 == 0: + print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}') +``` + +### Large-Scale Graphs with Neighbor Sampling + +For large graphs, use `NeighborLoader` to sample subgraphs: + +```python +from torch_geometric.loader import NeighborLoader + +# Create a neighbor sampler +train_loader = NeighborLoader( + data, + num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop + batch_size=128, + input_nodes=data.train_mask, +) + +# Training +model.train() +for batch in train_loader: + optimizer.zero_grad() + out = model(batch) + # Only compute loss on seed nodes (first batch_size nodes) + loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size]) + loss.backward() + optimizer.step() +``` + +**Important**: +- Output subgraphs are directed +- Node indices are relabeled (0 to batch.num_nodes - 1) +- Only use seed node predictions for loss computation +- Sampling beyond 2-3 hops is generally not feasible + +## Advanced Features + +### Heterogeneous Graphs + +For graphs with multiple node and edge types, use `HeteroData`: + +```python +from torch_geometric.data import HeteroData + +data = HeteroData() + +# Add node features for different types +data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features +data['author'].x = torch.randn(200, 64) # 200 authors with 64 features + +# Add edges for different types (source_type, edge_type, target_type) +data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) +data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300)) + +print(data) +``` + +Convert homogeneous models to heterogeneous: + +```python +from torch_geometric.nn import to_hetero + +# Define homogeneous model +model = GNN(...) + +# Convert to heterogeneous +model = to_hetero(model, data.metadata(), aggr='sum') + +# Use as normal +out = model(data.x_dict, data.edge_index_dict) +``` + +Or use `HeteroConv` for custom edge-type-specific operations: + +```python +from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv + +class HeteroGNN(torch.nn.Module): + def __init__(self, metadata): + super().__init__() + self.conv1 = HeteroConv({ + ('paper', 'cites', 'paper'): GCNConv(-1, 64), + ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), + }, aggr='sum') + + self.conv2 = HeteroConv({ + ('paper', 'cites', 'paper'): GCNConv(64, 32), + ('author', 'writes', 'paper'): SAGEConv((64, 64), 32), + }, aggr='sum') + + def forward(self, x_dict, edge_index_dict): + x_dict = self.conv1(x_dict, edge_index_dict) + x_dict = {key: F.relu(x) for key, x in x_dict.items()} + x_dict = self.conv2(x_dict, edge_index_dict) + return x_dict +``` + +### Transforms + +Apply transforms to modify graph structure or features: + +```python +from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose + +# Single transform +transform = NormalizeFeatures() +dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform) + +# Compose multiple transforms +transform = Compose([ + AddSelfLoops(), + NormalizeFeatures(), +]) +dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform) +``` + +Common transforms: +- **Structure**: `ToUndirected`, `AddSelfLoops`, `RemoveSelfLoops`, `KNNGraph`, `RadiusGraph` +- **Features**: `NormalizeFeatures`, `NormalizeScale`, `Center` +- **Sampling**: `RandomNodeSplit`, `RandomLinkSplit` +- **Positional Encoding**: `AddLaplacianEigenvectorPE`, `AddRandomWalkPE` + +See `references/transforms_reference.md` for the full list. + +### Model Explainability + +PyG provides explainability tools to understand model predictions: + +```python +from torch_geometric.explain import Explainer, GNNExplainer + +# Create explainer +explainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=200), + explanation_type='model', # or 'phenomenon' + node_mask_type='attributes', + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + task_level='node', + return_type='log_probs', + ), +) + +# Generate explanation for a specific node +node_idx = 10 +explanation = explainer(data.x, data.edge_index, index=node_idx) + +# Visualize +print(f'Node {node_idx} explanation:') +print(f'Important edges: {explanation.edge_mask.topk(5).indices}') +print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}') +``` + +### Pooling Operations + +For hierarchical graph representations: + +```python +from torch_geometric.nn import TopKPooling, global_mean_pool + +class HierarchicalGNN(torch.nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.conv1 = GCNConv(num_features, 64) + self.pool1 = TopKPooling(64, ratio=0.8) + self.conv2 = GCNConv(64, 64) + self.pool2 = TopKPooling(64, ratio=0.8) + self.lin = torch.nn.Linear(64, num_classes) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + + x = F.relu(self.conv1(x, edge_index)) + x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) + + x = F.relu(self.conv2(x, edge_index)) + x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) + + x = global_mean_pool(x, batch) + x = self.lin(x) + return F.log_softmax(x, dim=1) +``` + +## Common Patterns and Best Practices + +### Check Graph Properties + +```python +# Undirected check +from torch_geometric.utils import is_undirected +print(f"Is undirected: {is_undirected(data.edge_index)}") + +# Connected components +from torch_geometric.utils import connected_components +print(f"Connected components: {connected_components(data.edge_index)}") + +# Contains self-loops +from torch_geometric.utils import contains_self_loops +print(f"Has self-loops: {contains_self_loops(data.edge_index)}") +``` + +### GPU Training + +```python +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.to(device) +data = data.to(device) + +# For DataLoader +for batch in loader: + batch = batch.to(device) + # Train... +``` + +### Save and Load Models + +```python +# Save +torch.save(model.state_dict(), 'model.pth') + +# Load +model = GCN(num_features, num_classes) +model.load_state_dict(torch.load('model.pth')) +model.eval() +``` + +### Layer Capabilities + +When choosing layers, consider these capabilities: +- **SparseTensor**: Supports efficient sparse matrix operations +- **edge_weight**: Handles one-dimensional edge weights +- **edge_attr**: Processes multi-dimensional edge features +- **Bipartite**: Works with bipartite graphs (different source/target dimensions) +- **Lazy**: Enables initialization without specifying input dimensions + +See the GNN cheatsheet at `references/layer_capabilities.md`. + +## Resources + +### Bundled References + +This skill includes detailed reference documentation: + +- **`references/layers_reference.md`**: Complete listing of all 40+ GNN layers with descriptions and capabilities +- **`references/datasets_reference.md`**: Comprehensive dataset catalog organized by category +- **`references/transforms_reference.md`**: All available transforms and their use cases +- **`references/api_patterns.md`**: Common API patterns and coding examples + +### Scripts + +Utility scripts are provided in `scripts/`: + +- **`scripts/visualize_graph.py`**: Visualize graph structure using networkx and matplotlib +- **`scripts/create_gnn_template.py`**: Generate boilerplate code for common GNN architectures +- **`scripts/benchmark_model.py`**: Benchmark model performance on standard datasets + +Execute scripts directly or read them for implementation patterns. + +### Official Resources + +- **Documentation**: https://pytorch-geometric.readthedocs.io/ +- **GitHub**: https://github.com/pyg-team/pytorch_geometric +- **Tutorials**: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html +- **Examples**: https://github.com/pyg-team/pytorch_geometric/tree/master/examples diff --git a/scientific-packages/torch_geometric/references/datasets_reference.md b/scientific-packages/torch_geometric/references/datasets_reference.md new file mode 100644 index 0000000..344cc9b --- /dev/null +++ b/scientific-packages/torch_geometric/references/datasets_reference.md @@ -0,0 +1,574 @@ +# PyTorch Geometric Datasets Reference + +This document provides a comprehensive catalog of all datasets available in `torch_geometric.datasets`. + +## Citation Networks + +### Planetoid +**Usage**: Node classification, semi-supervised learning +**Networks**: Cora, CiteSeer, PubMed +**Description**: Citation networks where nodes are papers and edges are citations +- **Cora**: 2,708 nodes, 5,429 edges, 7 classes, 1,433 features +- **CiteSeer**: 3,327 nodes, 4,732 edges, 6 classes, 3,703 features +- **PubMed**: 19,717 nodes, 44,338 edges, 3 classes, 500 features + +```python +from torch_geometric.datasets import Planetoid +dataset = Planetoid(root='/tmp/Cora', name='Cora') +``` + +### Coauthor +**Usage**: Node classification on collaboration networks +**Networks**: CS, Physics +**Description**: Co-authorship networks from Microsoft Academic Graph +- **CS**: 18,333 nodes, 81,894 edges, 15 classes (computer science) +- **Physics**: 34,493 nodes, 247,962 edges, 5 classes (physics) + +```python +from torch_geometric.datasets import Coauthor +dataset = Coauthor(root='/tmp/CS', name='CS') +``` + +### Amazon +**Usage**: Node classification on product networks +**Networks**: Computers, Photo +**Description**: Amazon co-purchase networks where nodes are products +- **Computers**: 13,752 nodes, 245,861 edges, 10 classes +- **Photo**: 7,650 nodes, 119,081 edges, 8 classes + +```python +from torch_geometric.datasets import Amazon +dataset = Amazon(root='/tmp/Computers', name='Computers') +``` + +### CitationFull +**Usage**: Citation network analysis +**Networks**: Cora, Cora_ML, DBLP, PubMed +**Description**: Full citation networks without sampling + +```python +from torch_geometric.datasets import CitationFull +dataset = CitationFull(root='/tmp/Cora', name='Cora') +``` + +## Graph Classification + +### TUDataset +**Usage**: Graph classification, graph kernel benchmarks +**Description**: Collection of 120+ graph classification datasets +- **MUTAG**: 188 graphs, 2 classes (molecular compounds) +- **PROTEINS**: 1,113 graphs, 2 classes (protein structures) +- **ENZYMES**: 600 graphs, 6 classes (protein enzymes) +- **IMDB-BINARY**: 1,000 graphs, 2 classes (social networks) +- **REDDIT-BINARY**: 2,000 graphs, 2 classes (discussion threads) +- **COLLAB**: 5,000 graphs, 3 classes (scientific collaborations) +- **NCI1**: 4,110 graphs, 2 classes (chemical compounds) +- **DD**: 1,178 graphs, 2 classes (protein structures) + +```python +from torch_geometric.datasets import TUDataset +dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') +``` + +### MoleculeNet +**Usage**: Molecular property prediction +**Datasets**: Over 10 molecular benchmark datasets +**Description**: Comprehensive molecular machine learning benchmarks +- **ESOL**: Aqueous solubility (regression) +- **FreeSolv**: Hydration free energy (regression) +- **Lipophilicity**: Octanol/water distribution (regression) +- **BACE**: Binding results (classification) +- **BBBP**: Blood-brain barrier penetration (classification) +- **HIV**: HIV inhibition (classification) +- **Tox21**: Toxicity prediction (multi-task classification) +- **ToxCast**: Toxicology forecasting (multi-task classification) +- **SIDER**: Side effects (multi-task classification) +- **ClinTox**: Clinical trial toxicity (multi-task classification) + +```python +from torch_geometric.datasets import MoleculeNet +dataset = MoleculeNet(root='/tmp/ESOL', name='ESOL') +``` + +## Molecular and Chemical Datasets + +### QM7b +**Usage**: Molecular property prediction (quantum mechanics) +**Description**: 7,211 molecules with up to 7 heavy atoms +- Properties: Atomization energies, electronic properties + +```python +from torch_geometric.datasets import QM7b +dataset = QM7b(root='/tmp/QM7b') +``` + +### QM9 +**Usage**: Molecular property prediction (quantum mechanics) +**Description**: ~130,000 molecules with up to 9 heavy atoms (C, O, N, F) +- Properties: 19 quantum chemical properties including HOMO, LUMO, gap, energy + +```python +from torch_geometric.datasets import QM9 +dataset = QM9(root='/tmp/QM9') +``` + +### ZINC +**Usage**: Molecular generation, property prediction +**Description**: ~250,000 drug-like molecular graphs +- Properties: Constrained solubility, molecular weight + +```python +from torch_geometric.datasets import ZINC +dataset = ZINC(root='/tmp/ZINC', subset=True) +``` + +### AQSOL +**Usage**: Aqueous solubility prediction +**Description**: ~10,000 molecules with solubility measurements + +```python +from torch_geometric.datasets import AQSOL +dataset = AQSOL(root='/tmp/AQSOL') +``` + +### MD17 +**Usage**: Molecular dynamics, force field learning +**Description**: Molecular dynamics trajectories for small molecules +- Molecules: Benzene, Uracil, Naphthalene, Aspirin, Salicylic acid, etc. + +```python +from torch_geometric.datasets import MD17 +dataset = MD17(root='/tmp/MD17', name='benzene') +``` + +### PCQM4Mv2 +**Usage**: Large-scale molecular property prediction +**Description**: 3.8M molecules from PubChem for quantum chemistry +- Part of OGB Large-Scale Challenge + +```python +from torch_geometric.datasets import PCQM4Mv2 +dataset = PCQM4Mv2(root='/tmp/PCQM4Mv2') +``` + +## Social Networks + +### Reddit +**Usage**: Large-scale node classification +**Description**: Reddit posts from September 2014 +- 232,965 nodes, 11,606,919 edges, 41 classes +- Features: TF-IDF of post content + +```python +from torch_geometric.datasets import Reddit +dataset = Reddit(root='/tmp/Reddit') +``` + +### Reddit2 +**Usage**: Large-scale node classification +**Description**: Updated Reddit dataset with more posts + +```python +from torch_geometric.datasets import Reddit2 +dataset = Reddit2(root='/tmp/Reddit2') +``` + +### Twitch +**Usage**: Node classification, social network analysis +**Networks**: DE, EN, ES, FR, PT, RU +**Description**: Twitch user networks by language + +```python +from torch_geometric.datasets import Twitch +dataset = Twitch(root='/tmp/Twitch', name='DE') +``` + +### Facebook +**Usage**: Social network analysis, node classification +**Description**: Facebook page-page networks + +```python +from torch_geometric.datasets import FacebookPagePage +dataset = FacebookPagePage(root='/tmp/Facebook') +``` + +### GitHub +**Usage**: Social network analysis +**Description**: GitHub developer networks + +```python +from torch_geometric.datasets import GitHub +dataset = GitHub(root='/tmp/GitHub') +``` + +## Knowledge Graphs + +### Entities +**Usage**: Link prediction, knowledge graph embeddings +**Datasets**: AIFB, MUTAG, BGS, AM +**Description**: RDF knowledge graphs with typed relations + +```python +from torch_geometric.datasets import Entities +dataset = Entities(root='/tmp/AIFB', name='AIFB') +``` + +### WordNet18 +**Usage**: Link prediction on semantic networks +**Description**: Subset of WordNet with 18 relations +- 40,943 entities, 151,442 triplets + +```python +from torch_geometric.datasets import WordNet18 +dataset = WordNet18(root='/tmp/WordNet18') +``` + +### WordNet18RR +**Usage**: Link prediction (no inverse relations) +**Description**: Refined version without inverse relations + +```python +from torch_geometric.datasets import WordNet18RR +dataset = WordNet18RR(root='/tmp/WordNet18RR') +``` + +### FB15k-237 +**Usage**: Link prediction on Freebase +**Description**: Subset of Freebase with 237 relations +- 14,541 entities, 310,116 triplets + +```python +from torch_geometric.datasets import FB15k_237 +dataset = FB15k_237(root='/tmp/FB15k') +``` + +## Heterogeneous Graphs + +### OGB_MAG +**Usage**: Heterogeneous graph learning, node classification +**Description**: Microsoft Academic Graph with multiple node/edge types +- Node types: paper, author, institution, field of study +- 1M+ nodes, 21M+ edges + +```python +from torch_geometric.datasets import OGB_MAG +dataset = OGB_MAG(root='/tmp/OGB_MAG') +``` + +### MovieLens +**Usage**: Recommendation systems, link prediction +**Versions**: 100K, 1M, 10M, 20M +**Description**: User-movie rating networks +- Node types: user, movie +- Edge types: rates + +```python +from torch_geometric.datasets import MovieLens +dataset = MovieLens(root='/tmp/MovieLens', model_name='100k') +``` + +### IMDB +**Usage**: Heterogeneous graph learning +**Description**: IMDB movie network +- Node types: movie, actor, director + +```python +from torch_geometric.datasets import IMDB +dataset = IMDB(root='/tmp/IMDB') +``` + +### DBLP +**Usage**: Heterogeneous graph learning, node classification +**Description**: DBLP bibliography network +- Node types: author, paper, term, conference + +```python +from torch_geometric.datasets import DBLP +dataset = DBLP(root='/tmp/DBLP') +``` + +### LastFM +**Usage**: Heterogeneous recommendation +**Description**: LastFM music network +- Node types: user, artist, tag + +```python +from torch_geometric.datasets import LastFM +dataset = LastFM(root='/tmp/LastFM') +``` + +## Temporal Graphs + +### BitcoinOTC +**Usage**: Temporal link prediction, trust networks +**Description**: Bitcoin OTC trust network over time + +```python +from torch_geometric.datasets import BitcoinOTC +dataset = BitcoinOTC(root='/tmp/BitcoinOTC') +``` + +### ICEWS18 +**Usage**: Temporal knowledge graph completion +**Description**: Integrated Crisis Early Warning System events + +```python +from torch_geometric.datasets import ICEWS18 +dataset = ICEWS18(root='/tmp/ICEWS18') +``` + +### GDELT +**Usage**: Temporal event forecasting +**Description**: Global Database of Events, Language, and Tone + +```python +from torch_geometric.datasets import GDELT +dataset = GDELT(root='/tmp/GDELT') +``` + +### JODIEDataset +**Usage**: Dynamic graph learning +**Datasets**: Reddit, Wikipedia, MOOC, LastFM +**Description**: Temporal interaction networks + +```python +from torch_geometric.datasets import JODIEDataset +dataset = JODIEDataset(root='/tmp/JODIE', name='Reddit') +``` + +## 3D Meshes and Point Clouds + +### ShapeNet +**Usage**: 3D shape classification and segmentation +**Description**: Large-scale 3D CAD model dataset +- 16,881 models across 16 categories +- Part-level segmentation labels + +```python +from torch_geometric.datasets import ShapeNet +dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane']) +``` + +### ModelNet +**Usage**: 3D shape classification +**Versions**: ModelNet10, ModelNet40 +**Description**: CAD models for 3D object classification +- ModelNet10: 4,899 models, 10 categories +- ModelNet40: 12,311 models, 40 categories + +```python +from torch_geometric.datasets import ModelNet +dataset = ModelNet(root='/tmp/ModelNet', name='10') +``` + +### FAUST +**Usage**: 3D shape matching, correspondence +**Description**: Human body scans for shape analysis +- 100 meshes of 10 people in 10 poses + +```python +from torch_geometric.datasets import FAUST +dataset = FAUST(root='/tmp/FAUST') +``` + +### CoMA +**Usage**: 3D mesh deformation +**Description**: Facial expression meshes +- 20,466 3D face scans with expressions + +```python +from torch_geometric.datasets import CoMA +dataset = CoMA(root='/tmp/CoMA') +``` + +### S3DIS +**Usage**: 3D semantic segmentation +**Description**: Stanford Large-Scale 3D Indoor Spaces +- 6 areas, 271 rooms, point cloud data + +```python +from torch_geometric.datasets import S3DIS +dataset = S3DIS(root='/tmp/S3DIS', test_area=6) +``` + +## Image and Vision Datasets + +### MNISTSuperpixels +**Usage**: Graph-based image classification +**Description**: MNIST images as superpixel graphs +- 70,000 graphs (60k train, 10k test) + +```python +from torch_geometric.datasets import MNISTSuperpixels +dataset = MNISTSuperpixels(root='/tmp/MNIST') +``` + +### Flickr +**Usage**: Image description, node classification +**Description**: Flickr image network +- 89,250 nodes, 899,756 edges + +```python +from torch_geometric.datasets import Flickr +dataset = Flickr(root='/tmp/Flickr') +``` + +### PPI +**Usage**: Protein-protein interaction prediction +**Description**: Multi-graph protein interaction networks +- 24 graphs, 2,373 nodes total + +```python +from torch_geometric.datasets import PPI +dataset = PPI(root='/tmp/PPI', split='train') +``` + +## Small Classic Graphs + +### KarateClub +**Usage**: Community detection, visualization +**Description**: Zachary's karate club network +- 34 nodes, 78 edges, 2 communities + +```python +from torch_geometric.datasets import KarateClub +dataset = KarateClub() +``` + +## Open Graph Benchmark (OGB) + +PyG integrates seamlessly with OGB datasets: + +### Node Property Prediction +- **ogbn-products**: Amazon product network (2.4M nodes) +- **ogbn-proteins**: Protein association network (132K nodes) +- **ogbn-arxiv**: Citation network (169K nodes) +- **ogbn-papers100M**: Large citation network (111M nodes) +- **ogbn-mag**: Heterogeneous academic graph + +### Link Property Prediction +- **ogbl-ppa**: Protein association networks +- **ogbl-collab**: Collaboration networks +- **ogbl-ddi**: Drug-drug interaction network +- **ogbl-citation2**: Citation network +- **ogbl-wikikg2**: Wikidata knowledge graph + +### Graph Property Prediction +- **ogbg-molhiv**: Molecular HIV activity prediction +- **ogbg-molpcba**: Molecular bioassays (multi-task) +- **ogbg-ppa**: Protein function prediction +- **ogbg-code2**: Code abstract syntax trees + +```python +from torch_geometric.datasets import OGB_MAG, OGB_PPA +# or +from ogb.nodeproppred import PygNodePropPredDataset +dataset = PygNodePropPredDataset(name='ogbn-arxiv') +``` + +## Synthetic Datasets + +### FakeDataset +**Usage**: Testing, debugging +**Description**: Generates random graph data + +```python +from torch_geometric.datasets import FakeDataset +dataset = FakeDataset(num_graphs=100, avg_num_nodes=50) +``` + +### StochasticBlockModelDataset +**Usage**: Community detection benchmarks +**Description**: Graphs generated from stochastic block models + +```python +from torch_geometric.datasets import StochasticBlockModelDataset +dataset = StochasticBlockModelDataset(root='/tmp/SBM', num_graphs=1000) +``` + +### ExplainerDataset +**Usage**: Testing explainability methods +**Description**: Synthetic graphs with known explanation ground truth + +```python +from torch_geometric.datasets import ExplainerDataset +dataset = ExplainerDataset(num_graphs=1000) +``` + +## Materials Science + +### QM8 +**Usage**: Molecular property prediction +**Description**: Electronic properties of small molecules + +```python +from torch_geometric.datasets import QM8 +dataset = QM8(root='/tmp/QM8') +``` + +## Biological Networks + +### PPI (Protein-Protein Interaction) +Already listed above under Image and Vision Datasets + +### STRING +**Usage**: Protein interaction networks +**Description**: Known and predicted protein-protein interactions + +```python +# Available through external sources or custom loading +``` + +## Usage Tips + +1. **Start with small datasets**: Use Cora, KarateClub, or ENZYMES for prototyping +2. **Citation networks**: Planetoid datasets are perfect for node classification +3. **Graph classification**: TUDataset provides diverse benchmarks +4. **Molecular**: QM9, ZINC, MoleculeNet for chemistry applications +5. **Large-scale**: Use Reddit, OGB datasets with NeighborLoader +6. **Heterogeneous**: OGB_MAG, MovieLens, IMDB for multi-type graphs +7. **Temporal**: JODIE, ICEWS for dynamic graph learning +8. **3D**: ShapeNet, ModelNet, S3DIS for geometric learning + +## Common Patterns + +### Loading with Transforms +```python +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import NormalizeFeatures + +dataset = Planetoid(root='/tmp/Cora', name='Cora', + transform=NormalizeFeatures()) +``` + +### Train/Val/Test Splits +```python +# For datasets with pre-defined splits +data = dataset[0] +train_data = data[data.train_mask] +val_data = data[data.val_mask] +test_data = data[data.test_mask] + +# For graph classification +from torch_geometric.loader import DataLoader +train_dataset = dataset[:int(len(dataset) * 0.8)] +test_dataset = dataset[int(len(dataset) * 0.8):] +train_loader = DataLoader(train_dataset, batch_size=32) +``` + +### Custom Data Loading +```python +from torch_geometric.data import Data, Dataset + +class MyCustomDataset(Dataset): + def __init__(self, root, transform=None): + super().__init__(root, transform) + # Your initialization + + def len(self): + return len(self.data_list) + + def get(self, idx): + # Load and return data object + return self.data_list[idx] +``` diff --git a/scientific-packages/torch_geometric/references/layers_reference.md b/scientific-packages/torch_geometric/references/layers_reference.md new file mode 100644 index 0000000..e465894 --- /dev/null +++ b/scientific-packages/torch_geometric/references/layers_reference.md @@ -0,0 +1,485 @@ +# PyTorch Geometric Neural Network Layers Reference + +This document provides a comprehensive reference of all neural network layers available in `torch_geometric.nn`. + +## Layer Capability Flags + +When selecting layers, consider these capability flags: + +- **SparseTensor**: Supports `torch_sparse.SparseTensor` format for efficient sparse operations +- **edge_weight**: Handles one-dimensional edge weight data +- **edge_attr**: Processes multi-dimensional edge feature information +- **Bipartite**: Works with bipartite graphs (different source/target node dimensions) +- **Static**: Operates on static graphs with batched node features +- **Lazy**: Enables initialization without specifying input channel dimensions + +## Convolutional Layers + +### Standard Graph Convolutions + +**GCNConv** - Graph Convolutional Network layer +- Implements spectral graph convolution with symmetric normalization +- Supports: SparseTensor, edge_weight, Bipartite, Lazy +- Use for: Citation networks, social networks, general graph learning +- Example: `GCNConv(in_channels, out_channels, improved=False, cached=True)` + +**SAGEConv** - GraphSAGE layer +- Inductive learning via neighborhood sampling and aggregation +- Supports: SparseTensor, Bipartite, Lazy +- Use for: Large graphs, inductive learning, heterogeneous features +- Example: `SAGEConv(in_channels, out_channels, aggr='mean')` + +**GATConv** - Graph Attention Network layer +- Multi-head attention mechanism for adaptive neighbor weighting +- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy +- Use for: Tasks requiring variable neighbor importance +- Example: `GATConv(in_channels, out_channels, heads=8, dropout=0.6)` + +**GraphConv** - Simple graph convolution (Morris et al.) +- Basic message passing with optional edge weights +- Supports: SparseTensor, edge_weight, Bipartite, Lazy +- Use for: Baseline models, simple graph structures +- Example: `GraphConv(in_channels, out_channels, aggr='add')` + +**GINConv** - Graph Isomorphism Network layer +- Maximally powerful GNN for graph isomorphism testing +- Supports: Bipartite +- Use for: Graph classification, molecular property prediction +- Example: `GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))` + +**TransformerConv** - Graph Transformer layer +- Combines graph structure with transformer attention +- Supports: SparseTensor, Bipartite, Lazy +- Use for: Long-range dependencies, complex graphs +- Example: `TransformerConv(in_channels, out_channels, heads=8, beta=True)` + +**ChebConv** - Chebyshev spectral graph convolution +- Uses Chebyshev polynomials for efficient spectral filtering +- Supports: SparseTensor, edge_weight, Bipartite, Lazy +- Use for: Spectral graph learning, efficient convolutions +- Example: `ChebConv(in_channels, out_channels, K=3)` + +**SGConv** - Simplified Graph Convolution +- Pre-computes fixed number of propagation steps +- Supports: SparseTensor, edge_weight, Bipartite, Lazy +- Use for: Fast training, shallow models +- Example: `SGConv(in_channels, out_channels, K=2)` + +**APPNP** - Approximate Personalized Propagation of Neural Predictions +- Separates feature transformation from propagation +- Supports: SparseTensor, edge_weight, Lazy +- Use for: Deep propagation without oversmoothing +- Example: `APPNP(K=10, alpha=0.1)` + +**ARMAConv** - ARMA graph convolution +- Uses ARMA filters for graph filtering +- Supports: SparseTensor, edge_weight, Bipartite, Lazy +- Use for: Advanced spectral methods +- Example: `ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)` + +**GATv2Conv** - Improved Graph Attention Network +- Fixes static attention computation issue in GAT +- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy +- Use for: Better attention learning than original GAT +- Example: `GATv2Conv(in_channels, out_channels, heads=8)` + +**SuperGATConv** - Self-supervised Graph Attention +- Adds self-supervised attention mechanism +- Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy +- Use for: Self-supervised learning, limited labels +- Example: `SuperGATConv(in_channels, out_channels, heads=8)` + +**GMMConv** - Gaussian Mixture Model Convolution +- Uses Gaussian kernels in pseudo-coordinate space +- Supports: Bipartite +- Use for: Point clouds, spatial data +- Example: `GMMConv(in_channels, out_channels, dim=3, kernel_size=5)` + +**SplineConv** - Spline-based convolution +- B-spline basis functions for spatial filtering +- Supports: Bipartite +- Use for: Irregular grids, continuous spaces +- Example: `SplineConv(in_channels, out_channels, dim=2, kernel_size=5)` + +**NNConv** - Neural Network Convolution +- Edge features processed by neural networks +- Supports: edge_attr, Bipartite +- Use for: Rich edge features, molecular graphs +- Example: `NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')` + +**CGConv** - Crystal Graph Convolution +- Designed for crystalline materials +- Supports: Bipartite +- Use for: Materials science, crystal structures +- Example: `CGConv(in_channels, dim=3, batch_norm=True)` + +**EdgeConv** - Edge Convolution (Dynamic Graph CNN) +- Dynamically computes edges based on feature space +- Supports: Static +- Use for: Point clouds, dynamic graphs +- Example: `EdgeConv(nn=edge_nn, aggr='max')` + +**PointNetConv** - PointNet++ convolution +- Local and global feature learning for point clouds +- Use for: 3D point cloud processing +- Example: `PointNetConv(local_nn, global_nn)` + +**ResGatedGraphConv** - Residual Gated Graph Convolution +- Gating mechanism with residual connections +- Supports: edge_attr, Bipartite, Lazy +- Use for: Deep GNNs, complex features +- Example: `ResGatedGraphConv(in_channels, out_channels)` + +**GENConv** - Generalized Graph Convolution +- Generalizes multiple GNN variants +- Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy +- Use for: Flexible architecture exploration +- Example: `GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)` + +**FiLMConv** - Feature-wise Linear Modulation +- Conditions on global features +- Supports: Bipartite, Lazy +- Use for: Conditional generation, multi-task learning +- Example: `FiLMConv(in_channels, out_channels, num_relations=5)` + +**PANConv** - Path Attention Network +- Attention over multi-hop paths +- Supports: SparseTensor, Lazy +- Use for: Complex connectivity patterns +- Example: `PANConv(in_channels, out_channels, filter_size=3)` + +**ClusterGCNConv** - Cluster-GCN convolution +- Efficient training via graph clustering +- Supports: edge_attr, Lazy +- Use for: Very large graphs +- Example: `ClusterGCNConv(in_channels, out_channels)` + +**MFConv** - Multi-scale Feature Convolution +- Aggregates features at multiple scales +- Supports: SparseTensor, Lazy +- Use for: Multi-scale patterns +- Example: `MFConv(in_channels, out_channels)` + +**RGCNConv** - Relational Graph Convolution +- Handles multiple edge types +- Supports: SparseTensor, edge_weight, Lazy +- Use for: Knowledge graphs, heterogeneous graphs +- Example: `RGCNConv(in_channels, out_channels, num_relations=10)` + +**FAConv** - Frequency Adaptive Convolution +- Adaptive filtering in spectral domain +- Supports: SparseTensor, Lazy +- Use for: Spectral graph learning +- Example: `FAConv(in_channels, eps=0.1, dropout=0.5)` + +### Molecular and 3D Convolutions + +**SchNet** - Continuous-filter convolutional layer +- Designed for molecular dynamics +- Use for: Molecular property prediction, 3D molecules +- Example: `SchNet(hidden_channels=128, num_filters=64, num_interactions=6)` + +**DimeNet** - Directional Message Passing +- Uses directional information and angles +- Use for: 3D molecular structures, chemical properties +- Example: `DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)` + +**PointTransformerConv** - Point cloud transformer +- Transformer for 3D point clouds +- Use for: 3D vision, point cloud segmentation +- Example: `PointTransformerConv(in_channels, out_channels)` + +### Hypergraph Convolutions + +**HypergraphConv** - Hypergraph convolution +- Operates on hyperedges (edges connecting multiple nodes) +- Supports: Lazy +- Use for: Multi-way relationships, chemical reactions +- Example: `HypergraphConv(in_channels, out_channels)` + +**HGTConv** - Heterogeneous Graph Transformer +- Transformer for heterogeneous graphs with multiple types +- Supports: Lazy +- Use for: Heterogeneous networks, knowledge graphs +- Example: `HGTConv(in_channels, out_channels, metadata, heads=8)` + +## Aggregation Operators + +**Aggr** - Base aggregation class +- Flexible aggregation across nodes + +**SumAggregation** - Sum aggregation +- Example: `SumAggregation()` + +**MeanAggregation** - Mean aggregation +- Example: `MeanAggregation()` + +**MaxAggregation** - Max aggregation +- Example: `MaxAggregation()` + +**SoftmaxAggregation** - Softmax-weighted aggregation +- Learnable attention weights +- Example: `SoftmaxAggregation(learn=True)` + +**PowerMeanAggregation** - Power mean aggregation +- Learnable power parameter +- Example: `PowerMeanAggregation(learn=True)` + +**LSTMAggregation** - LSTM-based aggregation +- Sequential processing of neighbors +- Example: `LSTMAggregation(in_channels, out_channels)` + +**SetTransformerAggregation** - Set Transformer aggregation +- Transformer for permutation-invariant aggregation +- Example: `SetTransformerAggregation(in_channels, out_channels)` + +**MultiAggregation** - Multiple aggregations +- Combines multiple aggregation methods +- Example: `MultiAggregation(['mean', 'max', 'std'])` + +## Pooling Layers + +### Global Pooling + +**global_mean_pool** - Global mean pooling +- Averages node features per graph +- Example: `global_mean_pool(x, batch)` + +**global_max_pool** - Global max pooling +- Max over node features per graph +- Example: `global_max_pool(x, batch)` + +**global_add_pool** - Global sum pooling +- Sums node features per graph +- Example: `global_add_pool(x, batch)` + +**global_sort_pool** - Global sort pooling +- Sorts and concatenates top-k nodes +- Example: `global_sort_pool(x, batch, k=30)` + +**GlobalAttention** - Global attention pooling +- Learnable attention weights for aggregation +- Example: `GlobalAttention(gate_nn)` + +**Set2Set** - Set2Set pooling +- LSTM-based attention mechanism +- Example: `Set2Set(in_channels, processing_steps=3)` + +### Hierarchical Pooling + +**TopKPooling** - Top-k pooling +- Keeps top-k nodes based on projection scores +- Example: `TopKPooling(in_channels, ratio=0.5)` + +**SAGPooling** - Self-Attention Graph Pooling +- Uses self-attention for node selection +- Example: `SAGPooling(in_channels, ratio=0.5)` + +**ASAPooling** - Adaptive Structure Aware Pooling +- Structure-aware node selection +- Example: `ASAPooling(in_channels, ratio=0.5)` + +**PANPooling** - Path Attention Pooling +- Attention over paths for pooling +- Example: `PANPooling(in_channels, ratio=0.5)` + +**EdgePooling** - Edge contraction pooling +- Pools by contracting edges +- Example: `EdgePooling(in_channels)` + +**MemPooling** - Memory-based pooling +- Learnable cluster assignments +- Example: `MemPooling(in_channels, out_channels, heads=4, num_clusters=10)` + +**avg_pool** / **max_pool** - Average/Max pool with clustering +- Pools nodes within clusters +- Example: `avg_pool(cluster, data)` + +## Normalization Layers + +**BatchNorm** - Batch normalization +- Normalizes features across batch +- Example: `BatchNorm(in_channels)` + +**LayerNorm** - Layer normalization +- Normalizes features per sample +- Example: `LayerNorm(in_channels)` + +**InstanceNorm** - Instance normalization +- Normalizes per sample and graph +- Example: `InstanceNorm(in_channels)` + +**GraphNorm** - Graph normalization +- Graph-specific normalization +- Example: `GraphNorm(in_channels)` + +**PairNorm** - Pair normalization +- Prevents oversmoothing in deep GNNs +- Example: `PairNorm(scale_individually=False)` + +**MessageNorm** - Message normalization +- Normalizes messages during passing +- Example: `MessageNorm(learn_scale=True)` + +**DiffGroupNorm** - Differentiable Group Normalization +- Learnable grouping for normalization +- Example: `DiffGroupNorm(in_channels, groups=10)` + +## Model Architectures + +### Pre-Built Models + +**GCN** - Complete Graph Convolutional Network +- Multi-layer GCN with dropout +- Example: `GCN(in_channels, hidden_channels, num_layers, out_channels)` + +**GraphSAGE** - Complete GraphSAGE model +- Multi-layer SAGE with dropout +- Example: `GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)` + +**GIN** - Complete Graph Isomorphism Network +- Multi-layer GIN for graph classification +- Example: `GIN(in_channels, hidden_channels, num_layers, out_channels)` + +**GAT** - Complete Graph Attention Network +- Multi-layer GAT with attention +- Example: `GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)` + +**PNA** - Principal Neighbourhood Aggregation +- Combines multiple aggregators and scalers +- Example: `PNA(in_channels, hidden_channels, num_layers, out_channels)` + +**EdgeCNN** - Edge Convolution CNN +- Dynamic graph CNN for point clouds +- Example: `EdgeCNN(out_channels, num_layers=3, k=20)` + +### Auto-Encoders + +**GAE** - Graph Auto-Encoder +- Encodes graphs into latent space +- Example: `GAE(encoder)` + +**VGAE** - Variational Graph Auto-Encoder +- Probabilistic graph encoding +- Example: `VGAE(encoder)` + +**ARGA** - Adversarially Regularized Graph Auto-Encoder +- GAE with adversarial regularization +- Example: `ARGA(encoder, discriminator)` + +**ARGVA** - Adversarially Regularized Variational Graph Auto-Encoder +- VGAE with adversarial regularization +- Example: `ARGVA(encoder, discriminator)` + +### Knowledge Graph Embeddings + +**TransE** - Translating embeddings +- Learns entity and relation embeddings +- Example: `TransE(num_nodes, num_relations, hidden_channels)` + +**RotatE** - Rotational embeddings +- Embeddings in complex space +- Example: `RotatE(num_nodes, num_relations, hidden_channels)` + +**ComplEx** - Complex embeddings +- Complex-valued embeddings +- Example: `ComplEx(num_nodes, num_relations, hidden_channels)` + +**DistMult** - Bilinear diagonal model +- Simplified bilinear model +- Example: `DistMult(num_nodes, num_relations, hidden_channels)` + +## Utility Layers + +**Sequential** - Sequential container +- Chains multiple layers +- Example: `Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])` + +**JumpingKnowledge** - Jumping knowledge connections +- Combines representations from all layers +- Modes: 'cat', 'max', 'lstm' +- Example: `JumpingKnowledge(mode='cat')` + +**DeepGCNLayer** - Deep GCN layer wrapper +- Enables very deep GNNs with skip connections +- Example: `DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)` + +**MLP** - Multi-layer perceptron +- Standard feedforward network +- Example: `MLP([in_channels, 64, 64, out_channels], dropout=0.5)` + +**Linear** - Lazy linear layer +- Linear transformation with lazy initialization +- Example: `Linear(in_channels, out_channels, bias=True)` + +## Dense Layers + +For dense (non-sparse) graph representations: + +**DenseGCNConv** - Dense GCN layer +**DenseSAGEConv** - Dense SAGE layer +**DenseGINConv** - Dense GIN layer +**DenseGraphConv** - Dense graph convolution + +These are useful when working with small, fully-connected, or densely represented graphs. + +## Usage Tips + +1. **Start simple**: Begin with GCNConv or GATConv for most tasks +2. **Consider data type**: Use molecular layers (SchNet, DimeNet) for 3D structures +3. **Check capabilities**: Match layer capabilities to your data (edge features, bipartite, etc.) +4. **Deep networks**: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs +5. **Large graphs**: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling +6. **Heterogeneous**: Use RGCNConv, HGTConv, or to_hetero() conversion +7. **Lazy initialization**: Use lazy layers when input dimensions vary or are unknown + +## Common Patterns + +### Basic GNN +```python +from torch_geometric.nn import GCNConv, global_mean_pool + +class GNN(torch.nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = GCNConv(in_channels, hidden_channels) + self.conv2 = GCNConv(hidden_channels, out_channels) + + def forward(self, x, edge_index, batch): + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index) + return global_mean_pool(x, batch) +``` + +### Deep GNN with Normalization +```python +class DeepGNN(torch.nn.Module): + def __init__(self, in_channels, hidden_channels, num_layers, out_channels): + super().__init__() + self.convs = torch.nn.ModuleList() + self.norms = torch.nn.ModuleList() + + self.convs.append(GCNConv(in_channels, hidden_channels)) + self.norms.append(LayerNorm(hidden_channels)) + + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + self.norms.append(LayerNorm(hidden_channels)) + + self.convs.append(GCNConv(hidden_channels, out_channels)) + self.jk = JumpingKnowledge(mode='cat') + + def forward(self, x, edge_index, batch): + xs = [] + for conv, norm in zip(self.convs[:-1], self.norms): + x = conv(x, edge_index) + x = norm(x) + x = F.relu(x) + xs.append(x) + + x = self.convs[-1](x, edge_index) + xs.append(x) + + x = self.jk(xs) + return global_mean_pool(x, batch) +``` diff --git a/scientific-packages/torch_geometric/references/transforms_reference.md b/scientific-packages/torch_geometric/references/transforms_reference.md new file mode 100644 index 0000000..5149e12 --- /dev/null +++ b/scientific-packages/torch_geometric/references/transforms_reference.md @@ -0,0 +1,679 @@ +# PyTorch Geometric Transforms Reference + +This document provides a comprehensive reference of all transforms available in `torch_geometric.transforms`. + +## Overview + +Transforms modify `Data` or `HeteroData` objects before or during training. Apply them via: + +```python +# During dataset loading +dataset = MyDataset(root='/tmp', transform=MyTransform()) + +# Apply to individual data +transform = MyTransform() +data = transform(data) + +# Compose multiple transforms +from torch_geometric.transforms import Compose +transform = Compose([Transform1(), Transform2(), Transform3()]) +``` + +## General Transforms + +### NormalizeFeatures +**Purpose**: Row-normalizes node features to sum to 1 +**Use case**: Feature scaling, probability-like features +```python +from torch_geometric.transforms import NormalizeFeatures +transform = NormalizeFeatures() +``` + +### ToDevice +**Purpose**: Transfers data to specified device (CPU/GPU) +**Use case**: GPU training, device management +```python +from torch_geometric.transforms import ToDevice +transform = ToDevice('cuda') +``` + +### RandomNodeSplit +**Purpose**: Creates train/val/test node masks +**Use case**: Node classification splits +**Parameters**: `split='train_rest'`, `num_splits`, `num_val`, `num_test` +```python +from torch_geometric.transforms import RandomNodeSplit +transform = RandomNodeSplit(num_val=0.1, num_test=0.2) +``` + +### RandomLinkSplit +**Purpose**: Creates train/val/test edge splits +**Use case**: Link prediction +**Parameters**: `num_val`, `num_test`, `is_undirected`, `split_labels` +```python +from torch_geometric.transforms import RandomLinkSplit +transform = RandomLinkSplit(num_val=0.1, num_test=0.2) +``` + +### IndexToMask +**Purpose**: Converts indices to boolean masks +**Use case**: Data preprocessing +```python +from torch_geometric.transforms import IndexToMask +transform = IndexToMask() +``` + +### MaskToIndex +**Purpose**: Converts boolean masks to indices +**Use case**: Data preprocessing +```python +from torch_geometric.transforms import MaskToIndex +transform = MaskToIndex() +``` + +### FixedPoints +**Purpose**: Samples a fixed number of points +**Use case**: Point cloud subsampling +**Parameters**: `num`, `replace`, `allow_duplicates` +```python +from torch_geometric.transforms import FixedPoints +transform = FixedPoints(1024) +``` + +### ToDense +**Purpose**: Converts to dense adjacency matrices +**Use case**: Small graphs, dense operations +```python +from torch_geometric.transforms import ToDense +transform = ToDense(num_nodes=100) +``` + +### ToSparseTensor +**Purpose**: Converts edge_index to SparseTensor +**Use case**: Efficient sparse operations +**Parameters**: `remove_edge_index`, `fill_cache` +```python +from torch_geometric.transforms import ToSparseTensor +transform = ToSparseTensor() +``` + +## Graph Structure Transforms + +### ToUndirected +**Purpose**: Converts directed graph to undirected +**Use case**: Undirected graph algorithms +**Parameters**: `reduce='add'` (how to handle duplicate edges) +```python +from torch_geometric.transforms import ToUndirected +transform = ToUndirected() +``` + +### AddSelfLoops +**Purpose**: Adds self-loops to all nodes +**Use case**: GCN-style convolutions +**Parameters**: `fill_value` (edge attribute for self-loops) +```python +from torch_geometric.transforms import AddSelfLoops +transform = AddSelfLoops() +``` + +### RemoveSelfLoops +**Purpose**: Removes all self-loops +**Use case**: Cleaning graph structure +```python +from torch_geometric.transforms import RemoveSelfLoops +transform = RemoveSelfLoops() +``` + +### RemoveIsolatedNodes +**Purpose**: Removes nodes without edges +**Use case**: Graph cleaning +```python +from torch_geometric.transforms import RemoveIsolatedNodes +transform = RemoveIsolatedNodes() +``` + +### RemoveDuplicatedEdges +**Purpose**: Removes duplicate edges +**Use case**: Graph cleaning +```python +from torch_geometric.transforms import RemoveDuplicatedEdges +transform = RemoveDuplicatedEdges() +``` + +### LargestConnectedComponents +**Purpose**: Keeps only the largest connected component +**Use case**: Focus on main graph structure +**Parameters**: `num_components` (how many components to keep) +```python +from torch_geometric.transforms import LargestConnectedComponents +transform = LargestConnectedComponents(num_components=1) +``` + +### KNNGraph +**Purpose**: Creates edges based on k-nearest neighbors +**Use case**: Point clouds, spatial data +**Parameters**: `k`, `loop`, `force_undirected`, `flow` +```python +from torch_geometric.transforms import KNNGraph +transform = KNNGraph(k=6) +``` + +### RadiusGraph +**Purpose**: Creates edges within a radius +**Use case**: Point clouds, spatial data +**Parameters**: `r`, `loop`, `max_num_neighbors`, `flow` +```python +from torch_geometric.transforms import RadiusGraph +transform = RadiusGraph(r=0.1) +``` + +### Delaunay +**Purpose**: Computes Delaunay triangulation +**Use case**: 2D/3D spatial graphs +```python +from torch_geometric.transforms import Delaunay +transform = Delaunay() +``` + +### FaceToEdge +**Purpose**: Converts mesh faces to edges +**Use case**: Mesh processing +```python +from torch_geometric.transforms import FaceToEdge +transform = FaceToEdge() +``` + +### LineGraph +**Purpose**: Converts graph to its line graph +**Use case**: Edge-centric analysis +**Parameters**: `force_directed` +```python +from torch_geometric.transforms import LineGraph +transform = LineGraph() +``` + +### GDC +**Purpose**: Graph Diffusion Convolution preprocessing +**Use case**: Improved message passing +**Parameters**: `self_loop_weight`, `normalization_in`, `normalization_out`, `diffusion_kwargs` +```python +from torch_geometric.transforms import GDC +transform = GDC(self_loop_weight=1, normalization_in='sym', + diffusion_kwargs=dict(method='ppr', alpha=0.15)) +``` + +### SIGN +**Purpose**: Scalable Inception Graph Neural Networks preprocessing +**Use case**: Efficient multi-scale features +**Parameters**: `K` (number of hops) +```python +from torch_geometric.transforms import SIGN +transform = SIGN(K=3) +``` + +## Feature Transforms + +### OneHotDegree +**Purpose**: One-hot encodes node degree +**Use case**: Degree as feature +**Parameters**: `max_degree`, `cat` (concatenate with existing features) +```python +from torch_geometric.transforms import OneHotDegree +transform = OneHotDegree(max_degree=100) +``` + +### LocalDegreeProfile +**Purpose**: Appends local degree profile +**Use case**: Structural node features +```python +from torch_geometric.transforms import LocalDegreeProfile +transform = LocalDegreeProfile() +``` + +### Constant +**Purpose**: Adds constant features to nodes +**Use case**: Featureless graphs +**Parameters**: `value`, `cat` +```python +from torch_geometric.transforms import Constant +transform = Constant(value=1.0) +``` + +### TargetIndegree +**Purpose**: Saves in-degree as target +**Use case**: Degree prediction +**Parameters**: `norm`, `max_value` +```python +from torch_geometric.transforms import TargetIndegree +transform = TargetIndegree(norm=False) +``` + +### AddRandomWalkPE +**Purpose**: Adds random walk positional encoding +**Use case**: Positional information +**Parameters**: `walk_length`, `attr_name` +```python +from torch_geometric.transforms import AddRandomWalkPE +transform = AddRandomWalkPE(walk_length=20) +``` + +### AddLaplacianEigenvectorPE +**Purpose**: Adds Laplacian eigenvector positional encoding +**Use case**: Spectral positional information +**Parameters**: `k` (number of eigenvectors), `attr_name` +```python +from torch_geometric.transforms import AddLaplacianEigenvectorPE +transform = AddLaplacianEigenvectorPE(k=10) +``` + +### AddMetaPaths +**Purpose**: Adds meta-path induced edges +**Use case**: Heterogeneous graphs +**Parameters**: `metapaths`, `drop_orig_edges`, `drop_unconnected_nodes` +```python +from torch_geometric.transforms import AddMetaPaths +metapaths = [[('author', 'paper'), ('paper', 'author')]] # Co-authorship +transform = AddMetaPaths(metapaths) +``` + +### SVDFeatureReduction +**Purpose**: Reduces feature dimensionality via SVD +**Use case**: Dimensionality reduction +**Parameters**: `out_channels` +```python +from torch_geometric.transforms import SVDFeatureReduction +transform = SVDFeatureReduction(out_channels=64) +``` + +## Vision/Spatial Transforms + +### Center +**Purpose**: Centers node positions +**Use case**: Point cloud preprocessing +```python +from torch_geometric.transforms import Center +transform = Center() +``` + +### NormalizeScale +**Purpose**: Normalizes positions to unit sphere +**Use case**: Point cloud normalization +```python +from torch_geometric.transforms import NormalizeScale +transform = NormalizeScale() +``` + +### NormalizeRotation +**Purpose**: Rotates to principal components +**Use case**: Rotation-invariant learning +**Parameters**: `max_points` +```python +from torch_geometric.transforms import NormalizeRotation +transform = NormalizeRotation() +``` + +### Distance +**Purpose**: Saves Euclidean distance as edge attribute +**Use case**: Spatial graphs +**Parameters**: `norm`, `max_value`, `cat` +```python +from torch_geometric.transforms import Distance +transform = Distance(norm=False, cat=False) +``` + +### Cartesian +**Purpose**: Saves relative Cartesian coordinates as edge attributes +**Use case**: Spatial relationships +**Parameters**: `norm`, `max_value`, `cat` +```python +from torch_geometric.transforms import Cartesian +transform = Cartesian(norm=False) +``` + +### Polar +**Purpose**: Saves polar coordinates as edge attributes +**Use case**: 2D spatial graphs +**Parameters**: `norm`, `max_value`, `cat` +```python +from torch_geometric.transforms import Polar +transform = Polar(norm=False) +``` + +### Spherical +**Purpose**: Saves spherical coordinates as edge attributes +**Use case**: 3D spatial graphs +**Parameters**: `norm`, `max_value`, `cat` +```python +from torch_geometric.transforms import Spherical +transform = Spherical(norm=False) +``` + +### LocalCartesian +**Purpose**: Saves coordinates in local coordinate system +**Use case**: Local spatial features +**Parameters**: `norm`, `cat` +```python +from torch_geometric.transforms import LocalCartesian +transform = LocalCartesian() +``` + +### PointPairFeatures +**Purpose**: Computes point pair features +**Use case**: 3D registration, correspondence +**Parameters**: `cat` +```python +from torch_geometric.transforms import PointPairFeatures +transform = PointPairFeatures() +``` + +## Data Augmentation + +### RandomJitter +**Purpose**: Randomly jitters node positions +**Use case**: Point cloud augmentation +**Parameters**: `translate`, `scale` +```python +from torch_geometric.transforms import RandomJitter +transform = RandomJitter(0.01) +``` + +### RandomFlip +**Purpose**: Randomly flips positions along axis +**Use case**: Geometric augmentation +**Parameters**: `axis`, `p` (probability) +```python +from torch_geometric.transforms import RandomFlip +transform = RandomFlip(axis=0, p=0.5) +``` + +### RandomScale +**Purpose**: Randomly scales positions +**Use case**: Scale augmentation +**Parameters**: `scales` (min, max) +```python +from torch_geometric.transforms import RandomScale +transform = RandomScale((0.9, 1.1)) +``` + +### RandomRotate +**Purpose**: Randomly rotates positions +**Use case**: Rotation augmentation +**Parameters**: `degrees` (range), `axis` (rotation axis) +```python +from torch_geometric.transforms import RandomRotate +transform = RandomRotate(degrees=15, axis=2) +``` + +### RandomShear +**Purpose**: Randomly shears positions +**Use case**: Geometric augmentation +**Parameters**: `shear` (range) +```python +from torch_geometric.transforms import RandomShear +transform = RandomShear(0.1) +``` + +### RandomTranslate +**Purpose**: Randomly translates positions +**Use case**: Translation augmentation +**Parameters**: `translate` (range) +```python +from torch_geometric.transforms import RandomTranslate +transform = RandomTranslate(0.1) +``` + +### LinearTransformation +**Purpose**: Applies linear transformation matrix +**Use case**: Custom geometric transforms +**Parameters**: `matrix` +```python +from torch_geometric.transforms import LinearTransformation +import torch +matrix = torch.eye(3) +transform = LinearTransformation(matrix) +``` + +## Mesh Processing + +### SamplePoints +**Purpose**: Samples points uniformly from mesh +**Use case**: Mesh to point cloud conversion +**Parameters**: `num`, `remove_faces`, `include_normals` +```python +from torch_geometric.transforms import SamplePoints +transform = SamplePoints(num=1024) +``` + +### GenerateMeshNormals +**Purpose**: Generates face/vertex normals +**Use case**: Mesh processing +```python +from torch_geometric.transforms import GenerateMeshNormals +transform = GenerateMeshNormals() +``` + +### FaceToEdge +**Purpose**: Converts mesh faces to edges +**Use case**: Mesh to graph conversion +**Parameters**: `remove_faces` +```python +from torch_geometric.transforms import FaceToEdge +transform = FaceToEdge() +``` + +## Sampling and Splitting + +### GridSampling +**Purpose**: Clusters points in voxel grid +**Use case**: Point cloud downsampling +**Parameters**: `size` (voxel size), `start`, `end` +```python +from torch_geometric.transforms import GridSampling +transform = GridSampling(size=0.1) +``` + +### FixedPoints +**Purpose**: Samples fixed number of points +**Use case**: Uniform point cloud size +**Parameters**: `num`, `replace`, `allow_duplicates` +```python +from torch_geometric.transforms import FixedPoints +transform = FixedPoints(num=2048, replace=False) +``` + +### RandomScale +**Purpose**: Randomly scales by sampling from range +**Use case**: Scale augmentation (already listed above) + +### VirtualNode +**Purpose**: Adds a virtual node connected to all nodes +**Use case**: Global information propagation +```python +from torch_geometric.transforms import VirtualNode +transform = VirtualNode() +``` + +## Specialized Transforms + +### ToSLIC +**Purpose**: Converts images to superpixel graphs (SLIC algorithm) +**Use case**: Image as graph +**Parameters**: `num_segments`, `compactness`, `add_seg`, `add_img` +```python +from torch_geometric.transforms import ToSLIC +transform = ToSLIC(num_segments=75) +``` + +### GCNNorm +**Purpose**: Applies GCN-style normalization to edges +**Use case**: Preprocessing for GCN +**Parameters**: `add_self_loops` +```python +from torch_geometric.transforms import GCNNorm +transform = GCNNorm(add_self_loops=True) +``` + +### LaplacianLambdaMax +**Purpose**: Computes largest Laplacian eigenvalue +**Use case**: ChebConv preprocessing +**Parameters**: `normalization`, `is_undirected` +```python +from torch_geometric.transforms import LaplacianLambdaMax +transform = LaplacianLambdaMax(normalization='sym') +``` + +### NormalizeRotation +**Purpose**: Rotates mesh/point cloud to align with principal axes +**Use case**: Canonical orientation +**Parameters**: `max_points` +```python +from torch_geometric.transforms import NormalizeRotation +transform = NormalizeRotation() +``` + +## Compose and Apply + +### Compose +**Purpose**: Chains multiple transforms +**Use case**: Complex preprocessing pipelines +```python +from torch_geometric.transforms import Compose +transform = Compose([ + Center(), + NormalizeScale(), + KNNGraph(k=6), + Distance(norm=False), +]) +``` + +### BaseTransform +**Purpose**: Base class for custom transforms +**Use case**: Implementing custom transforms +```python +from torch_geometric.transforms import BaseTransform + +class MyTransform(BaseTransform): + def __init__(self, param): + self.param = param + + def __call__(self, data): + # Modify data + data.x = data.x * self.param + return data +``` + +## Common Transform Combinations + +### Node Classification Preprocessing +```python +transform = Compose([ + NormalizeFeatures(), + RandomNodeSplit(num_val=0.1, num_test=0.2), +]) +``` + +### Point Cloud Processing +```python +transform = Compose([ + Center(), + NormalizeScale(), + RandomRotate(degrees=15, axis=2), + RandomJitter(0.01), + KNNGraph(k=6), + Distance(norm=False), +]) +``` + +### Mesh to Graph +```python +transform = Compose([ + FaceToEdge(remove_faces=True), + GenerateMeshNormals(), + Distance(norm=True), +]) +``` + +### Graph Structure Enhancement +```python +transform = Compose([ + ToUndirected(), + AddSelfLoops(), + RemoveIsolatedNodes(), + GCNNorm(), +]) +``` + +### Heterogeneous Graph Preprocessing +```python +transform = Compose([ + AddMetaPaths(metapaths=[ + [('author', 'paper'), ('paper', 'author')], + [('author', 'paper'), ('paper', 'conference'), ('conference', 'paper'), ('paper', 'author')] + ]), + RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2), +]) +``` + +### Link Prediction +```python +transform = Compose([ + NormalizeFeatures(), + RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True), +]) +``` + +## Usage Tips + +1. **Order matters**: Apply structural transforms before feature transforms +2. **Caching**: Some transforms (like GDC) are expensive—apply once +3. **Augmentation**: Use Random* transforms during training only +4. **Compose sparingly**: Too many transforms slow down data loading +5. **Custom transforms**: Inherit from `BaseTransform` for custom logic +6. **Pre-transforms**: Apply expensive transforms once during dataset processing: + ```python + dataset = MyDataset(root='/tmp', pre_transform=ExpensiveTransform()) + ``` +7. **Dynamic transforms**: Apply cheap transforms during training: + ```python + dataset = MyDataset(root='/tmp', transform=CheapTransform()) + ``` + +## Performance Considerations + +**Expensive transforms** (apply as pre_transform): +- GDC +- SIGN +- KNNGraph (for large point clouds) +- AddLaplacianEigenvectorPE +- SVDFeatureReduction + +**Cheap transforms** (apply as transform): +- NormalizeFeatures +- ToUndirected +- AddSelfLoops +- Random* augmentations +- ToDevice + +**Example**: +```python +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import Compose, GDC, NormalizeFeatures + +# Expensive preprocessing done once +pre_transform = GDC( + self_loop_weight=1, + normalization_in='sym', + diffusion_kwargs=dict(method='ppr', alpha=0.15) +) + +# Cheap transform applied each time +transform = NormalizeFeatures() + +dataset = Planetoid( + root='/tmp/Cora', + name='Cora', + pre_transform=pre_transform, + transform=transform +) +``` diff --git a/scientific-packages/torch_geometric/scripts/benchmark_model.py b/scientific-packages/torch_geometric/scripts/benchmark_model.py new file mode 100644 index 0000000..cddf565 --- /dev/null +++ b/scientific-packages/torch_geometric/scripts/benchmark_model.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +Benchmark GNN models on standard datasets. + +This script provides a simple way to benchmark different GNN architectures +on common datasets and compare their performance. + +Usage: + python benchmark_model.py --models gcn gat --dataset Cora + python benchmark_model.py --models gcn --dataset Cora --epochs 200 --runs 10 +""" + +import argparse +import torch +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv +from torch_geometric.datasets import Planetoid, TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import global_mean_pool +import time +import numpy as np + + +class GCN(torch.nn.Module): + def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5): + super().__init__() + self.conv1 = GCNConv(num_features, hidden_channels) + self.conv2 = GCNConv(hidden_channels, num_classes) + self.dropout = dropout + + def forward(self, x, edge_index, batch=None): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + if batch is not None: + x = global_mean_pool(x, batch) + return F.log_softmax(x, dim=1) + + +class GAT(torch.nn.Module): + def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6): + super().__init__() + self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout) + self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1, + concat=False, dropout=dropout) + self.dropout = dropout + + def forward(self, x, edge_index, batch=None): + x = F.dropout(x, p=self.dropout, training=self.training) + x = F.elu(self.conv1(x, edge_index)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + if batch is not None: + x = global_mean_pool(x, batch) + return F.log_softmax(x, dim=1) + + +class GraphSAGE(torch.nn.Module): + def __init__(self, num_features, hidden_channels, num_classes, dropout=0.5): + super().__init__() + self.conv1 = SAGEConv(num_features, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, num_classes) + self.dropout = dropout + + def forward(self, x, edge_index, batch=None): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + if batch is not None: + x = global_mean_pool(x, batch) + return F.log_softmax(x, dim=1) + + +MODELS = { + 'gcn': GCN, + 'gat': GAT, + 'graphsage': GraphSAGE, +} + + +def train_node_classification(model, data, optimizer): + """Train for node classification.""" + model.train() + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + return loss.item() + + +@torch.no_grad() +def test_node_classification(model, data): + """Test for node classification.""" + model.eval() + out = model(data.x, data.edge_index) + pred = out.argmax(dim=1) + + accs = [] + for mask in [data.train_mask, data.val_mask, data.test_mask]: + correct = (pred[mask] == data.y[mask]).sum() + accs.append(float(correct) / int(mask.sum())) + + return accs + + +def train_graph_classification(model, loader, optimizer, device): + """Train for graph classification.""" + model.train() + total_loss = 0 + + for data in loader: + data = data.to(device) + optimizer.zero_grad() + out = model(data.x, data.edge_index, data.batch) + loss = F.nll_loss(out, data.y) + loss.backward() + optimizer.step() + total_loss += loss.item() * data.num_graphs + + return total_loss / len(loader.dataset) + + +@torch.no_grad() +def test_graph_classification(model, loader, device): + """Test for graph classification.""" + model.eval() + correct = 0 + + for data in loader: + data = data.to(device) + out = model(data.x, data.edge_index, data.batch) + pred = out.argmax(dim=1) + correct += (pred == data.y).sum().item() + + return correct / len(loader.dataset) + + +def benchmark_node_classification(model_name, dataset_name, epochs, lr, weight_decay, device): + """Benchmark a model on node classification.""" + # Load dataset + dataset = Planetoid(root=f'/tmp/{dataset_name}', name=dataset_name) + data = dataset[0].to(device) + + # Create model + model_class = MODELS[model_name] + model = model_class( + num_features=dataset.num_features, + hidden_channels=64, + num_classes=dataset.num_classes + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + + # Training + start_time = time.time() + best_val_acc = 0 + best_test_acc = 0 + + for epoch in range(1, epochs + 1): + loss = train_node_classification(model, data, optimizer) + train_acc, val_acc, test_acc = test_node_classification(model, data) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + + train_time = time.time() - start_time + + return { + 'train_acc': train_acc, + 'val_acc': best_val_acc, + 'test_acc': best_test_acc, + 'train_time': train_time, + } + + +def benchmark_graph_classification(model_name, dataset_name, epochs, lr, device): + """Benchmark a model on graph classification.""" + # Load dataset + dataset = TUDataset(root=f'/tmp/{dataset_name}', name=dataset_name) + + # Split dataset + dataset = dataset.shuffle() + train_dataset = dataset[:int(len(dataset) * 0.8)] + test_dataset = dataset[int(len(dataset) * 0.8):] + + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=32) + + # Create model + model_class = MODELS[model_name] + model = model_class( + num_features=dataset.num_features, + hidden_channels=64, + num_classes=dataset.num_classes + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + # Training + start_time = time.time() + + for epoch in range(1, epochs + 1): + loss = train_graph_classification(model, train_loader, optimizer, device) + + # Final evaluation + train_acc = test_graph_classification(model, train_loader, device) + test_acc = test_graph_classification(model, test_loader, device) + train_time = time.time() - start_time + + return { + 'train_acc': train_acc, + 'test_acc': test_acc, + 'train_time': train_time, + } + + +def run_benchmark(args): + """Run benchmark experiments.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Determine task type + if args.dataset in ['Cora', 'CiteSeer', 'PubMed']: + task = 'node_classification' + else: + task = 'graph_classification' + + print(f"\\nDataset: {args.dataset}") + print(f"Task: {task}") + print(f"Models: {', '.join(args.models)}") + print(f"Epochs: {args.epochs}") + print(f"Runs: {args.runs}") + print("=" * 60) + + results = {model: [] for model in args.models} + + # Run experiments + for run in range(args.runs): + print(f"\\nRun {run + 1}/{args.runs}") + print("-" * 60) + + for model_name in args.models: + if model_name not in MODELS: + print(f"Unknown model: {model_name}") + continue + + print(f" Training {model_name.upper()}...", end=" ") + + try: + if task == 'node_classification': + result = benchmark_node_classification( + model_name, args.dataset, args.epochs, + args.lr, args.weight_decay, device + ) + print(f"Test Acc: {result['test_acc']:.4f}, " + f"Time: {result['train_time']:.2f}s") + else: + result = benchmark_graph_classification( + model_name, args.dataset, args.epochs, args.lr, device + ) + print(f"Test Acc: {result['test_acc']:.4f}, " + f"Time: {result['train_time']:.2f}s") + + results[model_name].append(result) + except Exception as e: + print(f"Error: {e}") + + # Print summary + print("\\n" + "=" * 60) + print("BENCHMARK RESULTS") + print("=" * 60) + + for model_name in args.models: + if not results[model_name]: + continue + + test_accs = [r['test_acc'] for r in results[model_name]] + times = [r['train_time'] for r in results[model_name]] + + print(f"\\n{model_name.upper()}") + print(f" Test Accuracy: {np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}") + print(f" Training Time: {np.mean(times):.2f} ± {np.std(times):.2f}s") + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark GNN models") + parser.add_argument('--models', nargs='+', default=['gcn'], + help='Model types to benchmark (gcn, gat, graphsage)') + parser.add_argument('--dataset', type=str, default='Cora', + help='Dataset name (Cora, CiteSeer, PubMed, ENZYMES, PROTEINS)') + parser.add_argument('--epochs', type=int, default=200, + help='Number of training epochs') + parser.add_argument('--runs', type=int, default=5, + help='Number of runs to average over') + parser.add_argument('--lr', type=float, default=0.01, + help='Learning rate') + parser.add_argument('--weight-decay', type=float, default=5e-4, + help='Weight decay for node classification') + + args = parser.parse_args() + run_benchmark(args) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/torch_geometric/scripts/create_gnn_template.py b/scientific-packages/torch_geometric/scripts/create_gnn_template.py new file mode 100644 index 0000000..3882b4d --- /dev/null +++ b/scientific-packages/torch_geometric/scripts/create_gnn_template.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python3 +""" +Generate boilerplate code for common GNN architectures in PyTorch Geometric. + +This script creates ready-to-use GNN model templates with training loops, +evaluation metrics, and proper data handling. + +Usage: + python create_gnn_template.py --model gcn --task node_classification --output my_model.py + python create_gnn_template.py --model gat --task graph_classification --output graph_classifier.py +""" + +import argparse +from pathlib import Path + + +TEMPLATES = { + 'node_classification': { + 'gcn': '''import torch +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric.datasets import Planetoid + + +class GCN(torch.nn.Module): + """Graph Convolutional Network for node classification.""" + + def __init__(self, num_features, hidden_channels, num_classes, num_layers=2, dropout=0.5): + super().__init__() + self.convs = torch.nn.ModuleList() + + # First layer + self.convs.append(GCNConv(num_features, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, num_classes)) + + self.dropout = dropout + + def forward(self, data): + x, edge_index = data.x, data.edge_index + + # Apply conv layers with ReLU and dropout + for conv in self.convs[:-1]: + x = conv(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # Final layer without activation + x = self.convs[-1](x, edge_index) + return F.log_softmax(x, dim=1) + + +def train(model, data, optimizer): + """Train the model for one epoch.""" + model.train() + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + return loss.item() + + +@torch.no_grad() +def test(model, data): + """Evaluate the model.""" + model.eval() + out = model(data) + pred = out.argmax(dim=1) + + accs = [] + for mask in [data.train_mask, data.val_mask, data.test_mask]: + correct = (pred[mask] == data.y[mask]).sum() + accs.append(int(correct) / int(mask.sum())) + + return accs + + +def main(): + # Load dataset + dataset = Planetoid(root='/tmp/Cora', name='Cora') + data = dataset[0] + + # Create model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = GCN( + num_features=dataset.num_features, + hidden_channels=64, + num_classes=dataset.num_classes, + num_layers=3, + dropout=0.5 + ).to(device) + data = data.to(device) + + # Setup optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + + # Training loop + print("Training GCN model...") + best_val_acc = 0 + for epoch in range(1, 201): + loss = train(model, data, optimizer) + train_acc, val_acc, test_acc = test(model, data) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + + if epoch % 10 == 0: + print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, ' + f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') + + print(f'\\nBest Test Accuracy: {best_test_acc:.4f}') + + +if __name__ == '__main__': + main() +''', + + 'gat': '''import torch +import torch.nn.functional as F +from torch_geometric.nn import GATConv +from torch_geometric.datasets import Planetoid + + +class GAT(torch.nn.Module): + """Graph Attention Network for node classification.""" + + def __init__(self, num_features, hidden_channels, num_classes, heads=8, dropout=0.6): + super().__init__() + + self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=dropout) + self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1, + concat=False, dropout=dropout) + + self.dropout = dropout + + def forward(self, data): + x, edge_index = data.x, data.edge_index + + x = F.dropout(x, p=self.dropout, training=self.training) + x = F.elu(self.conv1(x, edge_index)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + + return F.log_softmax(x, dim=1) + + +def train(model, data, optimizer): + """Train the model for one epoch.""" + model.train() + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + return loss.item() + + +@torch.no_grad() +def test(model, data): + """Evaluate the model.""" + model.eval() + out = model(data) + pred = out.argmax(dim=1) + + accs = [] + for mask in [data.train_mask, data.val_mask, data.test_mask]: + correct = (pred[mask] == data.y[mask]).sum() + accs.append(int(correct) / int(mask.sum())) + + return accs + + +def main(): + # Load dataset + dataset = Planetoid(root='/tmp/Cora', name='Cora') + data = dataset[0] + + # Create model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = GAT( + num_features=dataset.num_features, + hidden_channels=8, + num_classes=dataset.num_classes, + heads=8, + dropout=0.6 + ).to(device) + data = data.to(device) + + # Setup optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) + + # Training loop + print("Training GAT model...") + best_val_acc = 0 + for epoch in range(1, 201): + loss = train(model, data, optimizer) + train_acc, val_acc, test_acc = test(model, data) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + + if epoch % 10 == 0: + print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, ' + f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') + + print(f'\\nBest Test Accuracy: {best_test_acc:.4f}') + + +if __name__ == '__main__': + main() +''', + + 'graphsage': '''import torch +import torch.nn.functional as F +from torch_geometric.nn import SAGEConv +from torch_geometric.datasets import Planetoid + + +class GraphSAGE(torch.nn.Module): + """GraphSAGE for node classification.""" + + def __init__(self, num_features, hidden_channels, num_classes, num_layers=2, dropout=0.5): + super().__init__() + self.convs = torch.nn.ModuleList() + + self.convs.append(SAGEConv(num_features, hidden_channels)) + for _ in range(num_layers - 2): + self.convs.append(SAGEConv(hidden_channels, hidden_channels)) + self.convs.append(SAGEConv(hidden_channels, num_classes)) + + self.dropout = dropout + + def forward(self, data): + x, edge_index = data.x, data.edge_index + + for conv in self.convs[:-1]: + x = conv(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + x = self.convs[-1](x, edge_index) + return F.log_softmax(x, dim=1) + + +def train(model, data, optimizer): + model.train() + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + return loss.item() + + +@torch.no_grad() +def test(model, data): + model.eval() + out = model(data) + pred = out.argmax(dim=1) + + accs = [] + for mask in [data.train_mask, data.val_mask, data.test_mask]: + correct = (pred[mask] == data.y[mask]).sum() + accs.append(int(correct) / int(mask.sum())) + + return accs + + +def main(): + dataset = Planetoid(root='/tmp/Cora', name='Cora') + data = dataset[0] + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = GraphSAGE( + num_features=dataset.num_features, + hidden_channels=64, + num_classes=dataset.num_classes, + num_layers=2, + dropout=0.5 + ).to(device) + data = data.to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + + print("Training GraphSAGE model...") + best_val_acc = 0 + for epoch in range(1, 201): + loss = train(model, data, optimizer) + train_acc, val_acc, test_acc = test(model, data) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + + if epoch % 10 == 0: + print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, ' + f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}') + + print(f'\\nBest Test Accuracy: {best_test_acc:.4f}') + + +if __name__ == '__main__': + main() +''', + }, + + 'graph_classification': { + 'gin': '''import torch +import torch.nn.functional as F +from torch_geometric.nn import GINConv, global_add_pool +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader + + +class GIN(torch.nn.Module): + """Graph Isomorphism Network for graph classification.""" + + def __init__(self, num_features, hidden_channels, num_classes, num_layers=3, dropout=0.5): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.batch_norms = torch.nn.ModuleList() + + # Create MLP for first layer + nn = torch.nn.Sequential( + torch.nn.Linear(num_features, hidden_channels), + torch.nn.ReLU(), + torch.nn.Linear(hidden_channels, hidden_channels) + ) + self.convs.append(GINConv(nn)) + self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + nn = torch.nn.Sequential( + torch.nn.Linear(hidden_channels, hidden_channels), + torch.nn.ReLU(), + torch.nn.Linear(hidden_channels, hidden_channels) + ) + self.convs.append(GINConv(nn)) + self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels)) + + # Output MLP + self.lin = torch.nn.Linear(hidden_channels, num_classes) + self.dropout = dropout + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + + for conv, batch_norm in zip(self.convs, self.batch_norms): + x = conv(x, edge_index) + x = batch_norm(x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # Global pooling + x = global_add_pool(x, batch) + + # Output layer + x = self.lin(x) + return F.log_softmax(x, dim=1) + + +def train(model, loader, optimizer, device): + """Train the model for one epoch.""" + model.train() + total_loss = 0 + + for data in loader: + data = data.to(device) + optimizer.zero_grad() + out = model(data) + loss = F.nll_loss(out, data.y) + loss.backward() + optimizer.step() + total_loss += loss.item() * data.num_graphs + + return total_loss / len(loader.dataset) + + +@torch.no_grad() +def test(model, loader, device): + """Evaluate the model.""" + model.eval() + correct = 0 + + for data in loader: + data = data.to(device) + out = model(data) + pred = out.argmax(dim=1) + correct += (pred == data.y).sum().item() + + return correct / len(loader.dataset) + + +def main(): + # Load dataset + dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') + print(f"Dataset: {dataset}") + print(f"Number of graphs: {len(dataset)}") + print(f"Number of features: {dataset.num_features}") + print(f"Number of classes: {dataset.num_classes}") + + # Shuffle and split + dataset = dataset.shuffle() + train_dataset = dataset[:int(len(dataset) * 0.8)] + test_dataset = dataset[int(len(dataset) * 0.8):] + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=32) + + # Create model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = GIN( + num_features=dataset.num_features, + hidden_channels=64, + num_classes=dataset.num_classes, + num_layers=3, + dropout=0.5 + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + # Training loop + print("\\nTraining GIN model...") + for epoch in range(1, 101): + loss = train(model, train_loader, optimizer, device) + train_acc = test(model, train_loader, device) + test_acc = test(model, test_loader, device) + + if epoch % 10 == 0: + print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, ' + f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}') + + +if __name__ == '__main__': + main() +''', + }, +} + + +def generate_template(model_type: str, task: str, output_path: str): + """Generate a GNN template file.""" + if task not in TEMPLATES: + raise ValueError(f"Unknown task: {task}. Available: {list(TEMPLATES.keys())}") + + if model_type not in TEMPLATES[task]: + raise ValueError(f"Model {model_type} not available for task {task}. " + f"Available: {list(TEMPLATES[task].keys())}") + + template = TEMPLATES[task][model_type] + + # Write to file + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + + with open(output_file, 'w') as f: + f.write(template) + + print(f"✓ Generated {model_type.upper()} template for {task}") + print(f" Saved to: {output_path}") + print(f"\\nTo run the template:") + print(f" python {output_path}") + + +def list_templates(): + """List all available templates.""" + print("Available GNN Templates") + print("=" * 50) + for task, models in TEMPLATES.items(): + print(f"\\n{task.upper()}") + print("-" * 50) + for model in models.keys(): + print(f" - {model}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Generate GNN model templates", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python create_gnn_template.py --model gcn --task node_classification --output gcn_model.py + python create_gnn_template.py --model gin --task graph_classification --output gin_model.py + python create_gnn_template.py --list + """ + ) + + parser.add_argument('--model', type=str, + help='Model type (gcn, gat, graphsage, gin)') + parser.add_argument('--task', type=str, + help='Task type (node_classification, graph_classification)') + parser.add_argument('--output', type=str, default='gnn_model.py', + help='Output file path (default: gnn_model.py)') + parser.add_argument('--list', action='store_true', + help='List all available templates') + + args = parser.parse_args() + + if args.list: + list_templates() + return + + if not args.model or not args.task: + parser.print_help() + print("\\n" + "=" * 50) + list_templates() + return + + try: + generate_template(args.model, args.task, args.output) + except ValueError as e: + print(f"Error: {e}") + print("\\nUse --list to see available templates") + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/torch_geometric/scripts/visualize_graph.py b/scientific-packages/torch_geometric/scripts/visualize_graph.py new file mode 100644 index 0000000..58b8783 --- /dev/null +++ b/scientific-packages/torch_geometric/scripts/visualize_graph.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +""" +Visualize PyTorch Geometric graph structures using networkx and matplotlib. + +This script provides utilities to visualize Data objects, including: +- Graph structure (nodes and edges) +- Node features (as colors) +- Edge attributes (as edge colors/widths) +- Community/cluster assignments + +Usage: + python visualize_graph.py --dataset Cora --output graph.png + +Or import and use: + from scripts.visualize_graph import visualize_data + visualize_data(data, title="My Graph", show_labels=True) +""" + +import argparse +import matplotlib.pyplot as plt +import networkx as nx +import torch +from typing import Optional, Union +import numpy as np + + +def visualize_data( + data, + title: str = "Graph Visualization", + node_color_attr: Optional[str] = None, + edge_color_attr: Optional[str] = None, + show_labels: bool = False, + node_size: int = 300, + figsize: tuple = (12, 10), + layout: str = "spring", + output_path: Optional[str] = None, + max_nodes: Optional[int] = None, +): + """ + Visualize a PyTorch Geometric Data object. + + Args: + data: PyTorch Geometric Data object + title: Plot title + node_color_attr: Data attribute to use for node colors (e.g., 'y', 'train_mask') + edge_color_attr: Data attribute to use for edge colors + show_labels: Whether to show node labels + node_size: Size of nodes in visualization + figsize: Figure size (width, height) + layout: Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral') + output_path: Path to save figure (if None, displays interactively) + max_nodes: Maximum number of nodes to visualize (samples if exceeded) + """ + # Sample nodes if graph is too large + if max_nodes and data.num_nodes > max_nodes: + print(f"Graph has {data.num_nodes} nodes. Sampling {max_nodes} nodes for visualization.") + node_indices = torch.randperm(data.num_nodes)[:max_nodes] + data = data.subgraph(node_indices) + + # Convert to networkx graph + G = nx.Graph() if is_undirected(data.edge_index) else nx.DiGraph() + + # Add nodes + G.add_nodes_from(range(data.num_nodes)) + + # Add edges + edge_index = data.edge_index.cpu().numpy() + edges = list(zip(edge_index[0], edge_index[1])) + G.add_edges_from(edges) + + # Setup figure + fig, ax = plt.subplots(figsize=figsize) + + # Choose layout + if layout == "spring": + pos = nx.spring_layout(G, k=0.5, iterations=50) + elif layout == "circular": + pos = nx.circular_layout(G) + elif layout == "kamada_kawai": + pos = nx.kamada_kawai_layout(G) + elif layout == "spectral": + pos = nx.spectral_layout(G) + else: + raise ValueError(f"Unknown layout: {layout}") + + # Determine node colors + if node_color_attr and hasattr(data, node_color_attr): + node_colors = getattr(data, node_color_attr).cpu().numpy() + if node_colors.dtype == bool: + node_colors = node_colors.astype(int) + if len(node_colors.shape) > 1: + # Multi-dimensional features - use first dimension + node_colors = node_colors[:, 0] + else: + node_colors = 'skyblue' + + # Determine edge colors + if edge_color_attr and hasattr(data, edge_color_attr): + edge_colors = getattr(data, edge_color_attr).cpu().numpy() + if len(edge_colors.shape) > 1: + edge_colors = edge_colors[:, 0] + else: + edge_colors = 'gray' + + # Draw graph + nx.draw_networkx_nodes( + G, pos, + node_color=node_colors, + node_size=node_size, + cmap=plt.cm.viridis, + ax=ax + ) + + nx.draw_networkx_edges( + G, pos, + edge_color=edge_colors, + alpha=0.3, + arrows=isinstance(G, nx.DiGraph), + arrowsize=10, + ax=ax + ) + + if show_labels: + nx.draw_networkx_labels(G, pos, font_size=8, ax=ax) + + ax.set_title(title, fontsize=16, fontweight='bold') + ax.axis('off') + + # Add colorbar if using numeric node colors + if node_color_attr and isinstance(node_colors, np.ndarray): + sm = plt.cm.ScalarMappable( + cmap=plt.cm.viridis, + norm=plt.Normalize(vmin=node_colors.min(), vmax=node_colors.max()) + ) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04) + cbar.set_label(node_color_attr, rotation=270, labelpad=20) + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Figure saved to {output_path}") + else: + plt.show() + + plt.close() + + +def is_undirected(edge_index): + """Check if graph is undirected.""" + row, col = edge_index + num_edges = edge_index.size(1) + + # Create a set of edges and reverse edges + edges = set(zip(row.tolist(), col.tolist())) + reverse_edges = set(zip(col.tolist(), row.tolist())) + + # Check if all edges have their reverse + return edges == reverse_edges + + +def plot_degree_distribution(data, output_path: Optional[str] = None): + """Plot the degree distribution of the graph.""" + from torch_geometric.utils import degree + + row, col = data.edge_index + deg = degree(col, data.num_nodes).cpu().numpy() + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + + # Histogram + ax1.hist(deg, bins=50, edgecolor='black', alpha=0.7) + ax1.set_xlabel('Degree', fontsize=12) + ax1.set_ylabel('Frequency', fontsize=12) + ax1.set_title('Degree Distribution', fontsize=14, fontweight='bold') + ax1.grid(alpha=0.3) + + # Log-log plot + unique_degrees, counts = np.unique(deg, return_counts=True) + ax2.loglog(unique_degrees, counts, 'o-', alpha=0.7) + ax2.set_xlabel('Degree (log scale)', fontsize=12) + ax2.set_ylabel('Frequency (log scale)', fontsize=12) + ax2.set_title('Degree Distribution (Log-Log)', fontsize=14, fontweight='bold') + ax2.grid(alpha=0.3) + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Degree distribution saved to {output_path}") + else: + plt.show() + + plt.close() + + +def plot_graph_statistics(data, output_path: Optional[str] = None): + """Plot various graph statistics.""" + from torch_geometric.utils import degree, contains_self_loops, is_undirected as check_undirected + + # Compute statistics + row, col = data.edge_index + deg = degree(col, data.num_nodes).cpu().numpy() + + stats = { + 'Nodes': data.num_nodes, + 'Edges': data.num_edges, + 'Avg Degree': deg.mean(), + 'Max Degree': deg.max(), + 'Self-loops': contains_self_loops(data.edge_index), + 'Undirected': check_undirected(data.edge_index), + } + + if hasattr(data, 'num_node_features'): + stats['Node Features'] = data.num_node_features + if hasattr(data, 'num_edge_features') and data.edge_attr is not None: + stats['Edge Features'] = data.num_edge_features + if hasattr(data, 'y'): + if data.y.dim() == 1: + stats['Classes'] = int(data.y.max().item()) + 1 + + # Create text plot + fig, ax = plt.subplots(figsize=(8, 6)) + ax.axis('off') + + text = "Graph Statistics\n" + "=" * 40 + "\n\n" + for key, value in stats.items(): + text += f"{key:20s}: {value}\n" + + ax.text(0.1, 0.5, text, fontsize=14, family='monospace', + verticalalignment='center', transform=ax.transAxes) + + plt.tight_layout() + + if output_path: + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Statistics saved to {output_path}") + else: + plt.show() + + plt.close() + + # Print to console as well + print("\n" + text) + + +def main(): + parser = argparse.ArgumentParser(description="Visualize PyTorch Geometric graphs") + parser.add_argument('--dataset', type=str, default='Cora', + help='Dataset name (e.g., Cora, CiteSeer, ENZYMES)') + parser.add_argument('--output', type=str, default=None, + help='Output file path for visualization') + parser.add_argument('--node-color', type=str, default='y', + help='Attribute to use for node colors') + parser.add_argument('--layout', type=str, default='spring', + choices=['spring', 'circular', 'kamada_kawai', 'spectral'], + help='Graph layout algorithm') + parser.add_argument('--show-labels', action='store_true', + help='Show node labels') + parser.add_argument('--max-nodes', type=int, default=500, + help='Maximum nodes to visualize') + parser.add_argument('--stats', action='store_true', + help='Show graph statistics') + parser.add_argument('--degree', action='store_true', + help='Show degree distribution') + + args = parser.parse_args() + + # Load dataset + print(f"Loading dataset: {args.dataset}") + + try: + # Try Planetoid datasets + from torch_geometric.datasets import Planetoid + dataset = Planetoid(root=f'/tmp/{args.dataset}', name=args.dataset) + data = dataset[0] + except: + try: + # Try TUDataset + from torch_geometric.datasets import TUDataset + dataset = TUDataset(root=f'/tmp/{args.dataset}', name=args.dataset) + data = dataset[0] + except Exception as e: + print(f"Error loading dataset: {e}") + print("Supported datasets: Cora, CiteSeer, PubMed, ENZYMES, PROTEINS, etc.") + return + + print(f"Loaded {args.dataset}: {data.num_nodes} nodes, {data.num_edges} edges") + + # Generate visualizations + if args.stats: + stats_output = args.output.replace('.png', '_stats.png') if args.output else None + plot_graph_statistics(data, stats_output) + + if args.degree: + degree_output = args.output.replace('.png', '_degree.png') if args.output else None + plot_degree_distribution(data, degree_output) + + # Main visualization + visualize_data( + data, + title=f"{args.dataset} Graph", + node_color_attr=args.node_color, + show_labels=args.show_labels, + layout=args.layout, + output_path=args.output, + max_nodes=args.max_nodes + ) + + +if __name__ == '__main__': + main() diff --git a/scientific-packages/transformers/SKILL.md b/scientific-packages/transformers/SKILL.md new file mode 100644 index 0000000..611ef18 --- /dev/null +++ b/scientific-packages/transformers/SKILL.md @@ -0,0 +1,860 @@ +--- +name: transformers +description: Comprehensive toolkit for working with Hugging Face Transformers library for state-of-the-art machine learning across NLP, computer vision, audio, and multimodal tasks. Use this skill when working with pretrained models, fine-tuning transformers, implementing text generation, image classification, speech recognition, or any task involving transformer architectures like BERT, GPT, T5, Vision Transformers, CLIP, or Whisper. +--- + +# Transformers + +## Overview + +Transformers is Hugging Face's flagship library providing unified access to over 1 million pretrained models for machine learning across text, vision, audio, and multimodal domains. The library serves as a standardized model-definition framework compatible with PyTorch, TensorFlow, and JAX, emphasizing ease of use through three core components: + +- **Pipeline**: Simple, optimized inference API for common tasks +- **AutoClasses**: Automatic model/tokenizer selection from pretrained checkpoints +- **Trainer**: Full-featured training loop with distributed training, mixed precision, and optimization + +The library prioritizes accessibility with pretrained models that reduce computational costs and carbon footprint while providing compatibility across major training frameworks (PyTorch-Lightning, DeepSpeed, vLLM, etc.). + +## Quick Start with Pipelines + +Use pipelines for simple, efficient inference without managing models, tokenizers, or preprocessing manually. Pipelines abstract complexity into a single function call. + +### Basic Pipeline Usage + +```python +from transformers import pipeline + +# Text classification +classifier = pipeline("text-classification") +result = classifier("This restaurant is awesome") +# [{'label': 'POSITIVE', 'score': 0.9998}] + +# Text generation +generator = pipeline("text-generation", model="meta-llama/Llama-2-7b-hf") +generator("The secret to baking a good cake is", max_length=50) + +# Question answering +qa = pipeline("question-answering") +qa(question="What is extractive QA?", context="Extractive QA is...") + +# Image classification +img_classifier = pipeline("image-classification") +img_classifier("path/to/image.jpg") + +# Automatic speech recognition +transcriber = pipeline("automatic-speech-recognition") +transcriber("audio_file.mp3") +``` + +### Available Pipeline Tasks + +**NLP Tasks:** +- `text-classification`, `token-classification`, `question-answering` +- `fill-mask`, `summarization`, `translation` +- `text-generation`, `conversational` +- `zero-shot-classification`, `sentiment-analysis` + +**Vision Tasks:** +- `image-classification`, `image-segmentation`, `object-detection` +- `depth-estimation`, `image-to-image`, `zero-shot-image-classification` + +**Audio Tasks:** +- `automatic-speech-recognition`, `audio-classification` +- `text-to-audio`, `zero-shot-audio-classification` + +**Multimodal Tasks:** +- `visual-question-answering`, `document-question-answering` +- `image-to-text`, `zero-shot-object-detection` + +### Pipeline Best Practices + +**Device Management:** +```python +from transformers import pipeline, infer_device + +device = infer_device() # Auto-detect best device +pipe = pipeline("text-generation", model="...", device=device) +``` + +**Batch Processing:** +```python +# Process multiple inputs efficiently +results = classifier(["Text 1", "Text 2", "Text 3"]) + +# Use KeyDataset for large datasets +from transformers.pipelines.pt_utils import KeyDataset +from datasets import load_dataset + +dataset = load_dataset("imdb", split="test") +for result in pipe(KeyDataset(dataset, "text")): + print(result) +``` + +**Memory Optimization:** +```python +# Use half-precision for faster inference +pipe = pipeline("text-generation", model="...", + torch_dtype=torch.float16, device="cuda") +``` + +## Core Components + +### AutoClasses for Model Loading + +AutoClasses automatically select the correct architecture based on pretrained checkpoints. + +```python +from transformers import ( + AutoModel, AutoTokenizer, AutoConfig, + AutoModelForCausalLM, AutoModelForSequenceClassification +) + +# Load any model by checkpoint name +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +model = AutoModel.from_pretrained("bert-base-uncased") + +# Task-specific model classes +causal_lm = AutoModelForCausalLM.from_pretrained("gpt2") +classifier = AutoModelForSequenceClassification.from_pretrained( + "bert-base-uncased", + num_labels=3 +) + +# Load with device and dtype optimization +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + device_map="auto", # Automatically distribute across devices + torch_dtype="auto" # Use optimal dtype +) +``` + +**Key Parameters:** +- `device_map="auto"`: Optimal device allocation (CPU/GPU/multi-GPU) +- `torch_dtype`: Control precision (torch.float16, torch.bfloat16, "auto") +- `trust_remote_code`: Enable custom model code (use cautiously) +- `use_fast`: Enable Rust-backed fast tokenizers (default True) + +### Tokenization + +Tokenizers convert text to model-compatible tensor inputs. + +```python +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +# Basic tokenization +tokens = tokenizer.tokenize("Hello, how are you?") +# ['hello', ',', 'how', 'are', 'you', '?'] + +# Encoding (text → token IDs) +encoded = tokenizer("Hello, how are you?", return_tensors="pt") +# {'input_ids': tensor([[...]], 'attention_mask': tensor([[...]])} + +# Batch encoding with padding and truncation +batch = tokenizer( + ["Short text", "This is a much longer text..."], + padding=True, # Pad to longest in batch + truncation=True, # Truncate to model's max length + max_length=512, + return_tensors="pt" +) + +# Decoding (token IDs → text) +text = tokenizer.decode(encoded['input_ids'][0]) +``` + +**Special Tokens:** +```python +# Access special tokens +tokenizer.pad_token # Padding token +tokenizer.cls_token # Classification token +tokenizer.sep_token # Separator token +tokenizer.mask_token # Mask token (for MLM) + +# Add custom tokens +tokenizer.add_tokens(["[CUSTOM]"]) +tokenizer.add_special_tokens({'additional_special_tokens': ['[NEW]']}) + +# Resize model embeddings to match new vocabulary +model.resize_token_embeddings(len(tokenizer)) +``` + +### Image Processors + +For vision tasks, use image processors instead of tokenizers. + +```python +from transformers import AutoImageProcessor + +processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + +# Process single image +from PIL import Image +image = Image.open("path/to/image.jpg") +inputs = processor(image, return_tensors="pt") +# Returns: {'pixel_values': tensor([[...]])} + +# Batch processing +images = [Image.open(f"img{i}.jpg") for i in range(3)] +inputs = processor(images, return_tensors="pt") +``` + +### Processors for Multimodal Models + +Multimodal models use processors that combine image and text processing. + +```python +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained("microsoft/git-base") + +# Process image + text caption +inputs = processor( + images=image, + text="A description of the image", + return_tensors="pt", + padding=True +) +``` + +## Model Inference + +### Basic Inference Pattern + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load model and tokenizer +model = AutoModelForCausalLM.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +# Tokenize input +inputs = tokenizer("The future of AI is", return_tensors="pt") + +# Generate (for causal LM) +outputs = model.generate(**inputs, max_length=50) +text = tokenizer.decode(outputs[0]) + +# Or get model outputs directly +outputs = model(**inputs) +logits = outputs.logits # Shape: (batch_size, seq_len, vocab_size) +``` + +### Text Generation Strategies + +For generative models, control generation behavior with parameters: + +```python +# Greedy decoding (default) +output = model.generate(inputs, max_length=50) + +# Beam search (multiple hypothesis) +output = model.generate( + inputs, + max_length=50, + num_beams=5, # Keep top 5 beams + early_stopping=True +) + +# Sampling with temperature +output = model.generate( + inputs, + max_length=50, + do_sample=True, + temperature=0.7, # Lower = more focused, higher = more random + top_k=50, # Sample from top 50 tokens + top_p=0.95 # Nucleus sampling +) + +# Streaming generation +from transformers import TextStreamer + +streamer = TextStreamer(tokenizer) +model.generate(**inputs, streamer=streamer, max_length=100) +``` + +**Generation Parameters:** +- `max_length` / `max_new_tokens`: Control output length +- `num_beams`: Beam search width (1 = greedy) +- `temperature`: Randomness (0.7-1.0 typical) +- `top_k`: Sample from top k tokens +- `top_p`: Nucleus sampling threshold +- `repetition_penalty`: Discourage repetition (>1.0) + +Refer to `references/generation_strategies.md` for detailed information on choosing appropriate strategies. + +## Training and Fine-Tuning + +### Training Workflow Overview + +1. **Load dataset** → 2. **Preprocess** → 3. **Configure training** → 4. **Train** → 5. **Evaluate** → 6. **Save/Share** + +### Text Classification Example + +```python +from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, + TrainingArguments, Trainer, DataCollatorWithPadding +) +from datasets import load_dataset + +# 1. Load dataset +dataset = load_dataset("imdb") + +# 2. Preprocess +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +def preprocess(examples): + return tokenizer(examples["text"], truncation=True) + +tokenized = dataset.map(preprocess, batched=True) +data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + +# 3. Load model +model = AutoModelForSequenceClassification.from_pretrained( + "bert-base-uncased", + num_labels=2, + id2label={0: "negative", 1: "positive"}, + label2id={"negative": 0, "positive": 1} +) + +# 4. Configure training +training_args = TrainingArguments( + output_dir="./results", + learning_rate=2e-5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + num_train_epochs=3, + weight_decay=0.01, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + push_to_hub=False, +) + +# 5. Train +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized["train"], + eval_dataset=tokenized["test"], + tokenizer=tokenizer, + data_collator=data_collator, +) + +trainer.train() + +# 6. Evaluate and save +metrics = trainer.evaluate() +trainer.save_model("./my-finetuned-model") +trainer.push_to_hub() # Share to Hugging Face Hub +``` + +### Vision Task Fine-Tuning + +```python +from transformers import ( + AutoImageProcessor, AutoModelForImageClassification, + TrainingArguments, Trainer +) +from datasets import load_dataset + +# Load dataset +dataset = load_dataset("food101", split="train[:5000]") + +# Image preprocessing +processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + +def transform(examples): + examples["pixel_values"] = [ + processor(img.convert("RGB"), return_tensors="pt")["pixel_values"][0] + for img in examples["image"] + ] + return examples + +dataset = dataset.with_transform(transform) + +# Load model +model = AutoModelForImageClassification.from_pretrained( + "google/vit-base-patch16-224", + num_labels=101, # 101 food categories + ignore_mismatched_sizes=True +) + +# Training (similar pattern to text) +training_args = TrainingArguments( + output_dir="./vit-food101", + remove_unused_columns=False, # Keep image data + eval_strategy="epoch", + save_strategy="epoch", + learning_rate=5e-5, + per_device_train_batch_size=32, + num_train_epochs=3, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=processor, +) + +trainer.train() +``` + +### Sequence-to-Sequence Tasks + +For tasks like summarization, translation, use Seq2SeqTrainer: + +```python +from transformers import ( + AutoTokenizer, AutoModelForSeq2SeqLM, + Seq2SeqTrainingArguments, Seq2SeqTrainer, + DataCollatorForSeq2Seq +) + +tokenizer = AutoTokenizer.from_pretrained("t5-small") +model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + +def preprocess(examples): + # Prefix input for T5 + inputs = ["summarize: " + doc for doc in examples["text"]] + model_inputs = tokenizer(inputs, max_length=1024, truncation=True) + + # Tokenize targets + labels = tokenizer( + examples["summary"], + max_length=128, + truncation=True + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs + +tokenized_dataset = dataset.map(preprocess, batched=True) + +training_args = Seq2SeqTrainingArguments( + output_dir="./t5-summarization", + eval_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=8, + num_train_epochs=3, + predict_with_generate=True, # Important for seq2seq +) + +trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["test"], + tokenizer=tokenizer, + data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), +) + +trainer.train() +``` + +### Important TrainingArguments + +```python +TrainingArguments( + # Essential + output_dir="./results", + num_train_epochs=3, + per_device_train_batch_size=8, + learning_rate=2e-5, + + # Evaluation + eval_strategy="epoch", # or "steps" + eval_steps=500, # if eval_strategy="steps" + + # Checkpointing + save_strategy="epoch", + save_steps=500, + save_total_limit=2, # Keep only 2 best checkpoints + load_best_model_at_end=True, + metric_for_best_model="accuracy", + + # Optimization + gradient_accumulation_steps=4, + warmup_steps=500, + weight_decay=0.01, + max_grad_norm=1.0, + + # Mixed Precision + fp16=True, # For Nvidia GPUs + bf16=True, # For Ampere+ GPUs (better) + + # Logging + logging_steps=100, + report_to="tensorboard", # or "wandb", "mlflow" + + # Memory Optimization + gradient_checkpointing=True, + optim="adamw_torch", # or "adafactor" for memory + + # Distributed Training + ddp_find_unused_parameters=False, +) +``` + +Refer to `references/training_guide.md` for comprehensive training patterns and optimization strategies. + +## Performance Optimization + +### Model Quantization + +Reduce memory footprint while maintaining accuracy: + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +# 8-bit quantization +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + load_in_8bit=True, + device_map="auto" +) + +# 4-bit quantization (even smaller) +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=bnb_config, + device_map="auto" +) +``` + +**Quantization Methods:** +- **Bitsandbytes**: 4/8-bit on-the-fly quantization, supports PEFT fine-tuning +- **GPTQ**: 2/3/4/8-bit, requires calibration, very fast inference +- **AWQ**: 4-bit activation-aware, balanced speed/accuracy + +Refer to `references/quantization.md` for detailed comparison and usage patterns. + +### Training Optimization + +```python +# Gradient accumulation (simulate larger batch) +training_args = TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=8, # Effective batch = 4 * 8 = 32 +) + +# Gradient checkpointing (reduce memory, slower) +training_args = TrainingArguments( + gradient_checkpointing=True, +) + +# Mixed precision training +training_args = TrainingArguments( + bf16=True, # or fp16=True +) + +# Efficient optimizer +training_args = TrainingArguments( + optim="adafactor", # Lower memory than AdamW +) +``` + +**Key Strategies:** +- **Batch sizes**: Use powers of 2 (8, 16, 32, 64, 128) +- **Gradient accumulation**: Enables larger effective batch sizes +- **Gradient checkpointing**: Reduces memory ~60%, increases time ~20% +- **Mixed precision**: bf16 for Ampere+ GPUs, fp16 for older +- **torch.compile**: Optimize model graph (PyTorch 2.0+) + +## Advanced Features + +### Custom Training Loop + +For maximum control, bypass Trainer: + +```python +from torch.utils.data import DataLoader +from transformers import AdamW, get_scheduler + +# Prepare data +train_dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True) + +# Setup optimizer and scheduler +optimizer = AdamW(model.parameters(), lr=5e-5) +scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=len(train_dataloader) * num_epochs +) + +# Training loop +model.train() +for epoch in range(num_epochs): + for batch in train_dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + + outputs = model(**batch) + loss = outputs.loss + loss.backward() + + optimizer.step() + scheduler.step() + optimizer.zero_grad() +``` + +### Parameter-Efficient Fine-Tuning (PEFT) + +Use PEFT library with transformers for efficient fine-tuning: + +```python +from peft import LoraConfig, get_peft_model + +# Configure LoRA +lora_config = LoraConfig( + r=16, # Low-rank dimension + lora_alpha=32, + target_modules=["q_proj", "v_proj"], # Which layers to adapt + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +# Apply to model +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") +model = get_peft_model(model, lora_config) + +# Now train as usual - only LoRA parameters train +trainer = Trainer(model=model, ...) +trainer.train() +``` + +### Chat Templates + +Apply chat templates for instruction-tuned models: + +```python +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is machine learning?"}, +] + +# Format according to model's chat template +formatted = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) + +# Tokenize and generate +inputs = tokenizer(formatted, return_tensors="pt") +outputs = model.generate(**inputs, max_length=200) +response = tokenizer.decode(outputs[0]) +``` + +### Multi-GPU Training + +```python +# Automatic with Trainer - no code changes needed +# Just run with: accelerate launch train.py + +# Or use PyTorch DDP explicitly +training_args = TrainingArguments( + output_dir="./results", + ddp_find_unused_parameters=False, + # ... other args +) + +# For larger models, use FSDP +training_args = TrainingArguments( + output_dir="./results", + fsdp="full_shard auto_wrap", + fsdp_config={ + "fsdp_transformer_layer_cls_to_wrap": ["BertLayer"], + }, +) +``` + +## Task-Specific Patterns + +### Question Answering (Extractive) + +```python +from transformers import pipeline + +qa = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") + +result = qa( + question="What is extractive QA?", + context="Extractive QA extracts the answer from the given context..." +) +# {'answer': 'extracts the answer from the given context', 'score': 0.97, ...} +``` + +### Named Entity Recognition + +```python +ner = pipeline("token-classification", model="dslim/bert-base-NER") + +result = ner("My name is John and I live in New York") +# [{'entity': 'B-PER', 'word': 'John', ...}, {'entity': 'B-LOC', 'word': 'New York', ...}] +``` + +### Image Captioning + +```python +from transformers import AutoProcessor, AutoModelForCausalLM + +processor = AutoProcessor.from_pretrained("microsoft/git-base") +model = AutoModelForCausalLM.from_pretrained("microsoft/git-base") + +from PIL import Image +image = Image.open("image.jpg") + +inputs = processor(images=image, return_tensors="pt") +outputs = model.generate(**inputs, max_length=50) +caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] +``` + +### Speech Recognition + +```python +transcriber = pipeline( + "automatic-speech-recognition", + model="openai/whisper-base" +) + +result = transcriber("audio.mp3") +# {'text': 'This is the transcribed text...'} + +# With timestamps +result = transcriber("audio.mp3", return_timestamps=True) +``` + +## Common Patterns and Best Practices + +### Saving and Loading Models + +```python +# Save entire model +model.save_pretrained("./my-model") +tokenizer.save_pretrained("./my-model") + +# Load later +model = AutoModel.from_pretrained("./my-model") +tokenizer = AutoTokenizer.from_pretrained("./my-model") + +# Push to Hugging Face Hub +model.push_to_hub("username/my-model") +tokenizer.push_to_hub("username/my-model") + +# Load from Hub +model = AutoModel.from_pretrained("username/my-model") +``` + +### Error Handling + +```python +from transformers import AutoModel +import torch + +try: + model = AutoModel.from_pretrained("model-name") +except OSError: + print("Model not found - check internet connection or model name") +except torch.cuda.OutOfMemoryError: + print("GPU memory exceeded - try quantization or smaller batch size") +``` + +### Device Management + +```python +import torch + +# Check device availability +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Move model to device +model = model.to(device) + +# Or use device_map for automatic distribution +model = AutoModel.from_pretrained("model-name", device_map="auto") + +# For inputs +inputs = tokenizer(text, return_tensors="pt").to(device) +``` + +### Memory Management + +```python +import torch + +# Clear CUDA cache +torch.cuda.empty_cache() + +# Use context manager for inference +with torch.no_grad(): + outputs = model(**inputs) + +# Delete unused models +del model +torch.cuda.empty_cache() +``` + +## Resources + +This skill includes comprehensive reference documentation and example scripts: + +### scripts/ + +- `quick_inference.py`: Ready-to-use script for running inference with pipelines +- `fine_tune_classifier.py`: Complete example for fine-tuning a text classifier +- `generate_text.py`: Text generation with various strategies + +Execute scripts directly or read them as implementation templates. + +### references/ + +- `api_reference.md`: Comprehensive API documentation for key classes +- `training_guide.md`: Detailed training patterns, optimization, and troubleshooting +- `generation_strategies.md`: In-depth guide to text generation methods +- `quantization.md`: Model quantization techniques comparison and usage +- `task_patterns.md`: Quick reference for common task implementations + +Load reference files when you need detailed information on specific topics. References contain extensive examples, parameter explanations, and best practices. + +## Troubleshooting + +**Import errors:** +```bash +pip install transformers +pip install accelerate # For device_map="auto" +pip install bitsandbytes # For quantization +``` + +**CUDA out of memory:** +- Reduce batch size +- Enable gradient checkpointing +- Use gradient accumulation +- Try quantization (8-bit or 4-bit) +- Use smaller model variant + +**Slow training:** +- Enable mixed precision (fp16/bf16) +- Increase batch size (if memory allows) +- Use torch.compile (PyTorch 2.0+) +- Check data loading isn't bottleneck + +**Poor generation quality:** +- Adjust temperature (lower = more focused) +- Try different decoding strategies (beam search vs sampling) +- Increase max_length if outputs cut off +- Use repetition_penalty to reduce repetition + +For task-specific guidance, consult the appropriate reference file in the `references/` directory. diff --git a/scientific-packages/transformers/references/api_reference.md b/scientific-packages/transformers/references/api_reference.md new file mode 100644 index 0000000..d43397a --- /dev/null +++ b/scientific-packages/transformers/references/api_reference.md @@ -0,0 +1,699 @@ +# Transformers API Reference + +This document provides comprehensive API reference for the most commonly used classes and methods in the Transformers library. + +## Core Model Classes + +### PreTrainedModel + +Base class for all models. Handles loading, saving, and common model operations. + +**Key Methods:** + +```python +from transformers import PreTrainedModel + +# Load pretrained model +model = ModelClass.from_pretrained( + pretrained_model_name_or_path, + config=None, # Custom config + cache_dir=None, # Custom cache location + force_download=False, # Force re-download + resume_download=False, # Resume interrupted download + proxies=None, # HTTP proxies + local_files_only=False, # Only use cached files + token=None, # HF auth token + revision="main", # Git branch/tag + trust_remote_code=False, # Allow custom model code + device_map=None, # Device allocation ("auto", "cpu", "cuda:0", etc.) + torch_dtype=None, # Model dtype (torch.float16, "auto", etc.) + low_cpu_mem_usage=False, # Reduce CPU memory during loading + **model_kwargs +) + +# Save model +model.save_pretrained( + save_directory, + save_config=True, # Save config.json + state_dict=None, # Custom state dict + save_function=torch.save, # Custom save function + push_to_hub=False, # Upload to Hub + max_shard_size="5GB", # Max checkpoint size + safe_serialization=True, # Use SafeTensors format + variant=None, # Model variant name +) + +# Generate text (for generative models) +outputs = model.generate( + inputs=None, # Input token IDs + max_length=20, # Max total length + max_new_tokens=None, # Max new tokens to generate + min_length=0, # Minimum length + do_sample=False, # Enable sampling + early_stopping=False, # Stop when num_beams finish + num_beams=1, # Beam search width + temperature=1.0, # Sampling temperature + top_k=50, # Top-k sampling + top_p=1.0, # Nucleus sampling + repetition_penalty=1.0, # Penalize repetition + length_penalty=1.0, # Beam search length penalty + no_repeat_ngram_size=0, # Block repeated n-grams + num_return_sequences=1, # Number of sequences to return + **model_kwargs +) + +# Resize token embeddings (after adding tokens) +new_embeddings = model.resize_token_embeddings( + new_num_tokens, + pad_to_multiple_of=None +) + +# Utility methods +num_params = model.num_parameters(only_trainable=False) +model.gradient_checkpointing_enable() # Enable gradient checkpointing +model.enable_input_require_grads() # For PEFT with frozen models +``` + +### AutoModel Classes + +Automatically instantiate the correct model architecture. + +**Available Classes:** + +- `AutoModel`: Base model (returns hidden states) +- `AutoModelForCausalLM`: Causal language modeling (GPT-style) +- `AutoModelForMaskedLM`: Masked language modeling (BERT-style) +- `AutoModelForSeq2SeqLM`: Sequence-to-sequence (T5, BART) +- `AutoModelForSequenceClassification`: Text classification +- `AutoModelForTokenClassification`: Token classification (NER) +- `AutoModelForQuestionAnswering`: Extractive QA +- `AutoModelForImageClassification`: Image classification +- `AutoModelForObjectDetection`: Object detection +- `AutoModelForSemanticSegmentation`: Semantic segmentation +- `AutoModelForAudioClassification`: Audio classification +- `AutoModelForSpeechSeq2Seq`: Speech-to-text +- `AutoModelForVision2Seq`: Image captioning, VQA + +**Usage:** + +```python +from transformers import AutoModel, AutoConfig + +# Load with default configuration +model = AutoModel.from_pretrained("bert-base-uncased") + +# Load with custom configuration +config = AutoConfig.from_pretrained("bert-base-uncased") +config.hidden_dropout_prob = 0.2 +model = AutoModel.from_pretrained("bert-base-uncased", config=config) + +# Register custom models +from transformers import AutoConfig, AutoModel + +AutoConfig.register("my-model", MyModelConfig) +AutoModel.register(MyModelConfig, MyModel) +``` + +## Tokenizer Classes + +### PreTrainedTokenizer / PreTrainedTokenizerFast + +Convert text to token IDs and vice versa. + +**Key Methods:** + +```python +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + use_fast=True, # Use fast (Rust) tokenizer if available + revision="main", + **kwargs +) + +# Encoding (text → token IDs) +encoded = tokenizer( + text, # String or List[str] + text_pair=None, # Second sequence for pairs + add_special_tokens=True, # Add [CLS], [SEP], etc. + padding=False, # True, False, "longest", "max_length" + truncation=False, # True, False, "longest_first", "only_first", "only_second" + max_length=None, # Max sequence length + stride=0, # Overlap for split sequences + return_tensors=None, # "pt" (PyTorch), "tf" (TensorFlow), "np" (NumPy) + return_token_type_ids=None, # Return token type IDs + return_attention_mask=None, # Return attention mask + return_overflowing_tokens=False, # Return overflowing tokens + return_special_tokens_mask=False, # Return special token mask + return_offsets_mapping=False, # Return char-level offsets (fast only) + return_length=False, # Return sequence lengths + **kwargs +) + +# Decoding (token IDs → text) +text = tokenizer.decode( + token_ids, + skip_special_tokens=False, # Remove special tokens + clean_up_tokenization_spaces=True, # Clean up spacing +) + +# Batch decoding +texts = tokenizer.batch_decode( + sequences, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, +) + +# Tokenization (text → tokens) +tokens = tokenizer.tokenize(text, **kwargs) + +# Convert tokens to IDs +ids = tokenizer.convert_tokens_to_ids(tokens) + +# Convert IDs to tokens +tokens = tokenizer.convert_ids_to_tokens(ids) + +# Add new tokens +num_added = tokenizer.add_tokens(["[NEW_TOKEN1]", "[NEW_TOKEN2]"]) + +# Add special tokens +tokenizer.add_special_tokens({ + "bos_token": "[BOS]", + "eos_token": "[EOS]", + "unk_token": "[UNK]", + "sep_token": "[SEP]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "mask_token": "[MASK]", + "additional_special_tokens": ["[SPECIAL1]", "[SPECIAL2]"], +}) + +# Chat template formatting +formatted = tokenizer.apply_chat_template( + conversation, # List[Dict[str, str]] with "role" and "content" + chat_template=None, # Custom template + add_generation_prompt=False, # Add prompt for model to continue + tokenize=True, # Return token IDs + padding=False, + truncation=False, + max_length=None, + return_tensors=None, + return_dict=True, +) + +# Save tokenizer +tokenizer.save_pretrained(save_directory) + +# Get vocab size +vocab_size = len(tokenizer) + +# Get special tokens +pad_token = tokenizer.pad_token +pad_token_id = tokenizer.pad_token_id +# Similar for: bos, eos, unk, sep, cls, mask +``` + +**Special Token Attributes:** + +```python +tokenizer.bos_token # Beginning of sequence +tokenizer.eos_token # End of sequence +tokenizer.unk_token # Unknown token +tokenizer.sep_token # Separator token +tokenizer.pad_token # Padding token +tokenizer.cls_token # Classification token +tokenizer.mask_token # Mask token + +# Corresponding IDs +tokenizer.bos_token_id +tokenizer.eos_token_id +# ... etc +``` + +## Image Processors + +### AutoImageProcessor + +Preprocess images for vision models. + +**Key Methods:** + +```python +from transformers import AutoImageProcessor + +processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + +# Process images +inputs = processor( + images, # PIL Image, np.array, torch.Tensor, or List + return_tensors="pt", # "pt", "tf", "np", None + do_resize=True, # Resize to model size + size=None, # Target size dict + resample=None, # Resampling method + do_rescale=True, # Rescale pixel values + do_normalize=True, # Normalize with mean/std + image_mean=None, # Custom mean + image_std=None, # Custom std + do_center_crop=False, # Center crop + crop_size=None, # Crop size + **kwargs +) + +# Returns: BatchFeature with 'pixel_values' key +``` + +## Training Components + +### TrainingArguments + +Configuration for the Trainer class. + +**Essential Arguments:** + +```python +from transformers import TrainingArguments + +args = TrainingArguments( + # ===== Output & Logging ===== + output_dir="./results", # REQUIRED: Output directory + overwrite_output_dir=False, # Overwrite output directory + + # ===== Training Parameters ===== + num_train_epochs=3.0, # Number of epochs + max_steps=-1, # Max training steps (overrides epochs) + per_device_train_batch_size=8, # Train batch size per device + per_device_eval_batch_size=8, # Eval batch size per device + gradient_accumulation_steps=1, # Accumulation steps + + # ===== Learning Rate & Optimization ===== + learning_rate=5e-5, # Initial learning rate + weight_decay=0.0, # Weight decay + adam_beta1=0.9, # Adam beta1 + adam_beta2=0.999, # Adam beta2 + adam_epsilon=1e-8, # Adam epsilon + max_grad_norm=1.0, # Gradient clipping + optim="adamw_torch", # Optimizer ("adamw_torch", "adafactor", "adamw_8bit") + + # ===== Learning Rate Scheduler ===== + lr_scheduler_type="linear", # Scheduler type + warmup_steps=0, # Warmup steps + warmup_ratio=0.0, # Warmup ratio (alternative to steps) + + # ===== Evaluation ===== + eval_strategy="no", # "no", "steps", "epoch" + eval_steps=None, # Eval every N steps + eval_delay=0, # Delay first eval + eval_accumulation_steps=None, # Accumulate eval outputs + + # ===== Checkpointing ===== + save_strategy="steps", # "no", "steps", "epoch" + save_steps=500, # Save every N steps + save_total_limit=None, # Max checkpoints to keep + save_safetensors=True, # Save as SafeTensors + save_on_each_node=False, # Save on each node (distributed) + + # ===== Best Model Selection ===== + load_best_model_at_end=False, # Load best checkpoint at end + metric_for_best_model=None, # Metric to use + greater_is_better=None, # True if higher is better + + # ===== Logging ===== + logging_dir=None, # TensorBoard log directory + logging_strategy="steps", # "no", "steps", "epoch" + logging_steps=500, # Log every N steps + logging_first_step=False, # Log first step + logging_nan_inf_filter=True, # Filter NaN/Inf + + # ===== Mixed Precision ===== + fp16=False, # Use fp16 training + fp16_opt_level="O1", # Apex AMP optimization level + fp16_backend="auto", # "auto", "apex", "cpu_amp" + bf16=False, # Use bfloat16 training + bf16_full_eval=False, # Use bf16 for evaluation + tf32=None, # Use TF32 (Ampere+ GPUs) + + # ===== Memory Optimization ===== + gradient_checkpointing=False, # Enable gradient checkpointing + gradient_checkpointing_kwargs=None, # Kwargs for gradient checkpointing + torch_empty_cache_steps=None, # Clear cache every N steps + + # ===== Distributed Training ===== + local_rank=-1, # Local rank for distributed + ddp_backend=None, # "nccl", "gloo", "mpi", "ccl" + ddp_find_unused_parameters=None, # Find unused parameters + ddp_bucket_cap_mb=None, # DDP bucket size + fsdp="", # FSDP configuration + fsdp_config=None, # FSDP config dict + deepspeed=None, # DeepSpeed config + + # ===== Hub Integration ===== + push_to_hub=False, # Push to Hugging Face Hub + hub_model_id=None, # Hub model ID + hub_strategy="every_save", # "every_save", "checkpoint", "end" + hub_token=None, # Hub authentication token + hub_private_repo=False, # Make repo private + + # ===== Data Handling ===== + dataloader_num_workers=0, # DataLoader workers + dataloader_pin_memory=True, # Pin memory + dataloader_drop_last=False, # Drop last incomplete batch + dataloader_prefetch_factor=None, # Prefetch factor + remove_unused_columns=True, # Remove unused dataset columns + label_names=None, # Label column names + + # ===== Other ===== + seed=42, # Random seed + data_seed=None, # Data sampling seed + jit_mode_eval=False, # Use PyTorch JIT for eval + use_ipex=False, # Use Intel Extension for PyTorch + torch_compile=False, # Use torch.compile() + torch_compile_backend=None, # Compile backend + torch_compile_mode=None, # Compile mode + include_inputs_for_metrics=False, # Pass inputs to compute_metrics + skip_memory_metrics=True, # Skip memory profiling +) +``` + +### Trainer + +Main training class with full training loop. + +**Key Methods:** + +```python +from transformers import Trainer + +trainer = Trainer( + model=None, # Model to train + args=None, # TrainingArguments + data_collator=None, # Data collator + train_dataset=None, # Training dataset + eval_dataset=None, # Evaluation dataset + tokenizer=None, # Tokenizer + model_init=None, # Function to instantiate model + compute_metrics=None, # Function to compute metrics + callbacks=None, # List of callbacks + optimizers=(None, None), # (optimizer, scheduler) tuple + preprocess_logits_for_metrics=None, # Preprocess logits before metrics +) + +# Train model +train_result = trainer.train( + resume_from_checkpoint=None, # Resume from checkpoint + trial=None, # Optuna/Ray trial + ignore_keys_for_eval=None, # Keys to ignore in eval +) + +# Evaluate model +eval_result = trainer.evaluate( + eval_dataset=None, # Eval dataset (default: self.eval_dataset) + ignore_keys=None, # Keys to ignore + metric_key_prefix="eval", # Prefix for metric names +) + +# Make predictions +predictions = trainer.predict( + test_dataset, # Test dataset + ignore_keys=None, # Keys to ignore + metric_key_prefix="test", # Metric prefix +) +# Returns: PredictionOutput(predictions, label_ids, metrics) + +# Save model +trainer.save_model(output_dir=None) + +# Push to Hub +trainer.push_to_hub( + commit_message="End of training", + blocking=True, + **kwargs +) + +# Hyperparameter search +best_trial = trainer.hyperparameter_search( + hp_space=None, # Hyperparameter search space + compute_objective=None, # Objective function + n_trials=20, # Number of trials + direction="minimize", # "minimize" or "maximize" + backend=None, # "optuna", "ray", "sigopt" + **kwargs +) + +# Create optimizer +optimizer = trainer.create_optimizer() + +# Create scheduler +scheduler = trainer.create_scheduler( + num_training_steps, + optimizer=None +) + +# Log metrics +trainer.log_metrics(split, metrics) +trainer.save_metrics(split, metrics) + +# Save checkpoint +trainer.save_state() + +# Access current step/epoch +current_step = trainer.state.global_step +current_epoch = trainer.state.epoch + +# Access training logs +logs = trainer.state.log_history +``` + +### Seq2SeqTrainer + +Specialized trainer for sequence-to-sequence models. + +```python +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments + +# Use Seq2SeqTrainingArguments with additional parameters +training_args = Seq2SeqTrainingArguments( + output_dir="./results", + predict_with_generate=True, # Use generate() for evaluation + generation_max_length=None, # Max length for generation + generation_num_beams=None, # Num beams for generation + **other_training_arguments +) + +# Trainer usage is identical to Trainer +trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, +) +``` + +## Pipeline Classes + +### pipeline() + +Unified inference API for all tasks. + +```python +from transformers import pipeline + +pipe = pipeline( + task=None, # Task name (required) + model=None, # Model name/path or model object + config=None, # Model config + tokenizer=None, # Tokenizer + feature_extractor=None, # Feature extractor + image_processor=None, # Image processor + framework=None, # "pt" or "tf" + revision=None, # Model revision + use_fast=True, # Use fast tokenizer + token=None, # HF token + device=None, # Device (-1 for CPU, 0+ for GPU) + device_map=None, # Device map for multi-GPU + torch_dtype=None, # Model dtype + trust_remote_code=False, # Allow custom code + model_kwargs=None, # Additional model kwargs + pipeline_class=None, # Custom pipeline class + **kwargs +) + +# Use pipeline +results = pipe( + inputs, # Input data + **task_specific_parameters +) +``` + +## Data Collators + +Batch and pad data for training. + +```python +from transformers import ( + DataCollatorWithPadding, # Dynamic padding for classification + DataCollatorForTokenClassification, # Padding for token classification + DataCollatorForSeq2Seq, # Padding for seq2seq + DataCollatorForLanguageModeling, # MLM/CLM data collation + default_data_collator, # Simple collator (no padding) +) + +# Text classification +data_collator = DataCollatorWithPadding( + tokenizer=tokenizer, + padding=True, + max_length=None, + pad_to_multiple_of=None, +) + +# Token classification +data_collator = DataCollatorForTokenClassification( + tokenizer=tokenizer, + padding=True, + max_length=None, + pad_to_multiple_of=None, + label_pad_token_id=-100, +) + +# Seq2Seq +data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=None, + padding=True, + max_length=None, + pad_to_multiple_of=None, + label_pad_token_id=-100, +) + +# Language modeling +data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=True, # Masked LM (False for causal LM) + mlm_probability=0.15, # Mask probability + pad_to_multiple_of=None, +) +``` + +## Optimization & Scheduling + +```python +from transformers import ( + AdamW, # AdamW optimizer + Adafactor, # Adafactor optimizer + get_scheduler, # Get LR scheduler + get_linear_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, +) + +# Create optimizer +optimizer = AdamW( + model.parameters(), + lr=5e-5, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +# Create scheduler +scheduler = get_scheduler( + name="linear", # "linear", "cosine", "polynomial", "constant" + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=total_steps, +) + +# Or use specific schedulers +scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, +) + +scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, + num_cycles=0.5, +) +``` + +## Configuration Classes + +```python +from transformers import AutoConfig + +# Load configuration +config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + **kwargs +) + +# Common configuration attributes +config.vocab_size # Vocabulary size +config.hidden_size # Hidden layer size +config.num_hidden_layers # Number of layers +config.num_attention_heads # Attention heads +config.intermediate_size # FFN intermediate size +config.hidden_dropout_prob # Dropout probability +config.attention_probs_dropout_prob # Attention dropout +config.max_position_embeddings # Max sequence length + +# Save configuration +config.save_pretrained(save_directory) + +# Create model from config +from transformers import AutoModel +model = AutoModel.from_config(config) +``` + +## Utility Functions + +```python +from transformers import ( + set_seed, # Set random seed + logging, # Logging utilities +) + +# Set seed for reproducibility +set_seed(42) + +# Configure logging +logging.set_verbosity_info() +logging.set_verbosity_warning() +logging.set_verbosity_error() +logging.set_verbosity_debug() + +# Get logger +logger = logging.get_logger(__name__) +``` + +## Model Outputs + +All models return model-specific output classes (subclasses of `ModelOutput`): + +```python +# Common output attributes +outputs.loss # Loss (if labels provided) +outputs.logits # Model logits +outputs.hidden_states # All hidden states (if output_hidden_states=True) +outputs.attentions # Attention weights (if output_attentions=True) + +# Seq2Seq specific +outputs.encoder_last_hidden_state +outputs.encoder_hidden_states +outputs.encoder_attentions +outputs.decoder_hidden_states +outputs.decoder_attentions +outputs.cross_attentions + +# Access as dict or tuple +logits = outputs.logits +logits = outputs["logits"] +loss, logits = outputs.to_tuple()[:2] +``` + +This reference covers the most commonly used API components. For complete documentation, refer to https://huggingface.co/docs/transformers. diff --git a/scientific-packages/transformers/references/generation_strategies.md b/scientific-packages/transformers/references/generation_strategies.md new file mode 100644 index 0000000..9ad4486 --- /dev/null +++ b/scientific-packages/transformers/references/generation_strategies.md @@ -0,0 +1,530 @@ +# Text Generation Strategies + +Comprehensive guide to text generation methods in Transformers for controlling output quality, creativity, and diversity. + +## Overview + +Text generation is the process of predicting tokens sequentially using a language model. The choice of generation strategy significantly impacts output quality, diversity, and computational cost. + +**When to use each strategy:** +- **Greedy**: Fast, deterministic, good for short outputs or when consistency is critical +- **Beam Search**: Better quality for tasks with clear "correct" answers (translation, summarization) +- **Sampling**: Creative, diverse outputs for open-ended generation (stories, dialogue) +- **Top-k/Top-p**: Balanced creativity and coherence + +## Basic Generation Methods + +### Greedy Decoding + +Selects the highest probability token at each step. Fast but prone to repetition and suboptimal sequences. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +inputs = tokenizer("The future of AI", return_tensors="pt") + +# Greedy decoding (default) +outputs = model.generate(**inputs, max_new_tokens=50) +print(tokenizer.decode(outputs[0])) +``` + +**Characteristics:** +- Deterministic (always same output for same input) +- Fast (single forward pass per token) +- Prone to repetition in longer sequences +- Best for: Short generations, deterministic applications + +**Parameters:** +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, # Number of tokens to generate + min_length=10, # Minimum total length + pad_token_id=tokenizer.pad_token_id, +) +``` + +### Beam Search + +Maintains multiple hypotheses (beams) and selects the sequence with highest overall probability. + +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, + num_beams=5, # Number of beams + early_stopping=True, # Stop when all beams finish + no_repeat_ngram_size=2, # Prevent 2-gram repetition +) +``` + +**Characteristics:** +- Higher quality than greedy for tasks with "correct" answers +- Slower than greedy (num_beams forward passes per step) +- Still can suffer from repetition +- Best for: Translation, summarization, QA generation + +**Advanced Parameters:** +```python +outputs = model.generate( + **inputs, + num_beams=5, + num_beam_groups=1, # Diverse beam search groups + diversity_penalty=0.0, # Penalty for similar beams + length_penalty=1.0, # >1: longer sequences, <1: shorter + early_stopping=True, # Stop when num_beams sequences finish + no_repeat_ngram_size=2, # Block repeating n-grams + num_return_sequences=1, # Return top-k sequences (≤ num_beams) +) +``` + +**Length Penalty:** +- `length_penalty > 1.0`: Favor longer sequences +- `length_penalty = 1.0`: No penalty +- `length_penalty < 1.0`: Favor shorter sequences + +### Sampling (Multinomial) + +Randomly sample tokens according to the probability distribution. + +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, + do_sample=True, # Enable sampling + temperature=1.0, # Sampling temperature + num_beams=1, # Must be 1 for sampling +) +``` + +**Characteristics:** +- Non-deterministic (different output each time) +- More diverse and creative than greedy/beam search +- Can produce incoherent output if not controlled +- Best for: Creative writing, dialogue, open-ended generation + +**Temperature Parameter:** +```python +# Low temperature (0.1-0.7): More focused, less random +outputs = model.generate(**inputs, do_sample=True, temperature=0.5) + +# Medium temperature (0.7-1.0): Balanced +outputs = model.generate(**inputs, do_sample=True, temperature=0.8) + +# High temperature (1.0-2.0): More random, more creative +outputs = model.generate(**inputs, do_sample=True, temperature=1.5) +``` + +- `temperature → 0`: Approaches greedy decoding +- `temperature = 1.0`: Sample from original distribution +- `temperature > 1.0`: Flatter distribution, more random +- `temperature < 1.0`: Sharper distribution, more confident + +## Advanced Sampling Methods + +### Top-k Sampling + +Sample from only the k most likely tokens. + +```python +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=50, + top_k=50, # Consider top 50 tokens + temperature=0.8, +) +``` + +**How it works:** +1. Filter to top-k most probable tokens +2. Renormalize probabilities +3. Sample from filtered distribution + +**Choosing k:** +- `k=1`: Equivalent to greedy decoding +- `k=10-50`: More focused, coherent output +- `k=100-500`: More diverse output +- Too high k: Includes low-probability tokens (noise) +- Too low k: Less diverse, may miss good alternatives + +### Top-p (Nucleus) Sampling + +Sample from the smallest set of tokens whose cumulative probability ≥ p. + +```python +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=50, + top_p=0.95, # Nucleus probability + temperature=0.8, +) +``` + +**How it works:** +1. Sort tokens by probability +2. Find smallest set with cumulative probability ≥ p +3. Sample from this set + +**Choosing p:** +- `p=0.9-0.95`: Good balance (recommended) +- `p=1.0`: Sample from full distribution +- Higher p: More diverse, might include unlikely tokens +- Lower p: More focused, like top-k with adaptive k + +**Top-p vs Top-k:** +- Top-p adapts to probability distribution shape +- Top-k is fixed regardless of distribution +- Top-p generally better for variable-quality contexts +- Can combine: `top_k=50, top_p=0.95` (apply both filters) + +### Combining Strategies + +```python +# Recommended for high-quality open-ended generation +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=100, + temperature=0.8, # Moderate temperature + top_k=50, # Limit to top 50 tokens + top_p=0.95, # Nucleus sampling + repetition_penalty=1.2, # Discourage repetition + no_repeat_ngram_size=3, # Block 3-gram repetition +) +``` + +## Controlling Generation Quality + +### Repetition Control + +Prevent models from repeating themselves: + +```python +outputs = model.generate( + **inputs, + max_new_tokens=100, + + # Method 1: Repetition penalty + repetition_penalty=1.2, # Penalize repeated tokens (>1.0) + + # Method 2: Block n-gram repetition + no_repeat_ngram_size=3, # Never repeat 3-grams + + # Method 3: Encoder repetition penalty (for seq2seq) + encoder_repetition_penalty=1.0, # Penalize input tokens +) +``` + +**Repetition Penalty Values:** +- `1.0`: No penalty +- `1.0-1.5`: Mild penalty (recommended: 1.1-1.3) +- `>1.5`: Strong penalty (may harm coherence) + +### Length Control + +```python +outputs = model.generate( + **inputs, + + # Hard constraints + min_length=20, # Minimum total length + max_length=100, # Maximum total length + max_new_tokens=50, # Maximum new tokens (excluding input) + + # Soft constraints (with beam search) + length_penalty=1.0, # Encourage longer/shorter outputs + + # Early stopping + early_stopping=True, # Stop when condition met +) +``` + +### Bad Words and Forced Tokens + +```python +# Prevent specific tokens +bad_words_ids = [ + tokenizer.encode("badword1", add_special_tokens=False), + tokenizer.encode("badword2", add_special_tokens=False), +] + +outputs = model.generate( + **inputs, + bad_words_ids=bad_words_ids, +) + +# Force specific tokens +force_words_ids = [ + tokenizer.encode("important", add_special_tokens=False), +] + +outputs = model.generate( + **inputs, + force_words_ids=force_words_ids, +) +``` + +## Streaming Generation + +Generate and process tokens as they're produced: + +```python +from transformers import TextStreamer, TextIteratorStreamer +from threading import Thread + +# Simple streaming (prints to stdout) +streamer = TextStreamer(tokenizer, skip_prompt=True) +outputs = model.generate(**inputs, streamer=streamer, max_new_tokens=100) + +# Iterator streaming (for custom processing) +streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) + +generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=100) +thread = Thread(target=model.generate, kwargs=generation_kwargs) +thread.start() + +for text in streamer: + print(text, end="", flush=True) + +thread.join() +``` + +## Advanced Techniques + +### Contrastive Search + +Balance coherence and diversity using contrastive objective: + +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, + penalty_alpha=0.6, # Contrastive penalty + top_k=4, # Consider top-4 tokens +) +``` + +**When to use:** +- Open-ended text generation +- Reduces repetition without sacrificing coherence +- Good alternative to sampling + +### Diverse Beam Search + +Generate multiple diverse outputs: + +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, + num_beams=10, + num_beam_groups=5, # 5 groups of 2 beams each + diversity_penalty=1.0, # Penalty for similar beams + num_return_sequences=5, # Return 5 diverse outputs +) +``` + +### Constrained Beam Search + +Force output to include specific phrases: + +```python +from transformers import PhrasalConstraint + +constraints = [ + PhrasalConstraint( + tokenizer("machine learning", add_special_tokens=False).input_ids + ), +] + +outputs = model.generate( + **inputs, + constraints=constraints, + num_beams=10, # Requires beam search +) +``` + +## Speculative Decoding + +Accelerate generation using a smaller draft model: + +```python +from transformers import AutoModelForCausalLM + +# Load main and assistant models +model = AutoModelForCausalLM.from_pretrained("large-model") +assistant_model = AutoModelForCausalLM.from_pretrained("small-model") + +# Generate with speculative decoding +outputs = model.generate( + **inputs, + assistant_model=assistant_model, + do_sample=True, + temperature=0.8, +) +``` + +**Benefits:** +- 2-3x faster generation +- Identical output distribution to regular generation +- Works with sampling and greedy decoding + +## Recipe: Recommended Settings by Task + +### Creative Writing / Dialogue + +```python +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=200, + temperature=0.9, + top_p=0.95, + top_k=50, + repetition_penalty=1.2, + no_repeat_ngram_size=3, +) +``` + +### Translation / Summarization + +```python +outputs = model.generate( + **inputs, + num_beams=5, + max_new_tokens=150, + early_stopping=True, + length_penalty=1.0, + no_repeat_ngram_size=2, +) +``` + +### Code Generation + +```python +outputs = model.generate( + **inputs, + max_new_tokens=300, + temperature=0.2, # Low temperature for correctness + top_p=0.95, + do_sample=True, +) +``` + +### Chatbot / Instruction Following + +```python +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=256, + temperature=0.7, + top_p=0.9, + repetition_penalty=1.15, +) +``` + +### Factual QA / Information Extraction + +```python +outputs = model.generate( + **inputs, + max_new_tokens=50, + num_beams=3, + early_stopping=True, + # Or greedy for very short answers: + # (no special parameters needed) +) +``` + +## Debugging Generation + +### Check Token Probabilities + +```python +outputs = model.generate( + **inputs, + max_new_tokens=20, + output_scores=True, # Return generation scores + return_dict_in_generate=True, # Return as dict +) + +# Access generation scores +scores = outputs.scores # Tuple of tensors (seq_len, vocab_size) + +# Get token probabilities +import torch +probs = torch.softmax(scores[0], dim=-1) +``` + +### Monitor Generation Process + +```python +from transformers import LogitsProcessor, LogitsProcessorList + +class DebugLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids, scores): + # Print top 5 tokens at each step + top_tokens = scores[0].topk(5) + print(f"Top 5 tokens: {top_tokens}") + return scores + +outputs = model.generate( + **inputs, + max_new_tokens=10, + logits_processor=LogitsProcessorList([DebugLogitsProcessor()]), +) +``` + +## Common Issues and Solutions + +**Issue: Repetitive output** +- Solution: Increase `repetition_penalty` (1.2-1.5), set `no_repeat_ngram_size=3` +- For sampling: Increase `temperature`, enable `top_p` + +**Issue: Incoherent output** +- Solution: Lower `temperature` (0.5-0.8), use beam search +- Set `top_k=50` or `top_p=0.9` to filter unlikely tokens + +**Issue: Too short output** +- Solution: Increase `min_length`, set `length_penalty > 1.0` (beam search) +- Check if EOS token is being generated early + +**Issue: Too slow generation** +- Solution: Use greedy instead of beam search +- Reduce `num_beams` +- Try speculative decoding with assistant model +- Use smaller model variant + +**Issue: Output doesn't follow format** +- Solution: Use constrained beam search +- Add format examples to prompt +- Use `bad_words_ids` to prevent format-breaking tokens + +## Performance Optimization + +```python +# Use half precision +model = AutoModelForCausalLM.from_pretrained( + "model-name", + torch_dtype=torch.float16, + device_map="auto" +) + +# Use KV cache optimization (default, but can be disabled) +outputs = model.generate(**inputs, use_cache=True) + +# Batch generation +inputs = tokenizer(["Prompt 1", "Prompt 2"], return_tensors="pt", padding=True) +outputs = model.generate(**inputs, max_new_tokens=50) + +# Static cache for longer sequences (if supported) +outputs = model.generate(**inputs, cache_implementation="static") +``` + +This guide covers the main generation strategies. For task-specific examples, see `task_patterns.md`. diff --git a/scientific-packages/transformers/references/quantization.md b/scientific-packages/transformers/references/quantization.md new file mode 100644 index 0000000..a6a3fc4 --- /dev/null +++ b/scientific-packages/transformers/references/quantization.md @@ -0,0 +1,504 @@ +# Model Quantization Guide + +Comprehensive guide to reducing model memory footprint through quantization while maintaining accuracy. + +## Overview + +Quantization reduces memory requirements by storing model weights in lower precision formats (int8, int4) instead of full precision (float32). This enables: +- Running larger models on limited hardware +- Faster inference (reduced memory bandwidth) +- Lower deployment costs +- Enabling fine-tuning of models that wouldn't fit in memory + +**Tradeoffs:** +- Slight accuracy loss (typically < 1-2%) +- Initial quantization overhead +- Some methods require calibration data + +## Quick Comparison + +| Method | Precision | Speed | Accuracy | Fine-tuning | Hardware | Setup | +|--------|-----------|-------|----------|-------------|----------|-------| +| **Bitsandbytes** | 4/8-bit | Fast | High | Yes (PEFT) | CUDA, CPU | Easy | +| **GPTQ** | 2-8-bit | Very Fast | High | Limited | CUDA, ROCm, Metal | Medium | +| **AWQ** | 4-bit | Very Fast | High | Yes (PEFT) | CUDA, ROCm | Medium | +| **GGUF** | 1-8-bit | Medium | Variable | No | CPU-optimized | Easy | +| **HQQ** | 1-8-bit | Fast | High | Yes | Multi-platform | Medium | + +## Bitsandbytes (BnB) + +On-the-fly quantization with excellent PEFT fine-tuning support. + +### 8-bit Quantization + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + load_in_8bit=True, # Enable 8-bit quantization + device_map="auto", # Automatic device placement +) + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + +# Use normally +inputs = tokenizer("Hello, how are you?", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_new_tokens=50) +``` + +**Memory Savings:** +- 7B model: ~14GB → ~7GB (50% reduction) +- 13B model: ~26GB → ~13GB +- 70B model: ~140GB → ~70GB + +**Characteristics:** +- Fast inference +- Minimal accuracy loss +- Works with PEFT (LoRA, QLoRA) +- Supports CPU and CUDA GPUs + +### 4-bit Quantization (QLoRA) + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig +import torch + +# Configure 4-bit quantization +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, # Enable 4-bit quantization + bnb_4bit_quant_type="nf4", # Quantization type ("nf4" or "fp4") + bnb_4bit_compute_dtype=torch.float16, # Computation dtype + bnb_4bit_use_double_quant=True, # Nested quantization for more savings +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=bnb_config, + device_map="auto", +) +``` + +**Memory Savings:** +- 7B model: ~14GB → ~4GB (70% reduction) +- 13B model: ~26GB → ~7GB +- 70B model: ~140GB → ~35GB + +**Quantization Types:** +- `nf4`: Normal Float 4 (recommended, better quality) +- `fp4`: Float Point 4 (slightly more memory efficient) + +**Compute Dtype:** +```python +# For better quality +bnb_4bit_compute_dtype=torch.float16 + +# For best performance on Ampere+ GPUs +bnb_4bit_compute_dtype=torch.bfloat16 +``` + +**Double Quantization:** +```python +# Enable for additional ~0.4 bits/param savings +bnb_4bit_use_double_quant=True # Quantize the quantization constants +``` + +### Fine-tuning with QLoRA + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import torch + +# Load quantized model +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=bnb_config, + device_map="auto", +) + +# Prepare for training +model = prepare_model_for_kbit_training(model) + +# Configure LoRA +lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) + +# Train normally +trainer = Trainer(model=model, args=training_args, ...) +trainer.train() +``` + +## GPTQ + +Post-training quantization requiring calibration, optimized for inference speed. + +### Loading GPTQ Models + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig + +# Load pre-quantized GPTQ model +model = AutoModelForCausalLM.from_pretrained( + "TheBloke/Llama-2-7B-GPTQ", # Pre-quantized model + device_map="auto", + revision="gptq-4bit-32g-actorder_True", # Specific quantization config +) + +# Or quantize yourself +gptq_config = GPTQConfig( + bits=4, # 2, 3, 4, 8 bits + dataset="c4", # Calibration dataset + tokenizer=tokenizer, +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + device_map="auto", + quantization_config=gptq_config, +) + +# Save quantized model +model.save_pretrained("llama-2-7b-gptq") +``` + +**Configuration Options:** +```python +gptq_config = GPTQConfig( + bits=4, # Quantization bits + group_size=128, # Group size for quantization (128, 32, -1) + dataset="c4", # Calibration dataset + desc_act=False, # Activation order (can improve accuracy) + sym=True, # Symmetric quantization + damp_percent=0.1, # Dampening factor +) +``` + +**Characteristics:** +- Fastest inference among quantization methods +- Requires one-time calibration (slow) +- Best when using pre-quantized models from Hub +- Limited fine-tuning support +- Excellent for production deployment + +## AWQ (Activation-aware Weight Quantization) + +Protects important weights for better quality. + +### Loading AWQ Models + +```python +from transformers import AutoModelForCausalLM, AwqConfig + +# Load pre-quantized AWQ model +model = AutoModelForCausalLM.from_pretrained( + "TheBloke/Llama-2-7B-AWQ", + device_map="auto", +) + +# Or quantize yourself +awq_config = AwqConfig( + bits=4, # 4-bit quantization + group_size=128, # Quantization group size + zero_point=True, # Use zero-point quantization + version="GEMM", # Quantization version +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=awq_config, + device_map="auto", +) +``` + +**Characteristics:** +- Better accuracy than GPTQ at same bit width +- Excellent inference speed +- Supports PEFT fine-tuning +- Requires calibration data + +### Fine-tuning AWQ Models + +```python +from peft import LoraConfig, get_peft_model + +# AWQ models support LoRA fine-tuning +lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +trainer = Trainer(model=model, ...) +trainer.train() +``` + +## GGUF (GGML Format) + +CPU-optimized quantization format, popular in llama.cpp ecosystem. + +### Using GGUF Models + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load GGUF model +model = AutoModelForCausalLM.from_pretrained( + "TheBloke/Llama-2-7B-GGUF", + gguf_file="llama-2-7b.Q4_K_M.gguf", # Specific quantization file + device_map="auto", +) + +tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-GGUF") +``` + +**GGUF Quantization Types:** +- `Q4_0`: 4-bit, smallest, lowest quality +- `Q4_K_M`: 4-bit, medium quality (recommended) +- `Q5_K_M`: 5-bit, good quality +- `Q6_K`: 6-bit, high quality +- `Q8_0`: 8-bit, very high quality + +**Characteristics:** +- Optimized for CPU inference +- Wide range of bit depths (1-8) +- Good for Apple Silicon (M1/M2) +- No fine-tuning support +- Excellent for local/edge deployment + +## HQQ (Half-Quadratic Quantization) + +Flexible quantization with good accuracy retention. + +### Using HQQ + +```python +from transformers import AutoModelForCausalLM, HqqConfig + +hqq_config = HqqConfig( + nbits=4, # Quantization bits + group_size=64, # Group size + quant_zero=False, # Quantize zero point + quant_scale=False, # Quantize scale + axis=0, # Quantization axis +) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=hqq_config, + device_map="auto", +) +``` + +**Characteristics:** +- Very fast quantization +- No calibration data needed +- Support for 1-8 bits +- Can serialize/deserialize +- Good accuracy vs size tradeoff + +## Choosing a Quantization Method + +### Decision Tree + +**For inference only:** +1. Need fastest inference? → **GPTQ or AWQ** (use pre-quantized models) +2. CPU-only deployment? → **GGUF** +3. Want easiest setup? → **Bitsandbytes 8-bit** +4. Need extreme compression? → **GGUF Q4_0 or HQQ 2-bit** + +**For fine-tuning:** +1. Limited VRAM? → **QLoRA (BnB 4-bit + LoRA)** +2. Want best accuracy? → **Bitsandbytes 8-bit + LoRA** +3. Need very large models? → **QLoRA with double quantization** + +**For production:** +1. Latency-critical? → **GPTQ or AWQ** +2. Cost-optimized? → **Bitsandbytes 8-bit** +3. CPU deployment? → **GGUF** + +## Memory Requirements + +Approximate memory for Llama-2 7B model: + +| Method | Memory | vs FP16 | +|--------|--------|---------| +| FP32 | 28GB | 2x | +| FP16 / BF16 | 14GB | 1x | +| 8-bit (BnB) | 7GB | 0.5x | +| 4-bit (QLoRA) | 3.5GB | 0.25x | +| 4-bit Double Quant | 3GB | 0.21x | +| GPTQ 4-bit | 4GB | 0.29x | +| AWQ 4-bit | 4GB | 0.29x | + +**Note:** Add ~1-2GB for inference activations, KV cache, and framework overhead. + +## Best Practices + +### For Training + +```python +# QLoRA recommended configuration +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, # BF16 if available + bnb_4bit_use_double_quant=True, +) + +# LoRA configuration +lora_config = LoraConfig( + r=16, # Rank (8, 16, 32, 64) + lora_alpha=32, # Scaling (typically 2*r) + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) +``` + +### For Inference + +```python +# High-speed inference +model = AutoModelForCausalLM.from_pretrained( + "TheBloke/Llama-2-7B-GPTQ", + device_map="auto", + torch_dtype=torch.float16, # Use FP16 for activations +) + +# Balanced quality/speed +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + load_in_8bit=True, + device_map="auto", +) + +# Maximum compression +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ), + device_map="auto", +) +``` + +### Multi-GPU Setups + +```python +# Automatically distribute across GPUs +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b-hf", + load_in_4bit=True, + device_map="auto", # Automatic distribution + max_memory={0: "20GB", 1: "20GB"}, # Optional: limit per GPU +) + +# Manual device map +device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": 0, + # ... distribute layers ... + "model.norm": 1, + "lm_head": 1, +} + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b-hf", + load_in_4bit=True, + device_map=device_map, +) +``` + +## Troubleshooting + +**Issue: OOM during quantization** +```python +# Solution: Use low_cpu_mem_usage +model = AutoModelForCausalLM.from_pretrained( + "model-name", + quantization_config=config, + device_map="auto", + low_cpu_mem_usage=True, # Reduce CPU memory during loading +) +``` + +**Issue: Slow quantization** +```python +# GPTQ/AWQ take time to calibrate +# Solution: Use pre-quantized models from Hub +model = AutoModelForCausalLM.from_pretrained("TheBloke/Model-GPTQ") + +# Or use BnB for instant quantization +model = AutoModelForCausalLM.from_pretrained("model-name", load_in_4bit=True) +``` + +**Issue: Poor quality after quantization** +```python +# Try different quantization types +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", # Try "nf4" instead of "fp4" + bnb_4bit_compute_dtype=torch.bfloat16, # Use BF16 if available +) + +# Or use 8-bit instead of 4-bit +model = AutoModelForCausalLM.from_pretrained("model-name", load_in_8bit=True) +``` + +**Issue: Can't fine-tune quantized model** +```python +# Ensure using compatible quantization method +from peft import prepare_model_for_kbit_training + +model = prepare_model_for_kbit_training(model) + +# Only BnB and AWQ support PEFT fine-tuning +# GPTQ has limited support, GGUF doesn't support fine-tuning +``` + +## Performance Benchmarks + +Approximate generation speed (tokens/sec) for Llama-2 7B on A100 40GB: + +| Method | Speed | Memory | +|--------|-------|--------| +| FP16 | 100 tok/s | 14GB | +| 8-bit | 90 tok/s | 7GB | +| 4-bit QLoRA | 70 tok/s | 4GB | +| GPTQ 4-bit | 95 tok/s | 4GB | +| AWQ 4-bit | 95 tok/s | 4GB | + +**Note:** Actual performance varies by hardware, sequence length, and batch size. + +## Resources + +- **Pre-quantized models:** Search "GPTQ" or "AWQ" on Hugging Face Hub +- **BnB documentation:** https://github.com/TimDettmers/bitsandbytes +- **PEFT library:** https://github.com/huggingface/peft +- **QLoRA paper:** https://arxiv.org/abs/2305.14314 + +For task-specific quantization examples, see `training_guide.md`. diff --git a/scientific-packages/transformers/references/task_patterns.md b/scientific-packages/transformers/references/task_patterns.md new file mode 100644 index 0000000..3ebf89a --- /dev/null +++ b/scientific-packages/transformers/references/task_patterns.md @@ -0,0 +1,610 @@ +# Task-Specific Patterns + +Quick reference for implementing common tasks with Transformers. Each pattern includes the complete workflow from data loading to inference. + +## Text Classification + +Classify text into predefined categories (sentiment, topic, intent, etc.). + +```python +from transformers import ( + AutoTokenizer, AutoModelForSequenceClassification, + TrainingArguments, Trainer, DataCollatorWithPadding +) +from datasets import load_dataset + +# 1. Load data +dataset = load_dataset("imdb") + +# 2. Preprocess +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +def preprocess(examples): + return tokenizer(examples["text"], truncation=True, max_length=512) + +tokenized = dataset.map(preprocess, batched=True) + +# 3. Model +model = AutoModelForSequenceClassification.from_pretrained( + "bert-base-uncased", + num_labels=2, + id2label={0: "negative", 1: "positive"}, + label2id={"negative": 0, "positive": 1} +) + +# 4. Train +training_args = TrainingArguments( + output_dir="./results", + learning_rate=2e-5, + per_device_train_batch_size=16, + num_train_epochs=3, + eval_strategy="epoch", +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized["train"], + eval_dataset=tokenized["test"], + tokenizer=tokenizer, + data_collator=DataCollatorWithPadding(tokenizer=tokenizer), +) + +trainer.train() + +# 5. Inference +text = "This movie was fantastic!" +inputs = tokenizer(text, return_tensors="pt") +outputs = model(**inputs) +predictions = outputs.logits.argmax(-1) +print(model.config.id2label[predictions.item()]) # "positive" +``` + +## Token Classification (NER) + +Label each token in text (named entities, POS tags, etc.). + +```python +from transformers import AutoTokenizer, AutoModelForTokenClassification +from datasets import load_dataset + +# Load data (tokens and NER tags) +dataset = load_dataset("conll2003") + +tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + +def tokenize_and_align_labels(examples): + tokenized_inputs = tokenizer( + examples["tokens"], + truncation=True, + is_split_into_words=True + ) + + labels = [] + for i, label in enumerate(examples["ner_tags"]): + word_ids = tokenized_inputs.word_ids(batch_index=i) + label_ids = [] + previous_word_idx = None + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) # Special tokens + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) # Subword tokens + previous_word_idx = word_idx + labels.append(label_ids) + + tokenized_inputs["labels"] = labels + return tokenized_inputs + +tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True) + +# Model +label_list = dataset["train"].features["ner_tags"].feature.names +model = AutoModelForTokenClassification.from_pretrained( + "bert-base-cased", + num_labels=len(label_list), + id2label={i: label for i, label in enumerate(label_list)}, + label2id={label: i for i, label in enumerate(label_list)} +) + +# Training similar to classification +# ... (use Trainer with DataCollatorForTokenClassification) +``` + +## Question Answering (Extractive) + +Extract answer spans from context. + +```python +from transformers import AutoTokenizer, AutoModelForQuestionAnswering + +tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad") +model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad") + +question = "What is the capital of France?" +context = "Paris is the capital and most populous city of France." + +inputs = tokenizer(question, context, return_tensors="pt") +outputs = model(**inputs) + +# Get answer span +answer_start = outputs.start_logits.argmax() +answer_end = outputs.end_logits.argmax() + 1 +answer = tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end]) +print(answer) # "Paris" +``` + +## Text Generation + +Generate text continuations. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +prompt = "In the future, artificial intelligence will" +inputs = tokenizer(prompt, return_tensors="pt") + +outputs = model.generate( + **inputs, + max_new_tokens=100, + do_sample=True, + temperature=0.8, + top_p=0.95, + repetition_penalty=1.2, +) + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) +``` + +## Summarization + +Condense long text into summaries. + +```python +from transformers import ( + AutoTokenizer, AutoModelForSeq2SeqLM, + Seq2SeqTrainingArguments, Seq2SeqTrainer, + DataCollatorForSeq2Seq +) + +tokenizer = AutoTokenizer.from_pretrained("t5-small") +model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + +def preprocess(examples): + inputs = ["summarize: " + doc for doc in examples["document"]] + model_inputs = tokenizer(inputs, max_length=1024, truncation=True) + + labels = tokenizer( + examples["summary"], + max_length=128, + truncation=True + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs + +tokenized_dataset = dataset.map(preprocess, batched=True) + +# Training +training_args = Seq2SeqTrainingArguments( + output_dir="./results", + predict_with_generate=True, # Important for seq2seq + eval_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=8, + num_train_epochs=3, +) + +trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"], + tokenizer=tokenizer, + data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), +) + +trainer.train() + +# Inference +text = "Long article text here..." +inputs = tokenizer("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True) +outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True) +summary = tokenizer.decode(outputs[0], skip_special_tokens=True) +``` + +## Translation + +Translate text between languages. + +```python +from transformers import pipeline + +translator = pipeline("translation_en_to_fr", model="Helsinki-NLP/opus-mt-en-fr") +result = translator("Hello, how are you?") +print(result[0]["translation_text"]) # "Bonjour, comment allez-vous?" + +# For fine-tuning, similar to summarization with Seq2SeqTrainer +``` + +## Image Classification + +Classify images into categories. + +```python +from transformers import ( + AutoImageProcessor, AutoModelForImageClassification, + TrainingArguments, Trainer +) +from datasets import load_dataset +from PIL import Image + +# Load data +dataset = load_dataset("food101", split="train[:1000]") + +# Preprocess +processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + +def transform(examples): + examples["pixel_values"] = [ + processor(img.convert("RGB"), return_tensors="pt")["pixel_values"][0] + for img in examples["image"] + ] + return examples + +dataset = dataset.with_transform(transform) + +# Model +model = AutoModelForImageClassification.from_pretrained( + "google/vit-base-patch16-224", + num_labels=101, + ignore_mismatched_sizes=True +) + +# Training +training_args = TrainingArguments( + output_dir="./results", + remove_unused_columns=False, # Keep image data + eval_strategy="epoch", + learning_rate=5e-5, + per_device_train_batch_size=32, + num_train_epochs=3, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=processor, +) + +trainer.train() + +# Inference +image = Image.open("food.jpg") +inputs = processor(image, return_tensors="pt") +outputs = model(**inputs) +predicted_class = outputs.logits.argmax(-1).item() +print(model.config.id2label[predicted_class]) +``` + +## Object Detection + +Detect and localize objects in images. + +```python +from transformers import pipeline +from PIL import Image + +detector = pipeline("object-detection", model="facebook/detr-resnet-50") + +image = Image.open("street.jpg") +results = detector(image) + +for result in results: + print(f"{result['label']}: {result['score']:.2f} at {result['box']}") + # car: 0.98 at {'xmin': 123, 'ymin': 456, 'xmax': 789, 'ymax': 1011} +``` + +## Image Segmentation + +Segment images into regions. + +```python +from transformers import pipeline + +segmenter = pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic") + +image = "path/to/image.jpg" +segments = segmenter(image) + +for segment in segments: + print(f"{segment['label']}: {segment['score']:.2f}") + # Access mask: segment['mask'] +``` + +## Image Captioning + +Generate textual descriptions of images. + +```python +from transformers import AutoProcessor, AutoModelForCausalLM +from PIL import Image + +processor = AutoProcessor.from_pretrained("microsoft/git-base") +model = AutoModelForCausalLM.from_pretrained("microsoft/git-base") + +image = Image.open("photo.jpg") +inputs = processor(images=image, return_tensors="pt") + +outputs = model.generate(**inputs, max_length=50) +caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] +print(caption) # "a dog sitting on grass" +``` + +## Speech Recognition (ASR) + +Transcribe speech to text. + +```python +from transformers import pipeline + +transcriber = pipeline( + "automatic-speech-recognition", + model="openai/whisper-base" +) + +result = transcriber("audio.mp3") +print(result["text"]) # "Hello, this is a test." + +# With timestamps +result = transcriber("audio.mp3", return_timestamps=True) +for chunk in result["chunks"]: + print(f"[{chunk['timestamp'][0]:.1f}s - {chunk['timestamp'][1]:.1f}s]: {chunk['text']}") +``` + +## Text-to-Speech + +Generate speech from text. + +```python +from transformers import pipeline + +synthesizer = pipeline("text-to-speech", model="microsoft/speecht5_tts") + +result = synthesizer("Hello, how are you today?") +# result["audio"] contains the waveform +# result["sampling_rate"] contains the sample rate + +# Save audio +import scipy +scipy.io.wavfile.write("output.wav", result["sampling_rate"], result["audio"][0]) +``` + +## Visual Question Answering + +Answer questions about images. + +```python +from transformers import pipeline +from PIL import Image + +vqa = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") + +image = Image.open("photo.jpg") +question = "What color is the car?" + +result = vqa(image=image, question=question) +print(result[0]["answer"]) # "red" +``` + +## Document Question Answering + +Extract information from documents (PDFs, images with text). + +```python +from transformers import pipeline + +doc_qa = pipeline("document-question-answering", model="impira/layoutlm-document-qa") + +result = doc_qa( + image="invoice.png", + question="What is the total amount?" +) + +print(result["answer"]) # "$1,234.56" +``` + +## Zero-Shot Classification + +Classify without training data. + +```python +from transformers import pipeline + +classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + +text = "This is a delicious Italian restaurant with great pasta." +candidate_labels = ["food", "travel", "technology", "sports"] + +result = classifier(text, candidate_labels) +print(result["labels"][0]) # "food" +print(result["scores"][0]) # 0.95 +``` + +## Few-Shot Learning with LLMs + +Use large language models for few-shot tasks. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + +# Few-shot prompt +prompt = """ +Classify the sentiment: positive, negative, or neutral. + +Text: "I love this product!" +Sentiment: positive + +Text: "This is terrible." +Sentiment: negative + +Text: "It's okay, nothing special." +Sentiment: neutral + +Text: "Best purchase ever!" +Sentiment:""" + +inputs = tokenizer(prompt, return_tensors="pt") +outputs = model.generate(**inputs, max_new_tokens=5, temperature=0.1) +response = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(response.split("Sentiment:")[-1].strip()) # "positive" +``` + +## Instruction-Following / Chat + +Use instruction-tuned models. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is machine learning?"}, +] + +formatted = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) + +inputs = tokenizer(formatted, return_tensors="pt") +outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7) +response = tokenizer.decode(outputs[0], skip_special_tokens=True) + +# Extract assistant response +assistant_response = response.split("[/INST]")[-1].strip() +print(assistant_response) +``` + +## Embeddings / Semantic Search + +Generate embeddings for semantic similarity. + +```python +from transformers import AutoTokenizer, AutoModel +import torch + +tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") +model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + +def get_embedding(text): + inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) + with torch.no_grad(): + outputs = model(**inputs) + + # Mean pooling + embeddings = outputs.last_hidden_state.mean(dim=1) + return embeddings + +# Get embeddings +text1 = "Machine learning is a subset of AI" +text2 = "AI includes machine learning" + +emb1 = get_embedding(text1) +emb2 = get_embedding(text2) + +# Compute similarity +similarity = torch.nn.functional.cosine_similarity(emb1, emb2) +print(f"Similarity: {similarity.item():.4f}") # ~0.85 +``` + +## Multimodal Understanding (CLIP) + +Connect vision and language. + +```python +from transformers import CLIPProcessor, CLIPModel +from PIL import Image + +model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") +processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + +image = Image.open("photo.jpg") +texts = ["a dog", "a cat", "a car", "a house"] + +inputs = processor(text=texts, images=image, return_tensors="pt", padding=True) +outputs = model(**inputs) + +# Get similarity scores +logits_per_image = outputs.logits_per_image +probs = logits_per_image.softmax(dim=1) + +for text, prob in zip(texts, probs[0]): + print(f"{text}: {prob.item():.4f}") +``` + +## Common Evaluation Metrics + +```python +from datasets import load_metric + +# Accuracy (classification) +metric = load_metric("accuracy") +predictions = [0, 1, 1, 0] +references = [0, 1, 0, 0] +result = metric.compute(predictions=predictions, references=references) + +# F1 Score (classification, NER) +metric = load_metric("f1") +result = metric.compute(predictions=predictions, references=references) + +# BLEU (translation) +metric = load_metric("bleu") +predictions = ["hello there general kenobi"] +references = [["hello there general kenobi", "hello there!"]] +result = metric.compute(predictions=predictions, references=references) + +# ROUGE (summarization) +metric = load_metric("rouge") +predictions = ["summary text"] +references = ["reference summary"] +result = metric.compute(predictions=predictions, references=references) +``` + +## Common Data Collators + +```python +from transformers import ( + DataCollatorWithPadding, + DataCollatorForTokenClassification, + DataCollatorForSeq2Seq, + DataCollatorForLanguageModeling, +) + +# Classification: dynamic padding +DataCollatorWithPadding(tokenizer=tokenizer) + +# NER: pad labels too +DataCollatorForTokenClassification(tokenizer=tokenizer) + +# Seq2Seq: pad inputs and labels +DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) + +# Language modeling: create MLM masks +DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) +``` + +This covers the most common task patterns. For detailed parameter tuning, see `api_reference.md` and `generation_strategies.md`. diff --git a/scientific-packages/transformers/scripts/fine_tune_classifier.py b/scientific-packages/transformers/scripts/fine_tune_classifier.py new file mode 100755 index 0000000..6eba340 --- /dev/null +++ b/scientific-packages/transformers/scripts/fine_tune_classifier.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +""" +Complete example for fine-tuning a text classification model. + +This script demonstrates the full workflow: +1. Load dataset +2. Preprocess with tokenizer +3. Configure model +4. Train with Trainer +5. Evaluate and save + +Usage: + python fine_tune_classifier.py --model bert-base-uncased --dataset imdb --epochs 3 +""" + +import argparse +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + TrainingArguments, + Trainer, + DataCollatorWithPadding, +) +import evaluate +import numpy as np + + +def compute_metrics(eval_pred): + """Compute accuracy and F1 score.""" + metric_accuracy = evaluate.load("accuracy") + metric_f1 = evaluate.load("f1") + + logits, labels = eval_pred + predictions = np.argmax(logits, axis=-1) + + accuracy = metric_accuracy.compute(predictions=predictions, references=labels) + f1 = metric_f1.compute(predictions=predictions, references=labels) + + return {"accuracy": accuracy["accuracy"], "f1": f1["f1"]} + + +def main(): + parser = argparse.ArgumentParser(description="Fine-tune a text classification model") + parser.add_argument( + "--model", + type=str, + default="bert-base-uncased", + help="Pretrained model name or path", + ) + parser.add_argument( + "--dataset", + type=str, + default="imdb", + help="Dataset name from Hugging Face Hub", + ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Maximum samples to use (for quick testing)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./results", + help="Output directory for checkpoints", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Batch size per device", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=2e-5, + help="Learning rate", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push model to Hugging Face Hub after training", + ) + + args = parser.parse_args() + + print("=" * 60) + print("Text Classification Fine-Tuning") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Dataset: {args.dataset}") + print(f"Epochs: {args.epochs}") + print(f"Batch size: {args.batch_size}") + print(f"Learning rate: {args.learning_rate}") + print("=" * 60) + + # 1. Load dataset + print("\n[1/5] Loading dataset...") + dataset = load_dataset(args.dataset) + + if args.max_samples: + dataset["train"] = dataset["train"].select(range(args.max_samples)) + dataset["test"] = dataset["test"].select(range(args.max_samples // 5)) + + print(f"Train samples: {len(dataset['train'])}") + print(f"Test samples: {len(dataset['test'])}") + + # 2. Preprocess + print("\n[2/5] Preprocessing data...") + tokenizer = AutoTokenizer.from_pretrained(args.model) + + def preprocess_function(examples): + return tokenizer(examples["text"], truncation=True, max_length=512) + + tokenized_dataset = dataset.map(preprocess_function, batched=True) + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + # 3. Load model + print("\n[3/5] Loading model...") + + # Determine number of labels + num_labels = len(set(dataset["train"]["label"])) + + model = AutoModelForSequenceClassification.from_pretrained( + args.model, + num_labels=num_labels, + ) + + print(f"Number of labels: {num_labels}") + print(f"Model parameters: {model.num_parameters():,}") + + # 4. Configure training + print("\n[4/5] Configuring training...") + training_args = TrainingArguments( + output_dir=args.output_dir, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + num_train_epochs=args.epochs, + weight_decay=0.01, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + push_to_hub=args.push_to_hub, + logging_steps=100, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["test"], + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + + # 5. Train + print("\n[5/5] Training...") + print("-" * 60) + trainer.train() + + # Evaluate + print("\n" + "=" * 60) + print("Final Evaluation") + print("=" * 60) + metrics = trainer.evaluate() + + print(f"Accuracy: {metrics['eval_accuracy']:.4f}") + print(f"F1 Score: {metrics['eval_f1']:.4f}") + print(f"Loss: {metrics['eval_loss']:.4f}") + + # Save + print("\n" + "=" * 60) + print(f"Saving model to {args.output_dir}") + trainer.save_model(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + if args.push_to_hub: + print("Pushing to Hugging Face Hub...") + trainer.push_to_hub() + + print("=" * 60) + print("Training complete!") + print("=" * 60) + + # Quick inference example + print("\nQuick inference example:") + from transformers import pipeline + + classifier = pipeline( + "text-classification", + model=args.output_dir, + tokenizer=args.output_dir, + ) + + example_text = "This is a great example of how to use transformers!" + result = classifier(example_text) + print(f"Text: {example_text}") + print(f"Prediction: {result[0]['label']} (score: {result[0]['score']:.4f})") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/transformers/scripts/generate_text.py b/scientific-packages/transformers/scripts/generate_text.py new file mode 100755 index 0000000..f813a9c --- /dev/null +++ b/scientific-packages/transformers/scripts/generate_text.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +Text generation with various strategies. + +This script demonstrates different generation strategies: +- Greedy decoding +- Beam search +- Sampling with temperature +- Top-k and top-p sampling + +Usage: + python generate_text.py --model gpt2 --prompt "The future of AI" --strategy sampling +""" + +import argparse +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def generate_with_greedy(model, tokenizer, prompt, max_length): + """Greedy decoding (deterministic).""" + print("\n" + "=" * 60) + print("GREEDY DECODING") + print("=" * 60) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=max_length, + pad_token_id=tokenizer.eos_token_id, + ) + + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"\nGenerated:\n{text}") + + +def generate_with_beam_search(model, tokenizer, prompt, max_length, num_beams=5): + """Beam search for higher quality.""" + print("\n" + "=" * 60) + print(f"BEAM SEARCH (num_beams={num_beams})") + print("=" * 60) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=max_length, + num_beams=num_beams, + early_stopping=True, + no_repeat_ngram_size=2, + pad_token_id=tokenizer.eos_token_id, + ) + + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"\nGenerated:\n{text}") + + +def generate_with_sampling(model, tokenizer, prompt, max_length, temperature=0.8): + """Sampling with temperature.""" + print("\n" + "=" * 60) + print(f"SAMPLING (temperature={temperature})") + print("=" * 60) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=max_length, + do_sample=True, + temperature=temperature, + pad_token_id=tokenizer.eos_token_id, + ) + + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"\nGenerated:\n{text}") + + +def generate_with_top_k_top_p(model, tokenizer, prompt, max_length, top_k=50, top_p=0.95, temperature=0.8): + """Top-k and top-p (nucleus) sampling.""" + print("\n" + "=" * 60) + print(f"TOP-K TOP-P SAMPLING (k={top_k}, p={top_p}, temp={temperature})") + print("=" * 60) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=max_length, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=1.2, + no_repeat_ngram_size=3, + pad_token_id=tokenizer.eos_token_id, + ) + + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"\nGenerated:\n{text}") + + +def generate_multiple(model, tokenizer, prompt, max_length, num_sequences=3): + """Generate multiple diverse sequences.""" + print("\n" + "=" * 60) + print(f"MULTIPLE SEQUENCES (n={num_sequences})") + print("=" * 60) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=max_length, + do_sample=True, + num_return_sequences=num_sequences, + temperature=0.9, + top_p=0.95, + pad_token_id=tokenizer.eos_token_id, + ) + + print(f"\nPrompt: {prompt}\n") + for i, output in enumerate(outputs, 1): + text = tokenizer.decode(output, skip_special_tokens=True) + print(f"\n--- Sequence {i} ---\n{text}\n") + + +def main(): + parser = argparse.ArgumentParser(description="Text generation with various strategies") + parser.add_argument( + "--model", + type=str, + default="gpt2", + help="Model name or path", + ) + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Input prompt for generation", + ) + parser.add_argument( + "--strategy", + type=str, + default="all", + choices=["greedy", "beam", "sampling", "top_k_top_p", "multiple", "all"], + help="Generation strategy to use", + ) + parser.add_argument( + "--max-length", + type=int, + default=100, + help="Maximum number of new tokens to generate", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device (cuda, cpu, or auto)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--quantize", + action="store_true", + help="Use 8-bit quantization", + ) + + args = parser.parse_args() + + print("=" * 60) + print("Text Generation Demo") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Strategy: {args.strategy}") + print(f"Max length: {args.max_length}") + print(f"Device: {args.device}") + print("=" * 60) + + # Load model and tokenizer + print("\nLoading model...") + + if args.device == "auto": + device_map = "auto" + device = None + else: + device_map = None + device = args.device + + model_kwargs = {"device_map": device_map} if device_map else {} + + if args.quantize: + print("Using 8-bit quantization...") + model_kwargs["load_in_8bit"] = True + + model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(args.model) + + if device and not device_map: + model = model.to(device) + + print(f"Model loaded on: {model.device if hasattr(model, 'device') else 'multiple devices'}") + + # Generate based on strategy + strategies = { + "greedy": lambda: generate_with_greedy(model, tokenizer, args.prompt, args.max_length), + "beam": lambda: generate_with_beam_search(model, tokenizer, args.prompt, args.max_length), + "sampling": lambda: generate_with_sampling(model, tokenizer, args.prompt, args.max_length, args.temperature), + "top_k_top_p": lambda: generate_with_top_k_top_p(model, tokenizer, args.prompt, args.max_length), + "multiple": lambda: generate_multiple(model, tokenizer, args.prompt, args.max_length), + } + + if args.strategy == "all": + for strategy_fn in strategies.values(): + strategy_fn() + else: + strategies[args.strategy]() + + print("\n" + "=" * 60) + print("Generation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/transformers/scripts/quick_inference.py b/scientific-packages/transformers/scripts/quick_inference.py new file mode 100755 index 0000000..8f931f5 --- /dev/null +++ b/scientific-packages/transformers/scripts/quick_inference.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +""" +Quick inference script using Transformers pipelines. + +This script demonstrates how to use various pipeline tasks for quick inference +without manually managing models, tokenizers, or preprocessing. + +Usage: + python quick_inference.py --task text-generation --model gpt2 --input "Hello world" + python quick_inference.py --task sentiment-analysis --input "I love this!" +""" + +import argparse +from transformers import pipeline, infer_device + + +def main(): + parser = argparse.ArgumentParser(description="Quick inference with Transformers pipelines") + parser.add_argument( + "--task", + type=str, + required=True, + help="Pipeline task (text-generation, sentiment-analysis, question-answering, etc.)", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model name or path (default: use task default)", + ) + parser.add_argument( + "--input", + type=str, + required=True, + help="Input text for inference", + ) + parser.add_argument( + "--context", + type=str, + default=None, + help="Context for question-answering tasks", + ) + parser.add_argument( + "--max-length", + type=int, + default=50, + help="Maximum generation length", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Device (cuda, cpu, or auto-detect)", + ) + + args = parser.parse_args() + + # Auto-detect device if not specified + if args.device is None: + device = infer_device() + else: + device = args.device + + print(f"Using device: {device}") + print(f"Task: {args.task}") + print(f"Model: {args.model or 'default'}") + print("-" * 50) + + # Create pipeline + pipe = pipeline( + args.task, + model=args.model, + device=device, + ) + + # Run inference based on task + if args.task == "question-answering": + if args.context is None: + print("Error: --context required for question-answering") + return + result = pipe(question=args.input, context=args.context) + print(f"Question: {args.input}") + print(f"Context: {args.context}") + print(f"\nAnswer: {result['answer']}") + print(f"Score: {result['score']:.4f}") + + elif args.task == "text-generation": + result = pipe(args.input, max_length=args.max_length) + print(f"Prompt: {args.input}") + print(f"\nGenerated: {result[0]['generated_text']}") + + elif args.task in ["sentiment-analysis", "text-classification"]: + result = pipe(args.input) + print(f"Text: {args.input}") + print(f"\nLabel: {result[0]['label']}") + print(f"Score: {result[0]['score']:.4f}") + + else: + # Generic handling for other tasks + result = pipe(args.input) + print(f"Input: {args.input}") + print(f"\nResult: {result}") + + +if __name__ == "__main__": + main() diff --git a/scientific-packages/umap-learn/SKILL.md b/scientific-packages/umap-learn/SKILL.md new file mode 100644 index 0000000..08daf9c --- /dev/null +++ b/scientific-packages/umap-learn/SKILL.md @@ -0,0 +1,485 @@ +--- +name: umap-learn +description: Guide for using UMAP (Uniform Manifold Approximation and Projection) for dimensionality reduction, visualization, and clustering. Use this skill when working with high-dimensional data that needs to be reduced for visualization, machine learning pipelines, or clustering tasks. Triggers include requests for dimensionality reduction, manifold learning, data visualization in 2D/3D, UMAP-based clustering, or supervised feature engineering. +--- + +# UMAP-Learn + +## Overview + +UMAP (Uniform Manifold Approximation and Projection) is a dimensionality reduction technique designed for both visualization and general non-linear dimensionality reduction. It is faster than t-SNE while producing comparable or superior results, and uniquely scales well to higher embedding dimensions (beyond 2D/3D). UMAP preserves both local and global structure in data and supports supervised learning, making it versatile for visualization, clustering preprocessing, and feature engineering. + +**Key capabilities:** +- Fast, scalable dimensionality reduction for visualization +- Supervised and semi-supervised learning with label information +- Effective preprocessing for density-based clustering (HDBSCAN) +- Transform new data using trained models +- Parametric embeddings via neural networks +- Inverse transforms for data reconstruction + +## Quick Start + +### Installation + +```bash +# Via conda +conda install -c conda-forge umap-learn + +# Via pip +pip install umap-learn +``` + +### Basic Usage + +UMAP follows scikit-learn conventions and can be used as a drop-in replacement for t-SNE or PCA. + +```python +import umap +from sklearn.preprocessing import StandardScaler + +# Prepare data (standardization is essential) +scaled_data = StandardScaler().fit_transform(data) + +# Method 1: Single step (fit and transform) +embedding = umap.UMAP().fit_transform(scaled_data) + +# Method 2: Separate steps (for reusing trained model) +reducer = umap.UMAP(random_state=42) +reducer.fit(scaled_data) +embedding = reducer.embedding_ # Access the trained embedding +``` + +**Critical preprocessing requirement:** Always standardize features to comparable scales before applying UMAP to ensure equal weighting across dimensions. + +### Typical Workflow + +```python +import umap +import matplotlib.pyplot as plt +from sklearn.preprocessing import StandardScaler + +# 1. Preprocess data +scaler = StandardScaler() +scaled_data = scaler.fit_transform(raw_data) + +# 2. Create and fit UMAP +reducer = umap.UMAP( + n_neighbors=15, + min_dist=0.1, + n_components=2, + metric='euclidean', + random_state=42 +) +embedding = reducer.fit_transform(scaled_data) + +# 3. Visualize +plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='Spectral', s=5) +plt.colorbar() +plt.title('UMAP Embedding') +plt.show() +``` + +## Parameter Tuning Guide + +UMAP has four primary parameters that control the embedding behavior. Understanding these is crucial for effective usage. + +### n_neighbors (default: 15) + +**Purpose:** Balances local versus global structure in the embedding. + +**How it works:** Controls the size of the local neighborhood UMAP examines when learning manifold structure. + +**Effects by value:** +- **Low values (2-5):** Emphasizes fine local detail but may fragment data into disconnected components +- **Medium values (15-20):** Balanced view of both local structure and global relationships (recommended starting point) +- **High values (50-200):** Prioritizes broad topological structure at the expense of fine-grained details + +**Recommendation:** Start with 15 and adjust based on results. Increase for more global structure, decrease for more local detail. + +### min_dist (default: 0.1) + +**Purpose:** Controls how tightly points cluster in the low-dimensional space. + +**How it works:** Sets the minimum distance apart that points are allowed to be in the output representation. + +**Effects by value:** +- **Low values (0.0-0.1):** Creates clumped embeddings useful for clustering; reveals fine topological details +- **High values (0.5-0.99):** Prevents tight packing; emphasizes broad topological preservation over local structure + +**Recommendation:** Use 0.0 for clustering applications, 0.1-0.3 for visualization, 0.5+ for loose structure. + +### n_components (default: 2) + +**Purpose:** Determines the dimensionality of the embedded output space. + +**Key feature:** Unlike t-SNE, UMAP scales well in the embedding dimension, enabling use beyond visualization. + +**Common uses:** +- **2-3 dimensions:** Visualization +- **5-10 dimensions:** Clustering preprocessing (better preserves density than 2D) +- **10-50 dimensions:** Feature engineering for downstream ML models + +**Recommendation:** Use 2 for visualization, 5-10 for clustering, higher for ML pipelines. + +### metric (default: 'euclidean') + +**Purpose:** Specifies how distance is calculated between input data points. + +**Supported metrics:** +- **Minkowski variants:** euclidean, manhattan, chebyshev +- **Spatial metrics:** canberra, braycurtis, haversine +- **Correlation metrics:** cosine, correlation (good for text/document embeddings) +- **Binary data metrics:** hamming, jaccard, dice, russellrao, kulsinski, rogerstanimoto, sokalmichener, sokalsneath, yule +- **Custom metrics:** User-defined distance functions via Numba + +**Recommendation:** Use euclidean for numeric data, cosine for text/document vectors, hamming for binary data. + +### Parameter Tuning Example + +```python +# For visualization with emphasis on local structure +umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean') + +# For clustering preprocessing +umap.UMAP(n_neighbors=30, min_dist=0.0, n_components=10, metric='euclidean') + +# For document embeddings +umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, metric='cosine') + +# For preserving global structure +umap.UMAP(n_neighbors=100, min_dist=0.5, n_components=2, metric='euclidean') +``` + +## Supervised and Semi-Supervised Dimension Reduction + +UMAP supports incorporating label information to guide the embedding process, enabling class separation while preserving internal structure. + +### Supervised UMAP + +Pass target labels via the `y` parameter when fitting: + +```python +# Supervised dimension reduction +embedding = umap.UMAP().fit_transform(data, y=labels) +``` + +**Key benefits:** +- Achieves cleanly separated classes +- Preserves internal structure within each class +- Maintains global relationships between classes + +**When to use:** When you have labeled data and want to separate known classes while keeping meaningful point embeddings. + +### Semi-Supervised UMAP + +For partial labels, mark unlabeled points with `-1` following scikit-learn convention: + +```python +# Create semi-supervised labels +semi_labels = labels.copy() +semi_labels[unlabeled_indices] = -1 + +# Fit with partial labels +embedding = umap.UMAP().fit_transform(data, y=semi_labels) +``` + +**When to use:** When labeling is expensive or you have more data than labels available. + +### Metric Learning with UMAP + +Train a supervised embedding on labeled data, then apply to new unlabeled data: + +```python +# Train on labeled data +mapper = umap.UMAP().fit(train_data, train_labels) + +# Transform unlabeled test data +test_embedding = mapper.transform(test_data) + +# Use as feature engineering for downstream classifier +from sklearn.svm import SVC +clf = SVC().fit(mapper.embedding_, train_labels) +predictions = clf.predict(test_embedding) +``` + +**When to use:** For supervised feature engineering in machine learning pipelines. + +## UMAP for Clustering + +UMAP serves as effective preprocessing for density-based clustering algorithms like HDBSCAN, overcoming the curse of dimensionality. + +### Best Practices for Clustering + +**Key principle:** Configure UMAP differently for clustering than for visualization. + +**Recommended parameters:** +- **n_neighbors:** Increase to ~30 (default 15 is too local and can create artificial fine-grained clusters) +- **min_dist:** Set to 0.0 (pack points densely within clusters for clearer boundaries) +- **n_components:** Use 5-10 dimensions (maintains performance while improving density preservation vs. 2D) + +### Clustering Workflow + +```python +import umap +import hdbscan +from sklearn.preprocessing import StandardScaler + +# 1. Preprocess data +scaled_data = StandardScaler().fit_transform(data) + +# 2. UMAP with clustering-optimized parameters +reducer = umap.UMAP( + n_neighbors=30, + min_dist=0.0, + n_components=10, # Higher than 2 for better density preservation + metric='euclidean', + random_state=42 +) +embedding = reducer.fit_transform(scaled_data) + +# 3. Apply HDBSCAN clustering +clusterer = hdbscan.HDBSCAN( + min_cluster_size=15, + min_samples=5, + metric='euclidean' +) +labels = clusterer.fit_predict(embedding) + +# 4. Evaluate +from sklearn.metrics import adjusted_rand_score +score = adjusted_rand_score(true_labels, labels) +print(f"Adjusted Rand Score: {score:.3f}") +print(f"Number of clusters: {len(set(labels)) - (1 if -1 in labels else 0)}") +print(f"Noise points: {sum(labels == -1)}") +``` + +### Visualization After Clustering + +```python +# Create 2D embedding for visualization (separate from clustering) +vis_reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) +vis_embedding = vis_reducer.fit_transform(scaled_data) + +# Plot with cluster labels +import matplotlib.pyplot as plt +plt.scatter(vis_embedding[:, 0], vis_embedding[:, 1], c=labels, cmap='Spectral', s=5) +plt.colorbar() +plt.title('UMAP Visualization with HDBSCAN Clusters') +plt.show() +``` + +**Important caveat:** UMAP does not completely preserve density and can create artificial cluster divisions. Always validate and explore resulting clusters. + +## Transforming New Data + +UMAP enables preprocessing of new data through its `transform()` method, allowing trained models to project unseen data into the learned embedding space. + +### Basic Transform Usage + +```python +# Train on training data +trans = umap.UMAP(n_neighbors=15, random_state=42).fit(X_train) + +# Transform test data +test_embedding = trans.transform(X_test) +``` + +### Integration with Machine Learning Pipelines + +```python +from sklearn.svm import SVC +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +import umap + +# Split data +X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2) + +# Preprocess +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) +X_test_scaled = scaler.transform(X_test) + +# Train UMAP +reducer = umap.UMAP(n_components=10, random_state=42) +X_train_embedded = reducer.fit_transform(X_train_scaled) +X_test_embedded = reducer.transform(X_test_scaled) + +# Train classifier on embeddings +clf = SVC() +clf.fit(X_train_embedded, y_train) +accuracy = clf.score(X_test_embedded, y_test) +print(f"Test accuracy: {accuracy:.3f}") +``` + +### Important Considerations + +**Data consistency:** The transform method assumes the overall distribution in the higher-dimensional space is consistent between training and test data. When this assumption fails, consider using Parametric UMAP instead. + +**Performance:** Transform operations are efficient (typically <1 second), though initial calls may be slower due to Numba JIT compilation. + +**Scikit-learn compatibility:** UMAP follows standard sklearn conventions and works seamlessly in pipelines: + +```python +from sklearn.pipeline import Pipeline + +pipeline = Pipeline([ + ('scaler', StandardScaler()), + ('umap', umap.UMAP(n_components=10)), + ('classifier', SVC()) +]) + +pipeline.fit(X_train, y_train) +predictions = pipeline.predict(X_test) +``` + +## Advanced Features + +### Parametric UMAP + +Parametric UMAP replaces direct embedding optimization with a learned neural network mapping function. + +**Key differences from standard UMAP:** +- Uses TensorFlow/Keras to train encoder networks +- Enables efficient transformation of new data +- Supports reconstruction via decoder networks (inverse transform) +- Allows custom architectures (CNNs for images, RNNs for sequences) + +**Installation:** +```bash +pip install umap-learn[parametric_umap] +# Requires TensorFlow 2.x +``` + +**Basic usage:** +```python +from umap.parametric_umap import ParametricUMAP + +# Default architecture (3-layer 100-neuron fully-connected network) +embedder = ParametricUMAP() +embedding = embedder.fit_transform(data) + +# Transform new data efficiently +new_embedding = embedder.transform(new_data) +``` + +**Custom architecture:** +```python +import tensorflow as tf + +# Define custom encoder +encoder = tf.keras.Sequential([ + tf.keras.layers.InputLayer(input_shape=(input_dim,)), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(64, activation='relu'), + tf.keras.layers.Dense(2) # Output dimension +]) + +embedder = ParametricUMAP(encoder=encoder, dims=(input_dim,)) +embedding = embedder.fit_transform(data) +``` + +**When to use Parametric UMAP:** +- Need efficient transformation of new data after training +- Require reconstruction capabilities (inverse transforms) +- Want to combine UMAP with autoencoders +- Working with complex data types (images, sequences) benefiting from specialized architectures + +**When to use standard UMAP:** +- Need simplicity and quick prototyping +- Dataset is small and computational efficiency isn't critical +- Don't require learned transformations for future data + +### Inverse Transforms + +Inverse transforms enable reconstruction of high-dimensional data from low-dimensional embeddings. + +**Basic usage:** +```python +reducer = umap.UMAP() +embedding = reducer.fit_transform(data) + +# Reconstruct high-dimensional data from embedding coordinates +reconstructed = reducer.inverse_transform(embedding) +``` + +**Important limitations:** +- Computationally expensive operation +- Works poorly outside the convex hull of the embedding +- Accuracy decreases in regions with gaps between clusters + +**Use cases:** +- Understanding structure of embedded data +- Visualizing smooth transitions between clusters +- Exploring interpolations between data points +- Generating synthetic samples in embedding space + +**Example: Exploring embedding space:** +```python +import numpy as np + +# Create grid of points in embedding space +x = np.linspace(embedding[:, 0].min(), embedding[:, 0].max(), 10) +y = np.linspace(embedding[:, 1].min(), embedding[:, 1].max(), 10) +xx, yy = np.meshgrid(x, y) +grid_points = np.c_[xx.ravel(), yy.ravel()] + +# Reconstruct samples from grid +reconstructed_samples = reducer.inverse_transform(grid_points) +``` + +### AlignedUMAP + +For analyzing temporal or related datasets (e.g., time-series experiments, batch data): + +```python +from umap import AlignedUMAP + +# List of related datasets +datasets = [day1_data, day2_data, day3_data] + +# Create aligned embeddings +mapper = AlignedUMAP().fit(datasets) +aligned_embeddings = mapper.embeddings_ # List of embeddings +``` + +**When to use:** Comparing embeddings across related datasets while maintaining consistent coordinate systems. + +## Reproducibility + +To ensure reproducible results, always set the `random_state` parameter: + +```python +reducer = umap.UMAP(random_state=42) +``` + +UMAP uses stochastic optimization, so results will vary slightly between runs without a fixed random state. + +## Common Issues and Solutions + +**Issue:** Disconnected components or fragmented clusters +- **Solution:** Increase `n_neighbors` to emphasize more global structure + +**Issue:** Clusters too spread out or not well separated +- **Solution:** Decrease `min_dist` to allow tighter packing + +**Issue:** Poor clustering results +- **Solution:** Use clustering-specific parameters (n_neighbors=30, min_dist=0.0, n_components=5-10) + +**Issue:** Transform results differ significantly from training +- **Solution:** Ensure test data distribution matches training, or use Parametric UMAP + +**Issue:** Slow performance on large datasets +- **Solution:** Set `low_memory=True` (default), or consider dimensionality reduction with PCA first + +**Issue:** All points collapsed to single cluster +- **Solution:** Check data preprocessing (ensure proper scaling), increase `min_dist` + +## Resources + +### references/ + +Contains detailed API documentation: +- `api_reference.md`: Complete UMAP class parameters and methods + +Load these references when detailed parameter information or advanced method usage is needed. diff --git a/scientific-packages/umap-learn/references/api_reference.md b/scientific-packages/umap-learn/references/api_reference.md new file mode 100644 index 0000000..3e0dbef --- /dev/null +++ b/scientific-packages/umap-learn/references/api_reference.md @@ -0,0 +1,532 @@ +# UMAP API Reference + +## UMAP Class + +`umap.UMAP(n_neighbors=15, n_components=2, metric='euclidean', n_epochs=None, learning_rate=1.0, init='spectral', min_dist=0.1, spread=1.0, low_memory=True, set_op_mix_ratio=1.0, local_connectivity=1.0, repulsion_strength=1.0, negative_sample_rate=5, transform_queue_size=4.0, a=None, b=None, random_state=None, metric_kwds=None, angular_rp_forest=False, target_n_neighbors=-1, target_metric='categorical', target_metric_kwds=None, target_weight=0.5, transform_seed=42, transform_mode='embedding', force_approximation_algorithm=False, verbose=False, unique=False, densmap=False, dens_lambda=2.0, dens_frac=0.3, dens_var_shift=0.1, output_dens=False, disconnection_distance=None, precomputed_knn=(None, None, None))` + +Find low-dimensional embedding that approximates the underlying manifold of the data. + +### Core Parameters + +#### n_neighbors (int, default: 15) +Size of the local neighborhood used for manifold approximation. Larger values result in more global views of the manifold, while smaller values preserve more local structure. Generally in the range 2 to 100. + +**Tuning guidance:** +- Use 2-5 for very local structure +- Use 10-20 for balanced local/global structure (typical) +- Use 50-200 for emphasizing global structure + +#### n_components (int, default: 2) +Dimension of the embedding space. Unlike t-SNE, UMAP scales well with increasing embedding dimensions. + +**Common values:** +- 2-3: Visualization +- 5-10: Clustering preprocessing +- 10-100: Feature engineering for downstream ML + +#### metric (str or callable, default: 'euclidean') +Distance metric to use. Accepts: +- Any metric from scipy.spatial.distance +- Any metric from sklearn.metrics +- Custom callable distance functions (must be compiled with Numba) + +**Common metrics:** +- `'euclidean'`: Standard Euclidean distance (default) +- `'manhattan'`: L1 distance +- `'cosine'`: Cosine distance (good for text/document vectors) +- `'correlation'`: Correlation distance +- `'hamming'`: Hamming distance (for binary data) +- `'jaccard'`: Jaccard distance (for binary/set data) +- `'dice'`: Dice distance +- `'canberra'`: Canberra distance +- `'braycurtis'`: Bray-Curtis distance +- `'chebyshev'`: Chebyshev distance +- `'minkowski'`: Minkowski distance (specify p with metric_kwds) +- `'precomputed'`: Use precomputed distance matrix + +#### min_dist (float, default: 0.1) +Effective minimum distance between embedded points. Controls how tightly points are packed together. Smaller values result in clumpier embeddings. + +**Tuning guidance:** +- Use 0.0 for clustering applications +- Use 0.1-0.3 for visualization (balanced) +- Use 0.5-0.99 for loose structure preservation + +#### spread (float, default: 1.0) +Effective scale of embedded points. Combined with `min_dist` to control clumped vs. spread-out embeddings. Determines how spread out the clusters are in the embedding space. + +### Training Parameters + +#### n_epochs (int, default: None) +Number of training epochs. If None, automatically determined based on dataset size (typically 200-500 epochs). + +**Manual tuning:** +- Smaller datasets may need 500+ epochs +- Larger datasets may converge with 200 epochs +- More epochs = better optimization but slower training + +#### learning_rate (float, default: 1.0) +Initial learning rate for the SGD optimizer. Higher values lead to faster convergence but may overshoot optimal solutions. + +#### init (str or np.ndarray, default: 'spectral') +Initialization method for the embedding: +- `'spectral'`: Use spectral embedding (default, usually best) +- `'random'`: Random initialization +- `'pca'`: Initialize with PCA +- numpy array: Custom initialization (shape: (n_samples, n_components)) + +### Advanced Structural Parameters + +#### local_connectivity (int, default: 1.0) +Number of nearest neighbors assumed to be locally connected. Higher values give more connected manifolds. + +#### set_op_mix_ratio (float, default: 1.0) +Interpolation between union and intersection when constructing fuzzy set unions. Value of 1.0 uses pure union, 0.0 uses pure intersection. + +#### repulsion_strength (float, default: 1.0) +Weighting applied to negative samples in low-dimensional embedding optimization. Higher values push embedded points further apart. + +#### negative_sample_rate (int, default: 5) +Number of negative samples to select per positive sample. Higher values lead to greater repulsion between points and more spread-out embeddings but increase computational cost. + +### Supervised Learning Parameters + +#### target_n_neighbors (int, default: -1) +Number of nearest neighbors to use when constructing target simplicial set. If -1, uses n_neighbors value. + +#### target_metric (str, default: 'categorical') +Distance metric for target values (labels): +- `'categorical'`: For classification tasks +- Any other metric for regression tasks + +#### target_weight (float, default: 0.5) +Weight applied to target information vs. data structure. Range 0.0 to 1.0: +- 0.0: Pure unsupervised embedding (ignores labels) +- 0.5: Balanced (default) +- 1.0: Pure supervised embedding (only considers labels) + +### Transform Parameters + +#### transform_queue_size (float, default: 4.0) +Size of the nearest neighbor search queue for transform operations. Larger values improve transform accuracy but increase memory usage and computation time. + +#### transform_seed (int, default: 42) +Random seed for transform operations. Ensures reproducibility of transform results. + +#### transform_mode (str, default: 'embedding') +Method for transforming new data: +- `'embedding'`: Standard approach (default) +- `'graph'`: Use nearest neighbor graph + +### Performance Parameters + +#### low_memory (bool, default: True) +Whether to use a memory-efficient implementation. Set to False only if memory is not a constraint and you want faster performance. + +#### verbose (bool, default: False) +Whether to print progress messages during fitting. + +#### unique (bool, default: False) +Whether to consider only unique data points. Set to True if you know your data contains many duplicates to improve performance. + +#### force_approximation_algorithm (bool, default: False) +Force use of approximate nearest neighbor search even for small datasets. Can improve performance on large datasets. + +#### angular_rp_forest (bool, default: False) +Whether to use angular random projection forest for nearest neighbor search. Can improve performance for normalized data in high dimensions. + +### DensMAP Parameters + +DensMAP is a variant that preserves local density information. + +#### densmap (bool, default: False) +Whether to use the DensMAP algorithm instead of standard UMAP. Preserves local density in addition to topological structure. + +#### dens_lambda (float, default: 2.0) +Weight of density preservation term in DensMAP optimization. Higher values emphasize density preservation. + +#### dens_frac (float, default: 0.3) +Fraction of dataset used for density estimation in DensMAP. + +#### dens_var_shift (float, default: 0.1) +Regularization parameter for density estimation in DensMAP. + +#### output_dens (bool, default: False) +Whether to output local density estimates in addition to the embedding. Results stored in `rad_orig_` and `rad_emb_` attributes. + +### Other Parameters + +#### a (float, default: None) +Parameter controlling embedding. If None, determined automatically from min_dist and spread. + +#### b (float, default: None) +Parameter controlling embedding. If None, determined automatically from min_dist and spread. + +#### random_state (int, RandomState instance, or None, default: None) +Random state for reproducibility. Set to an integer for reproducible results. + +#### metric_kwds (dict, default: None) +Additional keyword arguments for the distance metric. + +#### disconnection_distance (float, default: None) +Distance threshold for considering points disconnected. If None, uses max distance in the graph. + +#### precomputed_knn (tuple, default: (None, None, None)) +Precomputed k-nearest neighbors as (knn_indices, knn_dists, knn_search_index). Useful for reusing expensive computations. + +## Methods + +### fit(X, y=None) +Fit the UMAP model to the data. + +**Parameters:** +- `X`: array-like, shape (n_samples, n_features) - Training data +- `y`: array-like, shape (n_samples,), optional - Target values for supervised dimension reduction + +**Returns:** +- `self`: Fitted UMAP object + +**Attributes set:** +- `embedding_`: The embedded representation of training data +- `graph_`: Fuzzy simplicial set approximation to the manifold +- `_raw_data`: Copy of the training data +- `_small_data`: Whether the dataset is considered small +- `_metric_kwds`: Processed metric keyword arguments +- `_n_neighbors`: Actual n_neighbors used +- `_initial_alpha`: Initial learning rate +- `_a`, `_b`: Curve parameters + +### fit_transform(X, y=None) +Fit the model and return the embedded representation. + +**Parameters:** +- `X`: array-like, shape (n_samples, n_features) - Training data +- `y`: array-like, shape (n_samples,), optional - Target values for supervised dimension reduction + +**Returns:** +- `X_new`: array, shape (n_samples, n_components) - Embedded data + +### transform(X) +Transform new data into the existing embedded space. + +**Parameters:** +- `X`: array-like, shape (n_samples, n_features) - New data to transform + +**Returns:** +- `X_new`: array, shape (n_samples, n_components) - Embedded representation of new data + +**Important notes:** +- The model must be fitted before calling transform +- Transform quality depends on similarity between training and test distributions +- For significantly different data distributions, consider Parametric UMAP + +### inverse_transform(X) +Transform data from the embedded space back to the original data space. + +**Parameters:** +- `X`: array-like, shape (n_samples, n_components) - Embedded data points + +**Returns:** +- `X_new`: array, shape (n_samples, n_features) - Reconstructed data in original space + +**Important notes:** +- Computationally expensive operation +- Works poorly outside the convex hull of the training embedding +- Reconstruction quality varies by region + +### update(X) +Update the model with new data. Allows incremental fitting. + +**Parameters:** +- `X`: array-like, shape (n_samples, n_features) - New data to incorporate + +**Returns:** +- `self`: Updated UMAP object + +**Note:** Experimental feature, may not preserve all properties of batch training. + +## Attributes + +### embedding_ +array, shape (n_samples, n_components) - The embedded representation of the training data. + +### graph_ +scipy.sparse.csr_matrix - The weighted adjacency matrix of the fuzzy simplicial set approximation to the manifold. + +### _raw_data +array - Copy of the raw training data. + +### _sparse_data +bool - Whether the training data was sparse. + +### _small_data +bool - Whether the dataset was considered small (uses different algorithm for small datasets). + +### _input_hash +str - Hash of the input data for caching purposes. + +### _knn_indices +array - Indices of k-nearest neighbors for each training point. + +### _knn_dists +array - Distances to k-nearest neighbors for each training point. + +### _rp_forest +list - Random projection forest used for approximate nearest neighbor search. + +## ParametricUMAP Class + +`umap.ParametricUMAP(encoder=None, decoder=None, parametric_reconstruction=False, autoencoder_loss=False, reconstruction_validation=None, dims=None, batch_size=None, n_training_epochs=1, loss_report_frequency=10, optimizer=None, keras_fit_kwargs={}, **kwargs)` + +Parametric UMAP using neural networks to learn the embedding function. + +### Additional Parameters (beyond UMAP) + +#### encoder (tensorflow.keras.Model, default: None) +Keras model for encoding data to embeddings. If None, uses default 3-layer architecture with 100 neurons per layer. + +#### decoder (tensorflow.keras.Model, default: None) +Keras model for decoding embeddings back to data space. Only used if parametric_reconstruction=True. + +#### parametric_reconstruction (bool, default: False) +Whether to use parametric reconstruction. Requires decoder model. + +#### autoencoder_loss (bool, default: False) +Whether to include reconstruction loss in the optimization. Requires decoder model. + +#### reconstruction_validation (tuple, default: None) +Validation data (X_val, y_val) for monitoring reconstruction loss during training. + +#### dims (tuple, default: None) +Input dimensions for the encoder network. Required if providing custom encoder. + +#### batch_size (int, default: None) +Batch size for neural network training. If None, determined automatically. + +#### n_training_epochs (int, default: 1) +Number of training epochs for the neural networks. More epochs improve quality but increase training time. + +#### loss_report_frequency (int, default: 10) +How often to report loss during training. + +#### optimizer (tensorflow.keras.optimizers.Optimizer, default: None) +Keras optimizer for training. If None, uses Adam with learning_rate parameter. + +#### keras_fit_kwargs (dict, default: {}) +Additional keyword arguments passed to the Keras fit() method. + +### Methods + +Same as UMAP class, but transform() and inverse_transform() use learned neural networks for faster inference. + +## Utility Functions + +### umap.nearest_neighbors(X, n_neighbors, metric, metric_kwds={}, angular=False, random_state=None) +Compute k-nearest neighbors for the data. + +**Returns:** (knn_indices, knn_dists, rp_forest) + +### umap.fuzzy_simplicial_set(X, n_neighbors, random_state, metric, metric_kwds={}, knn_indices=None, knn_dists=None, angular=False, set_op_mix_ratio=1.0, local_connectivity=1.0, apply_set_operations=True, verbose=False, return_dists=None) +Construct fuzzy simplicial set representation of the data. + +**Returns:** Fuzzy simplicial set as sparse matrix + +### umap.simplicial_set_embedding(data, graph, n_components, initial_alpha, a, b, gamma, negative_sample_rate, n_epochs, init, random_state, metric, metric_kwds, densmap, densmap_kwds, output_dens, output_metric, output_metric_kwds, euclidean_output, parallel=False, verbose=False) +Perform the optimization to find a low-dimensional embedding. + +**Returns:** Embedding array + +### umap.find_ab_params(spread, min_dist) +Fit a, b params for the UMAP curve from spread and min_dist. + +**Returns:** (a, b) tuple + +## AlignedUMAP Class + +`umap.AlignedUMAP(n_neighbors=15, n_components=2, metric='euclidean', alignment_regularisation=1e-2, alignment_window_size=3, **kwargs)` + +UMAP variant for aligning multiple related datasets. + +### Additional Parameters + +#### alignment_regularisation (float, default: 1e-2) +Strength of alignment regularization between datasets. + +#### alignment_window_size (int, default: 3) +Number of adjacent datasets to align. + +### Methods + +#### fit(X) +Fit model to multiple datasets. + +**Parameters:** +- `X`: list of arrays - List of datasets to align + +**Returns:** +- `self`: Fitted model + +### Attributes + +#### embeddings_ +list of arrays - List of aligned embeddings, one per input dataset. + +## Usage Examples + +### Basic Usage with All Common Parameters + +```python +import umap + +# Standard 2D visualization embedding +reducer = umap.UMAP( + n_neighbors=15, # Balance local/global structure + n_components=2, # Output dimensions + metric='euclidean', # Distance metric + min_dist=0.1, # Minimum distance between points + spread=1.0, # Scale of embedded points + random_state=42, # Reproducibility + n_epochs=200, # Training iterations (None = auto) + learning_rate=1.0, # SGD learning rate + init='spectral', # Initialization method + low_memory=True, # Memory-efficient mode + verbose=True # Print progress +) + +embedding = reducer.fit_transform(data) +``` + +### Supervised Learning + +```python +# Train with labels for class separation +reducer = umap.UMAP( + n_neighbors=15, + target_weight=0.5, # Balance data structure vs labels + target_metric='categorical', # Metric for labels + random_state=42 +) + +embedding = reducer.fit_transform(data, y=labels) +``` + +### Clustering Preprocessing + +```python +# Optimized for clustering +reducer = umap.UMAP( + n_neighbors=30, # More global structure + min_dist=0.0, # Allow tight packing + n_components=10, # Higher dimensions for density + metric='euclidean', + random_state=42 +) + +embedding = reducer.fit_transform(data) +``` + +### Custom Distance Metric + +```python +from numba import njit + +@njit() +def custom_distance(x, y): + """Custom distance function (must be Numba-compatible)""" + result = 0.0 + for i in range(x.shape[0]): + result += abs(x[i] - y[i]) + return result + +reducer = umap.UMAP(metric=custom_distance) +embedding = reducer.fit_transform(data) +``` + +### Parametric UMAP with Custom Architecture + +```python +import tensorflow as tf +from umap.parametric_umap import ParametricUMAP + +# Define custom encoder +encoder = tf.keras.Sequential([ + tf.keras.layers.InputLayer(input_shape=(input_dim,)), + tf.keras.layers.Dense(256, activation='relu'), + tf.keras.layers.Dropout(0.3), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dropout(0.3), + tf.keras.layers.Dense(2) # Output dimension +]) + +# Define decoder for reconstruction +decoder = tf.keras.Sequential([ + tf.keras.layers.InputLayer(input_shape=(2,)), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(256, activation='relu'), + tf.keras.layers.Dense(input_dim) +]) + +# Train parametric UMAP with autoencoder +embedder = ParametricUMAP( + encoder=encoder, + decoder=decoder, + dims=(input_dim,), + parametric_reconstruction=True, + autoencoder_loss=True, + n_training_epochs=10, + batch_size=128, + n_neighbors=15, + min_dist=0.1, + random_state=42 +) + +embedding = embedder.fit_transform(data) +new_embedding = embedder.transform(new_data) +reconstructed = embedder.inverse_transform(embedding) +``` + +### DensMAP for Density Preservation + +```python +# Preserve local density information +reducer = umap.UMAP( + densmap=True, # Enable DensMAP + dens_lambda=2.0, # Weight of density preservation + dens_frac=0.3, # Fraction for density estimation + output_dens=True, # Output density estimates + n_neighbors=15, + min_dist=0.1, + random_state=42 +) + +embedding = reducer.fit_transform(data) + +# Access density estimates +original_density = reducer.rad_orig_ # Density in original space +embedded_density = reducer.rad_emb_ # Density in embedded space +``` + +### Aligned UMAP for Time Series + +```python +from umap import AlignedUMAP + +# Multiple related datasets (e.g., different time points) +datasets = [day1_data, day2_data, day3_data, day4_data] + +# Align embeddings +mapper = AlignedUMAP( + n_neighbors=15, + alignment_regularisation=1e-2, # Alignment strength + alignment_window_size=2, # Align with adjacent datasets + n_components=2, + random_state=42 +) + +mapper.fit(datasets) + +# Access aligned embeddings +aligned_embeddings = mapper.embeddings_ +# aligned_embeddings[0] is day1 embedding +# aligned_embeddings[1] is day2 embedding, etc. +``` diff --git a/scientific-packages/zarr-python/SKILL.md b/scientific-packages/zarr-python/SKILL.md new file mode 100644 index 0000000..9ab7395 --- /dev/null +++ b/scientific-packages/zarr-python/SKILL.md @@ -0,0 +1,785 @@ +--- +name: zarr-python +description: Toolkit for working with Zarr, a Python library for chunked, compressed N-dimensional arrays optimized for cloud storage and large-scale scientific computing. Use this skill when working with large datasets that need efficient storage and parallel access, multidimensional arrays requiring chunking and compression, cloud-native data workflows (S3, GCS), or when integrating with NumPy, Dask, and Xarray for scientific computing tasks. +--- + +# Zarr Python + +## Overview + +Zarr is a Python library for storage of large N-dimensional arrays that are chunked and compressed. It provides a NumPy-like API but divides data into manageable chunks stored separately, enabling efficient parallel I/O, cloud-native workflows, and seamless integration with the scientific Python ecosystem (NumPy, Dask, Xarray). + +**Key capabilities:** +- Create and manipulate N-dimensional arrays with NumPy-like semantics +- Configure chunking strategies for optimal parallel access and performance +- Apply compression algorithms (Blosc, Zstandard, Gzip, etc.) to reduce storage +- Use flexible storage backends: local filesystem, memory, ZIP files, or cloud storage (S3, GCS) +- Organize data hierarchically using groups (similar to HDF5) +- Integrate seamlessly with Dask for parallel computing and Xarray for labeled arrays + +## Quick Start + +### Installation + +```python +# Using pip +pip install zarr + +# Using conda +conda install --channel conda-forge zarr +``` + +Requires Python 3.11+. For cloud storage support, install additional packages: +```python +pip install s3fs # For S3 +pip install gcsfs # For Google Cloud Storage +``` + +### Basic Array Creation + +```python +import zarr +import numpy as np + +# Create a 2D array with chunking and compression +z = zarr.create_array( + store="data/my_array.zarr", + shape=(10000, 10000), + chunks=(1000, 1000), + dtype="f4" +) + +# Write data using NumPy-style indexing +z[:, :] = np.random.random((10000, 10000)) + +# Read data +data = z[0:100, 0:100] # Returns NumPy array +``` + +## Core Operations + +### Creating Arrays + +Zarr provides multiple convenience functions for array creation: + +```python +# Create empty array +z = zarr.zeros(shape=(10000, 10000), chunks=(1000, 1000), dtype='f4', + store='data.zarr') + +# Create filled arrays +z = zarr.ones((5000, 5000), chunks=(500, 500)) +z = zarr.full((1000, 1000), fill_value=42, chunks=(100, 100)) + +# Create from existing data +data = np.arange(10000).reshape(100, 100) +z = zarr.array(data, chunks=(10, 10), store='data.zarr') + +# Create like another array +z2 = zarr.zeros_like(z) # Matches shape, chunks, dtype of z +``` + +### Opening Existing Arrays + +```python +# Open array (read/write mode by default) +z = zarr.open_array('data.zarr', mode='r+') + +# Read-only mode +z = zarr.open_array('data.zarr', mode='r') + +# The open() function auto-detects arrays vs groups +z = zarr.open('data.zarr') # Returns Array or Group +``` + +### Reading and Writing Data + +Zarr arrays support NumPy-like indexing: + +```python +# Write entire array +z[:] = 42 + +# Write slices +z[0, :] = np.arange(100) +z[10:20, 50:60] = np.random.random((10, 10)) + +# Read data (returns NumPy array) +data = z[0:100, 0:100] +row = z[5, :] + +# Advanced indexing +z.vindex[[0, 5, 10], [2, 8, 15]] # Coordinate indexing +z.oindex[0:10, [5, 10, 15]] # Orthogonal indexing +z.blocks[0, 0] # Block/chunk indexing +``` + +### Resizing and Appending + +```python +# Resize array +z.resize(15000, 15000) # Expands or shrinks dimensions + +# Append data along an axis +z.append(np.random.random((1000, 10000)), axis=0) # Adds rows +``` + +## Chunking Strategies + +Chunking is critical for performance. Choose chunk sizes and shapes based on access patterns. + +### Chunk Size Guidelines + +- **Minimum chunk size**: 1 MB recommended for optimal performance +- **Balance**: Larger chunks = fewer metadata operations; smaller chunks = better parallel access +- **Memory consideration**: Entire chunks must fit in memory during compression + +```python +# Configure chunk size (aim for ~1MB per chunk) +# For float32 data: 1MB = 262,144 elements = 512×512 array +z = zarr.zeros( + shape=(10000, 10000), + chunks=(512, 512), # ~1MB chunks + dtype='f4' +) +``` + +### Aligning Chunks with Access Patterns + +**Critical**: Chunk shape dramatically affects performance based on how data is accessed. + +```python +# If accessing rows frequently (first dimension) +z = zarr.zeros((10000, 10000), chunks=(10, 10000)) # Chunk spans columns + +# If accessing columns frequently (second dimension) +z = zarr.zeros((10000, 10000), chunks=(10000, 10)) # Chunk spans rows + +# For mixed access patterns (balanced approach) +z = zarr.zeros((10000, 10000), chunks=(1000, 1000)) # Square chunks +``` + +**Performance example**: For a (200, 200, 200) array, reading along the first dimension: +- Using chunks (1, 200, 200): ~107ms +- Using chunks (200, 200, 1): ~1.65ms (65× faster!) + +### Sharding for Large-Scale Storage + +When arrays have millions of small chunks, use sharding to group chunks into larger storage objects: + +```python +from zarr.codecs import ShardingCodec, BytesCodec +from zarr.codecs.blosc import BloscCodec + +# Create array with sharding +z = zarr.create_array( + store='data.zarr', + shape=(100000, 100000), + chunks=(100, 100), # Small chunks for access + shards=(1000, 1000), # Groups 100 chunks per shard + dtype='f4' +) +``` + +**Benefits**: +- Reduces file system overhead from millions of small files +- Improves cloud storage performance (fewer object requests) +- Prevents filesystem block size waste + +**Important**: Entire shards must fit in memory before writing. + +## Compression + +Zarr applies compression per chunk to reduce storage while maintaining fast access. + +### Configuring Compression + +```python +from zarr.codecs.blosc import BloscCodec +from zarr.codecs import GzipCodec, ZstdCodec + +# Default: Blosc with Zstandard +z = zarr.zeros((1000, 1000), chunks=(100, 100)) # Uses default compression + +# Configure Blosc codec +z = zarr.create_array( + store='data.zarr', + shape=(1000, 1000), + chunks=(100, 100), + dtype='f4', + codecs=[BloscCodec(cname='zstd', clevel=5, shuffle='shuffle')] +) + +# Available Blosc compressors: 'blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd' + +# Use Gzip compression +z = zarr.create_array( + store='data.zarr', + shape=(1000, 1000), + chunks=(100, 100), + dtype='f4', + codecs=[GzipCodec(level=6)] +) + +# Disable compression +z = zarr.create_array( + store='data.zarr', + shape=(1000, 1000), + chunks=(100, 100), + dtype='f4', + codecs=[BytesCodec()] # No compression +) +``` + +### Compression Performance Tips + +- **Blosc** (default): Fast compression/decompression, good for interactive workloads +- **Zstandard**: Better compression ratios, slightly slower than LZ4 +- **Gzip**: Maximum compression, slower performance +- **LZ4**: Fastest compression, lower ratios +- **Shuffle**: Enable shuffle filter for better compression on numeric data + +```python +# Optimal for numeric scientific data +codecs=[BloscCodec(cname='zstd', clevel=5, shuffle='shuffle')] + +# Optimal for speed +codecs=[BloscCodec(cname='lz4', clevel=1)] + +# Optimal for compression ratio +codecs=[GzipCodec(level=9)] +``` + +## Storage Backends + +Zarr supports multiple storage backends through a flexible storage interface. + +### Local Filesystem (Default) + +```python +from zarr.storage import LocalStore + +# Explicit store creation +store = LocalStore('data/my_array.zarr') +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) + +# Or use string path (creates LocalStore automatically) +z = zarr.open_array('data/my_array.zarr', mode='w', shape=(1000, 1000), + chunks=(100, 100)) +``` + +### In-Memory Storage + +```python +from zarr.storage import MemoryStore + +# Create in-memory store +store = MemoryStore() +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) + +# Data exists only in memory, not persisted +``` + +### ZIP File Storage + +```python +from zarr.storage import ZipStore + +# Write to ZIP file +store = ZipStore('data.zip', mode='w') +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +z[:] = np.random.random((1000, 1000)) +store.close() # IMPORTANT: Must close ZipStore + +# Read from ZIP file +store = ZipStore('data.zip', mode='r') +z = zarr.open_array(store=store) +data = z[:] +store.close() +``` + +### Cloud Storage (S3, GCS) + +```python +import s3fs +import zarr + +# S3 storage +s3 = s3fs.S3FileSystem(anon=False) # Use credentials +store = s3fs.S3Map(root='my-bucket/path/to/array.zarr', s3=s3) +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +z[:] = data + +# Google Cloud Storage +import gcsfs +gcs = gcsfs.GCSFileSystem(project='my-project') +store = gcsfs.GCSMap(root='my-bucket/path/to/array.zarr', gcs=gcs) +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +``` + +**Cloud Storage Best Practices**: +- Use consolidated metadata to reduce latency: `zarr.consolidate_metadata(store)` +- Align chunk sizes with cloud object sizing (typically 5-100 MB optimal) +- Enable parallel writes using Dask for large-scale data +- Consider sharding to reduce number of objects + +## Groups and Hierarchies + +Groups organize multiple arrays hierarchically, similar to directories or HDF5 groups. + +### Creating and Using Groups + +```python +# Create root group +root = zarr.group(store='data/hierarchy.zarr') + +# Create sub-groups +temperature = root.create_group('temperature') +precipitation = root.create_group('precipitation') + +# Create arrays within groups +temp_array = temperature.create_array( + name='t2m', + shape=(365, 720, 1440), + chunks=(1, 720, 1440), + dtype='f4' +) + +precip_array = precipitation.create_array( + name='prcp', + shape=(365, 720, 1440), + chunks=(1, 720, 1440), + dtype='f4' +) + +# Access using paths +array = root['temperature/t2m'] + +# Visualize hierarchy +print(root.tree()) +# Output: +# / +# ├── temperature +# │ └── t2m (365, 720, 1440) f4 +# └── precipitation +# └── prcp (365, 720, 1440) f4 +``` + +### H5py-Compatible API + +Zarr provides an h5py-compatible interface for familiar HDF5 users: + +```python +# Create group with h5py-style methods +root = zarr.group('data.zarr') +dataset = root.create_dataset('my_data', shape=(1000, 1000), chunks=(100, 100), + dtype='f4') + +# Access like h5py +grp = root.require_group('subgroup') +arr = grp.require_dataset('array', shape=(500, 500), chunks=(50, 50), dtype='i4') +``` + +## Attributes and Metadata + +Attach custom metadata to arrays and groups using attributes: + +```python +# Add attributes to array +z = zarr.zeros((1000, 1000), chunks=(100, 100)) +z.attrs['description'] = 'Temperature data in Kelvin' +z.attrs['units'] = 'K' +z.attrs['created'] = '2024-01-15' +z.attrs['processing_version'] = 2.1 + +# Attributes are stored as JSON +print(z.attrs['units']) # Output: K + +# Add attributes to groups +root = zarr.group('data.zarr') +root.attrs['project'] = 'Climate Analysis' +root.attrs['institution'] = 'Research Institute' + +# Attributes persist with the array/group +z2 = zarr.open('data.zarr') +print(z2.attrs['description']) +``` + +**Important**: Attributes must be JSON-serializable (strings, numbers, lists, dicts, booleans, null). + +## Integration with NumPy, Dask, and Xarray + +### NumPy Integration + +Zarr arrays implement the NumPy array interface: + +```python +import numpy as np +import zarr + +z = zarr.zeros((1000, 1000), chunks=(100, 100)) + +# Use NumPy functions directly +result = np.sum(z, axis=0) # NumPy operates on Zarr array +mean = np.mean(z[:100, :100]) + +# Convert to NumPy array +numpy_array = z[:] # Loads entire array into memory +``` + +### Dask Integration + +Dask provides lazy, parallel computation on Zarr arrays: + +```python +import dask.array as da +import zarr + +# Create large Zarr array +z = zarr.open('data.zarr', mode='w', shape=(100000, 100000), + chunks=(1000, 1000), dtype='f4') + +# Load as Dask array (lazy, no data loaded) +dask_array = da.from_zarr('data.zarr') + +# Perform computations (parallel, out-of-core) +result = dask_array.mean(axis=0).compute() # Parallel computation + +# Write Dask array to Zarr +large_array = da.random.random((100000, 100000), chunks=(1000, 1000)) +da.to_zarr(large_array, 'output.zarr') +``` + +**Benefits**: +- Process datasets larger than memory +- Automatic parallel computation across chunks +- Efficient I/O with chunked storage + +### Xarray Integration + +Xarray provides labeled, multidimensional arrays with Zarr backend: + +```python +import xarray as xr +import zarr + +# Open Zarr store as Xarray Dataset (lazy loading) +ds = xr.open_zarr('data.zarr') + +# Dataset includes coordinates and metadata +print(ds) + +# Access variables +temperature = ds['temperature'] + +# Perform labeled operations +subset = ds.sel(time='2024-01', lat=slice(30, 60)) + +# Write Xarray Dataset to Zarr +ds.to_zarr('output.zarr') + +# Create from scratch with coordinates +ds = xr.Dataset( + { + 'temperature': (['time', 'lat', 'lon'], data), + 'precipitation': (['time', 'lat', 'lon'], data2) + }, + coords={ + 'time': pd.date_range('2024-01-01', periods=365), + 'lat': np.arange(-90, 91, 1), + 'lon': np.arange(-180, 180, 1) + } +) +ds.to_zarr('climate_data.zarr') +``` + +**Benefits**: +- Named dimensions and coordinates +- Label-based indexing and selection +- Integration with pandas for time series +- NetCDF-like interface familiar to climate/geospatial scientists + +## Parallel Computing and Synchronization + +### Thread-Safe Operations + +```python +from zarr import ThreadSynchronizer +import zarr + +# For multi-threaded writes +synchronizer = ThreadSynchronizer() +z = zarr.open_array('data.zarr', mode='r+', shape=(10000, 10000), + chunks=(1000, 1000), synchronizer=synchronizer) + +# Safe for concurrent writes from multiple threads +# (when writes don't span chunk boundaries) +``` + +### Process-Safe Operations + +```python +from zarr import ProcessSynchronizer +import zarr + +# For multi-process writes +synchronizer = ProcessSynchronizer('sync_data.sync') +z = zarr.open_array('data.zarr', mode='r+', shape=(10000, 10000), + chunks=(1000, 1000), synchronizer=synchronizer) + +# Safe for concurrent writes from multiple processes +``` + +**Note**: +- Concurrent reads require no synchronization +- Synchronization only needed for writes that may span chunk boundaries +- Each process/thread writing to separate chunks needs no synchronization + +## Consolidated Metadata + +For hierarchical stores with many arrays, consolidate metadata into a single file to reduce I/O operations: + +```python +import zarr + +# After creating arrays/groups +root = zarr.group('data.zarr') +# ... create multiple arrays/groups ... + +# Consolidate metadata +zarr.consolidate_metadata('data.zarr') + +# Open with consolidated metadata (faster, especially on cloud storage) +root = zarr.open_consolidated('data.zarr') +``` + +**Benefits**: +- Reduces metadata read operations from N (one per array) to 1 +- Critical for cloud storage (reduces latency) +- Speeds up `tree()` operations and group traversal + +**Cautions**: +- Metadata can become stale if arrays update without re-consolidation +- Not suitable for frequently-updated datasets +- Multi-writer scenarios may have inconsistent reads + +## Performance Optimization + +### Checklist for Optimal Performance + +1. **Chunk Size**: Aim for 1-10 MB per chunk + ```python + # For float32: 1MB = 262,144 elements + chunks = (512, 512) # 512×512×4 bytes = ~1MB + ``` + +2. **Chunk Shape**: Align with access patterns + ```python + # Row-wise access → chunk spans columns: (small, large) + # Column-wise access → chunk spans rows: (large, small) + # Random access → balanced: (medium, medium) + ``` + +3. **Compression**: Choose based on workload + ```python + # Interactive/fast: BloscCodec(cname='lz4') + # Balanced: BloscCodec(cname='zstd', clevel=5) + # Maximum compression: GzipCodec(level=9) + ``` + +4. **Storage Backend**: Match to environment + ```python + # Local: LocalStore (default) + # Cloud: S3Map/GCSMap with consolidated metadata + # Temporary: MemoryStore + ``` + +5. **Sharding**: Use for large-scale datasets + ```python + # When you have millions of small chunks + shards=(10*chunk_size, 10*chunk_size) + ``` + +6. **Parallel I/O**: Use Dask for large operations + ```python + import dask.array as da + dask_array = da.from_zarr('data.zarr') + result = dask_array.compute(scheduler='threads', num_workers=8) + ``` + +### Profiling and Debugging + +```python +# Print detailed array information +print(z.info) + +# Output includes: +# - Type, shape, chunks, dtype +# - Compression codec and level +# - Storage size (compressed vs uncompressed) +# - Storage location + +# Check storage size +print(f"Compressed size: {z.nbytes_stored / 1e6:.2f} MB") +print(f"Uncompressed size: {z.nbytes / 1e6:.2f} MB") +print(f"Compression ratio: {z.nbytes / z.nbytes_stored:.2f}x") +``` + +## Common Patterns and Best Practices + +### Pattern: Time Series Data + +```python +# Store time series with time as first dimension +# This allows efficient appending of new time steps +z = zarr.open('timeseries.zarr', mode='a', + shape=(0, 720, 1440), # Start with 0 time steps + chunks=(1, 720, 1440), # One time step per chunk + dtype='f4') + +# Append new time steps +new_data = np.random.random((1, 720, 1440)) +z.append(new_data, axis=0) +``` + +### Pattern: Large Matrix Operations + +```python +import dask.array as da + +# Create large matrix in Zarr +z = zarr.open('matrix.zarr', mode='w', + shape=(100000, 100000), + chunks=(1000, 1000), + dtype='f8') + +# Use Dask for parallel computation +dask_z = da.from_zarr('matrix.zarr') +result = (dask_z @ dask_z.T).compute() # Parallel matrix multiply +``` + +### Pattern: Cloud-Native Workflow + +```python +import s3fs +import zarr + +# Write to S3 +s3 = s3fs.S3FileSystem() +store = s3fs.S3Map(root='s3://my-bucket/data.zarr', s3=s3) + +# Create array with appropriate chunking for cloud +z = zarr.open_array(store=store, mode='w', + shape=(10000, 10000), + chunks=(500, 500), # ~1MB chunks + dtype='f4') +z[:] = data + +# Consolidate metadata for faster reads +zarr.consolidate_metadata(store) + +# Read from S3 (anywhere, anytime) +store_read = s3fs.S3Map(root='s3://my-bucket/data.zarr', s3=s3) +z_read = zarr.open_consolidated(store_read) +subset = z_read[0:100, 0:100] +``` + +### Pattern: Format Conversion + +```python +# HDF5 to Zarr +import h5py +import zarr + +with h5py.File('data.h5', 'r') as h5: + dataset = h5['dataset_name'] + z = zarr.array(dataset[:], + chunks=(1000, 1000), + store='data.zarr') + +# NumPy to Zarr +import numpy as np +data = np.load('data.npy') +z = zarr.array(data, chunks='auto', store='data.zarr') + +# Zarr to NetCDF (via Xarray) +import xarray as xr +ds = xr.open_zarr('data.zarr') +ds.to_netcdf('data.nc') +``` + +## Common Issues and Solutions + +### Issue: Slow Performance + +**Diagnosis**: Check chunk size and alignment +```python +print(z.chunks) # Are chunks appropriate size? +print(z.info) # Check compression ratio +``` + +**Solutions**: +- Increase chunk size to 1-10 MB +- Align chunks with access pattern +- Try different compression codecs +- Use Dask for parallel operations + +### Issue: High Memory Usage + +**Cause**: Loading entire array or large chunks into memory + +**Solutions**: +```python +# Don't load entire array +# Bad: data = z[:] +# Good: Process in chunks +for i in range(0, z.shape[0], 1000): + chunk = z[i:i+1000, :] + process(chunk) + +# Or use Dask for automatic chunking +import dask.array as da +dask_z = da.from_zarr('data.zarr') +result = dask_z.mean().compute() # Processes in chunks +``` + +### Issue: Cloud Storage Latency + +**Solutions**: +```python +# 1. Consolidate metadata +zarr.consolidate_metadata(store) +z = zarr.open_consolidated(store) + +# 2. Use appropriate chunk sizes (5-100 MB for cloud) +chunks = (2000, 2000) # Larger chunks for cloud + +# 3. Enable sharding +shards = (10000, 10000) # Groups many chunks +``` + +### Issue: Concurrent Write Conflicts + +**Solution**: Use synchronizers or ensure non-overlapping writes +```python +from zarr import ProcessSynchronizer + +sync = ProcessSynchronizer('sync.sync') +z = zarr.open_array('data.zarr', mode='r+', synchronizer=sync) + +# Or design workflow so each process writes to separate chunks +``` + +## Additional Resources + +For detailed API documentation, advanced usage, and the latest updates: + +- **Official Documentation**: https://zarr.readthedocs.io/ +- **Zarr Specifications**: https://zarr-specs.readthedocs.io/ +- **GitHub Repository**: https://github.com/zarr-developers/zarr-python +- **Community Chat**: https://gitter.im/zarr-developers/community + +**Related Libraries**: +- **Xarray**: https://docs.xarray.dev/ (labeled arrays) +- **Dask**: https://docs.dask.org/ (parallel computing) +- **NumCodecs**: https://numcodecs.readthedocs.io/ (compression codecs) diff --git a/scientific-packages/zarr-python/references/api_reference.md b/scientific-packages/zarr-python/references/api_reference.md new file mode 100644 index 0000000..71f9957 --- /dev/null +++ b/scientific-packages/zarr-python/references/api_reference.md @@ -0,0 +1,515 @@ +# Zarr Python Quick Reference + +This reference provides a concise overview of commonly used Zarr functions, parameters, and patterns for quick lookup during development. + +## Array Creation Functions + +### `zarr.zeros()` / `zarr.ones()` / `zarr.empty()` +```python +zarr.zeros(shape, chunks=None, dtype='f8', store=None, compressor='default', + fill_value=0, order='C', filters=None) +``` +Create arrays filled with zeros, ones, or empty (uninitialized) values. + +**Key parameters:** +- `shape`: Tuple defining array dimensions (e.g., `(1000, 1000)`) +- `chunks`: Tuple defining chunk dimensions (e.g., `(100, 100)`), or `None` for no chunking +- `dtype`: NumPy data type (e.g., `'f4'`, `'i8'`, `'bool'`) +- `store`: Storage location (string path, Store object, or `None` for memory) +- `compressor`: Compression codec or `None` for no compression + +### `zarr.create_array()` / `zarr.create()` +```python +zarr.create_array(store, shape, chunks, dtype='f8', compressor='default', + fill_value=0, order='C', filters=None, overwrite=False) +``` +Create a new array with explicit control over all parameters. + +### `zarr.array()` +```python +zarr.array(data, chunks=None, dtype=None, compressor='default', store=None) +``` +Create array from existing data (NumPy array, list, etc.). + +**Example:** +```python +import numpy as np +data = np.random.random((1000, 1000)) +z = zarr.array(data, chunks=(100, 100), store='data.zarr') +``` + +### `zarr.open_array()` / `zarr.open()` +```python +zarr.open_array(store, mode='a', shape=None, chunks=None, dtype=None, + compressor='default', fill_value=0) +``` +Open existing array or create new one. + +**Mode options:** +- `'r'`: Read-only +- `'r+'`: Read-write, file must exist +- `'a'`: Read-write, create if doesn't exist (default) +- `'w'`: Create new, overwrite if exists +- `'w-'`: Create new, fail if exists + +## Storage Classes + +### LocalStore (Default) +```python +from zarr.storage import LocalStore + +store = LocalStore('path/to/data.zarr') +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +``` + +### MemoryStore +```python +from zarr.storage import MemoryStore + +store = MemoryStore() # Data only in memory +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +``` + +### ZipStore +```python +from zarr.storage import ZipStore + +# Write +store = ZipStore('data.zip', mode='w') +z = zarr.open_array(store=store, mode='w', shape=(1000, 1000), chunks=(100, 100)) +z[:] = data +store.close() # MUST close + +# Read +store = ZipStore('data.zip', mode='r') +z = zarr.open_array(store=store) +data = z[:] +store.close() +``` + +### Cloud Storage (S3/GCS) +```python +# S3 +import s3fs +s3 = s3fs.S3FileSystem(anon=False) +store = s3fs.S3Map(root='bucket/path/data.zarr', s3=s3) + +# GCS +import gcsfs +gcs = gcsfs.GCSFileSystem(project='my-project') +store = gcsfs.GCSMap(root='bucket/path/data.zarr', gcs=gcs) +``` + +## Compression Codecs + +### Blosc Codec (Default) +```python +from zarr.codecs.blosc import BloscCodec + +codec = BloscCodec( + cname='zstd', # Compressor: 'blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd' + clevel=5, # Compression level: 0-9 + shuffle='shuffle' # Shuffle filter: 'noshuffle', 'shuffle', 'bitshuffle' +) + +z = zarr.create_array(store='data.zarr', shape=(1000, 1000), chunks=(100, 100), + dtype='f4', codecs=[codec]) +``` + +**Blosc compressor characteristics:** +- `'lz4'`: Fastest compression, lower ratio +- `'zstd'`: Balanced (default), good ratio and speed +- `'zlib'`: Good compatibility, moderate performance +- `'lz4hc'`: Better ratio than lz4, slower +- `'snappy'`: Fast, moderate ratio +- `'blosclz'`: Blosc's default + +### Other Codecs +```python +from zarr.codecs import GzipCodec, ZstdCodec, BytesCodec + +# Gzip compression (maximum ratio, slower) +GzipCodec(level=6) # Level 0-9 + +# Zstandard compression +ZstdCodec(level=3) # Level 1-22 + +# No compression +BytesCodec() +``` + +## Array Indexing and Selection + +### Basic Indexing (NumPy-style) +```python +z = zarr.zeros((1000, 1000), chunks=(100, 100)) + +# Read +row = z[0, :] # Single row +col = z[:, 0] # Single column +block = z[10:20, 50:60] # Slice +element = z[5, 10] # Single element + +# Write +z[0, :] = 42 +z[10:20, 50:60] = np.random.random((10, 10)) +``` + +### Advanced Indexing +```python +# Coordinate indexing (point selection) +z.vindex[[0, 5, 10], [2, 8, 15]] # Specific coordinates + +# Orthogonal indexing (outer product) +z.oindex[0:10, [5, 10, 15]] # Rows 0-9, columns 5, 10, 15 + +# Block/chunk indexing +z.blocks[0, 0] # First chunk +z.blocks[0:2, 0:2] # First four chunks +``` + +## Groups and Hierarchies + +### Creating Groups +```python +# Create root group +root = zarr.group(store='data.zarr') + +# Create nested groups +grp1 = root.create_group('group1') +grp2 = grp1.create_group('subgroup') + +# Create arrays in groups +arr = grp1.create_array(name='data', shape=(1000, 1000), + chunks=(100, 100), dtype='f4') + +# Access by path +arr2 = root['group1/data'] +``` + +### Group Methods +```python +root = zarr.group('data.zarr') + +# h5py-compatible methods +dataset = root.create_dataset('data', shape=(1000, 1000), chunks=(100, 100)) +subgrp = root.require_group('subgroup') # Create if doesn't exist + +# Visualize structure +print(root.tree()) + +# List contents +print(list(root.keys())) +print(list(root.groups())) +print(list(root.arrays())) +``` + +## Array Attributes and Metadata + +### Working with Attributes +```python +z = zarr.zeros((1000, 1000), chunks=(100, 100)) + +# Set attributes +z.attrs['units'] = 'meters' +z.attrs['description'] = 'Temperature data' +z.attrs['created'] = '2024-01-15' +z.attrs['version'] = 1.2 +z.attrs['tags'] = ['climate', 'temperature'] + +# Read attributes +print(z.attrs['units']) +print(dict(z.attrs)) # All attributes as dict + +# Update/delete +z.attrs['version'] = 2.0 +del z.attrs['tags'] +``` + +**Note:** Attributes must be JSON-serializable. + +## Array Properties and Methods + +### Properties +```python +z = zarr.zeros((1000, 1000), chunks=(100, 100), dtype='f4') + +z.shape # (1000, 1000) +z.chunks # (100, 100) +z.dtype # dtype('float32') +z.size # 1000000 +z.nbytes # 4000000 (uncompressed size in bytes) +z.nbytes_stored # Actual compressed size on disk +z.nchunks # 100 (number of chunks) +z.cdata_shape # Shape in terms of chunks: (10, 10) +``` + +### Methods +```python +# Information +print(z.info) # Detailed information about array +print(z.info_items()) # Info as list of tuples + +# Resizing +z.resize(1500, 1500) # Change dimensions + +# Appending +z.append(new_data, axis=0) # Add data along axis + +# Copying +z2 = z.copy(store='new_location.zarr') +``` + +## Chunking Guidelines + +### Chunk Size Calculation +```python +# For float32 (4 bytes per element): +# 1 MB = 262,144 elements +# 10 MB = 2,621,440 elements + +# Examples for 1 MB chunks: +(512, 512) # For 2D: 512 × 512 × 4 = 1,048,576 bytes +(128, 128, 128) # For 3D: 128 × 128 × 128 × 4 = 8,388,608 bytes ≈ 8 MB +(64, 256, 256) # For 3D: 64 × 256 × 256 × 4 = 16,777,216 bytes ≈ 16 MB +``` + +### Chunking Strategies by Access Pattern + +**Time series (sequential access along first dimension):** +```python +chunks=(1, 720, 1440) # One time step per chunk +``` + +**Row-wise access:** +```python +chunks=(10, 10000) # Small rows, span columns +``` + +**Column-wise access:** +```python +chunks=(10000, 10) # Span rows, small columns +``` + +**Random access:** +```python +chunks=(500, 500) # Balanced square chunks +``` + +**3D volumetric data:** +```python +chunks=(64, 64, 64) # Cubic chunks for isotropic access +``` + +## Integration APIs + +### NumPy Integration +```python +import numpy as np + +z = zarr.zeros((1000, 1000), chunks=(100, 100)) + +# Use NumPy functions +result = np.sum(z, axis=0) +mean = np.mean(z) +std = np.std(z) + +# Convert to NumPy +arr = z[:] # Loads entire array into memory +``` + +### Dask Integration +```python +import dask.array as da + +# Load Zarr as Dask array +dask_array = da.from_zarr('data.zarr') + +# Compute operations in parallel +result = dask_array.mean(axis=0).compute() + +# Write Dask array to Zarr +large_array = da.random.random((100000, 100000), chunks=(1000, 1000)) +da.to_zarr(large_array, 'output.zarr') +``` + +### Xarray Integration +```python +import xarray as xr + +# Open Zarr as Xarray Dataset +ds = xr.open_zarr('data.zarr') + +# Write Xarray to Zarr +ds.to_zarr('output.zarr') + +# Create with coordinates +ds = xr.Dataset( + {'temperature': (['time', 'lat', 'lon'], data)}, + coords={ + 'time': pd.date_range('2024-01-01', periods=365), + 'lat': np.arange(-90, 91, 1), + 'lon': np.arange(-180, 180, 1) + } +) +ds.to_zarr('climate.zarr') +``` + +## Parallel Computing + +### Synchronizers +```python +from zarr import ThreadSynchronizer, ProcessSynchronizer + +# Multi-threaded writes +sync = ThreadSynchronizer() +z = zarr.open_array('data.zarr', mode='r+', synchronizer=sync) + +# Multi-process writes +sync = ProcessSynchronizer('sync.sync') +z = zarr.open_array('data.zarr', mode='r+', synchronizer=sync) +``` + +**Note:** Synchronization only needed for: +- Concurrent writes that may span chunk boundaries +- Not needed for reads (always safe) +- Not needed if each process writes to separate chunks + +## Metadata Consolidation + +```python +# Consolidate metadata (after creating all arrays/groups) +zarr.consolidate_metadata('data.zarr') + +# Open with consolidated metadata (faster, especially on cloud) +root = zarr.open_consolidated('data.zarr') +``` + +**Benefits:** +- Reduces I/O from N operations to 1 +- Critical for cloud storage (reduces latency) +- Speeds up hierarchy traversal + +**Cautions:** +- Can become stale if data updates +- Re-consolidate after modifications +- Not for frequently-updated datasets + +## Common Patterns + +### Time Series with Growing Data +```python +# Start with empty first dimension +z = zarr.open('timeseries.zarr', mode='a', + shape=(0, 720, 1440), + chunks=(1, 720, 1440), + dtype='f4') + +# Append new time steps +for new_timestep in data_stream: + z.append(new_timestep, axis=0) +``` + +### Processing Large Arrays in Chunks +```python +z = zarr.open('large_data.zarr', mode='r') + +# Process without loading entire array +for i in range(0, z.shape[0], 1000): + chunk = z[i:i+1000, :] + result = process(chunk) + save(result) +``` + +### Format Conversion Pipeline +```python +# HDF5 → Zarr +import h5py +with h5py.File('data.h5', 'r') as h5: + z = zarr.array(h5['dataset'][:], chunks=(1000, 1000), store='data.zarr') + +# Zarr → NumPy file +z = zarr.open('data.zarr', mode='r') +np.save('data.npy', z[:]) + +# Zarr → NetCDF (via Xarray) +ds = xr.open_zarr('data.zarr') +ds.to_netcdf('data.nc') +``` + +## Performance Optimization Quick Checklist + +1. **Chunk size**: 1-10 MB per chunk +2. **Chunk shape**: Align with access pattern +3. **Compression**: + - Fast: `BloscCodec(cname='lz4', clevel=1)` + - Balanced: `BloscCodec(cname='zstd', clevel=5)` + - Maximum: `GzipCodec(level=9)` +4. **Cloud storage**: + - Larger chunks (5-100 MB) + - Consolidate metadata + - Consider sharding +5. **Parallel I/O**: Use Dask for large operations +6. **Memory**: Process in chunks, don't load entire arrays + +## Debugging and Profiling + +```python +z = zarr.open('data.zarr', mode='r') + +# Detailed information +print(z.info) + +# Size statistics +print(f"Uncompressed: {z.nbytes / 1e6:.2f} MB") +print(f"Compressed: {z.nbytes_stored / 1e6:.2f} MB") +print(f"Ratio: {z.nbytes / z.nbytes_stored:.1f}x") + +# Chunk information +print(f"Chunks: {z.chunks}") +print(f"Number of chunks: {z.nchunks}") +print(f"Chunk grid: {z.cdata_shape}") +``` + +## Common Data Types + +```python +# Integers +'i1', 'i2', 'i4', 'i8' # Signed: 8, 16, 32, 64-bit +'u1', 'u2', 'u4', 'u8' # Unsigned: 8, 16, 32, 64-bit + +# Floats +'f2', 'f4', 'f8' # 16, 32, 64-bit (half, single, double precision) + +# Others +'bool' # Boolean +'c8', 'c16' # Complex: 64, 128-bit +'S10' # Fixed-length string (10 bytes) +'U10' # Unicode string (10 characters) +``` + +## Version Compatibility + +Zarr-Python version 3.x supports both: +- **Zarr v2 format**: Legacy format, widely compatible +- **Zarr v3 format**: New format with sharding, improved metadata + +Check format version: +```python +# Zarr automatically detects format version +z = zarr.open('data.zarr', mode='r') +# Format info available in metadata +``` + +## Error Handling + +```python +try: + z = zarr.open_array('data.zarr', mode='r') +except zarr.errors.PathNotFoundError: + print("Array does not exist") +except zarr.errors.ReadOnlyError: + print("Cannot write to read-only array") +except Exception as e: + print(f"Unexpected error: {e}") +```