mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-29 07:43:46 +08:00
Compare commits
103 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a89e01aba | ||
|
|
2621ee329d | ||
|
|
57bde764fe | ||
|
|
4fc6ac7727 | ||
|
|
95a3b74b3b | ||
|
|
1e00b1536e | ||
|
|
312f18ae60 | ||
|
|
4fb9c053f7 | ||
|
|
4515ca6268 | ||
|
|
09d9aa3bb2 | ||
|
|
cf1d4aac5d | ||
|
|
d4ca5984ca | ||
|
|
a643493a32 | ||
|
|
c85faf039a | ||
|
|
ae60fcf620 | ||
|
|
78331e1b37 | ||
|
|
ab4aff4670 | ||
|
|
6560f1d779 | ||
|
|
49567890a6 | ||
|
|
ec10daba7e | ||
|
|
280a53f95e | ||
|
|
9347d99355 | ||
|
|
90de96a99b | ||
|
|
8d82c83a1a | ||
|
|
7e8deebf96 | ||
|
|
7763491813 | ||
|
|
16e47a1755 | ||
|
|
a077cee836 | ||
|
|
7caef7df68 | ||
|
|
bf4267161f | ||
|
|
6ac2a15e39 | ||
|
|
41f272c2bd | ||
|
|
02574ba19d | ||
|
|
ea638c5618 | ||
|
|
8e7a791871 | ||
|
|
3bb0ee77be | ||
|
|
e5fc882746 | ||
|
|
65b39d45d6 | ||
|
|
c078c98ad2 | ||
|
|
2e80732340 | ||
|
|
2fc3e6a88e | ||
|
|
d94f21c51f | ||
|
|
19c0b390ee | ||
|
|
54cab8e4b5 | ||
|
|
ad2dfc3446 | ||
|
|
63f257d81e | ||
|
|
8be6c6c307 | ||
|
|
cc99fdb57d | ||
|
|
50fdaf1b04 | ||
|
|
82663ee1de | ||
|
|
2873d0e39d | ||
|
|
0e4939147f | ||
|
|
5b7081cbff | ||
|
|
ffad3d81b0 | ||
|
|
1225ddecf1 | ||
|
|
4ad4f9970f | ||
|
|
63a4293f1a | ||
|
|
f124e28509 | ||
|
|
c56fa43747 | ||
|
|
86d8878eeb | ||
|
|
d57129ca3f | ||
|
|
537edff2a1 | ||
|
|
6ddea4786e | ||
|
|
094d5aa9f1 | ||
|
|
862445f531 | ||
|
|
b8c4d2bae1 | ||
|
|
27d6ee387f | ||
|
|
f32b3f8b42 | ||
|
|
ed2d1c4aeb | ||
|
|
7e3fae3ad1 | ||
|
|
97c03a11e5 | ||
|
|
b6bb261a71 | ||
|
|
c0de4269c0 | ||
|
|
aa5018612f | ||
|
|
f12f9c2904 | ||
|
|
55af02f2d4 | ||
|
|
346010a1bf | ||
|
|
5b37b62147 | ||
|
|
82901e9fbe | ||
|
|
752c32949b | ||
|
|
981823a71a | ||
|
|
26d4fde324 | ||
|
|
1f03feda5c | ||
|
|
44285238c4 | ||
|
|
af246ac562 | ||
|
|
04d528c4bc | ||
|
|
b83942845c | ||
|
|
6cefe6f4cc | ||
|
|
a5c2ed9bb6 | ||
|
|
295370a730 | ||
|
|
4fde8cb6de | ||
|
|
564bd39835 | ||
|
|
7240da6e6c | ||
|
|
1871693348 | ||
|
|
0e03bbcf38 | ||
|
|
ca6fd369d7 | ||
|
|
808fa11206 | ||
|
|
d923063309 | ||
|
|
43dc67a316 | ||
|
|
ee73143b53 | ||
|
|
b1df506eba | ||
|
|
bb1c9f4573 | ||
|
|
6cb25aea28 |
@@ -1,147 +1,160 @@
|
||||
|
||||
{
|
||||
"name": "claude-scientific-skills",
|
||||
"owner": {
|
||||
"name": "Timothy Kassis",
|
||||
"email": "timothy.kassis@k-dense.ai"
|
||||
"name": "K-Dense Inc.",
|
||||
"email": "contact@k-dense.ai"
|
||||
},
|
||||
"metadata": {
|
||||
"description": "Claude scientific skills from K-Dense Inc",
|
||||
"version": "1.54.0"
|
||||
"version": "2.11.2"
|
||||
},
|
||||
"plugins": [
|
||||
{
|
||||
"name": "scientific-packages",
|
||||
"description": "Collection of python scientific packages",
|
||||
"name": "scientific-skills",
|
||||
"description": "Collection of scientific skills",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./scientific-packages/anndata",
|
||||
"./scientific-packages/arboreto",
|
||||
"./scientific-packages/astropy",
|
||||
"./scientific-packages/biomni",
|
||||
"./scientific-packages/biopython",
|
||||
"./scientific-packages/bioservices",
|
||||
"./scientific-packages/cellxgene-census",
|
||||
"./scientific-packages/cobrapy",
|
||||
"./scientific-packages/dask",
|
||||
"./scientific-packages/datamol",
|
||||
"./scientific-packages/deepchem",
|
||||
"./scientific-packages/deeptools",
|
||||
"./scientific-packages/diffdock",
|
||||
"./scientific-packages/esm",
|
||||
"./scientific-packages/etetoolkit",
|
||||
"./scientific-packages/flowio",
|
||||
"./scientific-packages/gget",
|
||||
"./scientific-packages/matchms",
|
||||
"./scientific-packages/matplotlib",
|
||||
"./scientific-packages/medchem",
|
||||
"./scientific-packages/molfeat",
|
||||
"./scientific-packages/paper-2-web",
|
||||
"./scientific-packages/polars",
|
||||
"./scientific-packages/pydeseq2",
|
||||
"./scientific-packages/pymatgen",
|
||||
"./scientific-packages/pymc",
|
||||
"./scientific-packages/pymoo",
|
||||
"./scientific-packages/pyopenms",
|
||||
"./scientific-packages/pysam",
|
||||
"./scientific-packages/pytdc",
|
||||
"./scientific-packages/pytorch-lightning",
|
||||
"./scientific-packages/rdkit",
|
||||
"./scientific-packages/reportlab",
|
||||
"./scientific-packages/scanpy",
|
||||
"./scientific-packages/scvi-tools",
|
||||
"./scientific-packages/scikit-bio",
|
||||
"./scientific-packages/scikit-learn",
|
||||
"./scientific-packages/seaborn",
|
||||
"./scientific-packages/shap",
|
||||
"./scientific-packages/statsmodels",
|
||||
"./scientific-packages/torch_geometric",
|
||||
"./scientific-packages/torchdrug",
|
||||
"./scientific-packages/tooluniverse",
|
||||
"./scientific-packages/transformers",
|
||||
"./scientific-packages/umap-learn",
|
||||
"./scientific-packages/zarr-python"
|
||||
"./scientific-skills/adaptyv",
|
||||
"./scientific-skills/aeon",
|
||||
"./scientific-skills/anndata",
|
||||
"./scientific-skills/arboreto",
|
||||
"./scientific-skills/astropy",
|
||||
"./scientific-skills/biomni",
|
||||
"./scientific-skills/biopython",
|
||||
"./scientific-skills/bioservices",
|
||||
"./scientific-skills/cellxgene-census",
|
||||
"./scientific-skills/cirq",
|
||||
"./scientific-skills/cobrapy",
|
||||
"./scientific-skills/dask",
|
||||
"./scientific-skills/datacommons-client",
|
||||
"./scientific-skills/datamol",
|
||||
"./scientific-skills/deepchem",
|
||||
"./scientific-skills/deeptools",
|
||||
"./scientific-skills/denario",
|
||||
"./scientific-skills/diffdock",
|
||||
"./scientific-skills/esm",
|
||||
"./scientific-skills/etetoolkit",
|
||||
"./scientific-skills/flowio",
|
||||
"./scientific-skills/fluidsim",
|
||||
"./scientific-skills/geniml",
|
||||
"./scientific-skills/geopandas",
|
||||
"./scientific-skills/gget",
|
||||
"./scientific-skills/gtars",
|
||||
"./scientific-skills/histolab",
|
||||
"./scientific-skills/hypogenic",
|
||||
"./scientific-skills/lamindb",
|
||||
"./scientific-skills/markitdown",
|
||||
"./scientific-skills/matchms",
|
||||
"./scientific-skills/matplotlib",
|
||||
"./scientific-skills/medchem",
|
||||
"./scientific-skills/modal",
|
||||
"./scientific-skills/molfeat",
|
||||
"./scientific-skills/neurokit2",
|
||||
"./scientific-skills/networkx",
|
||||
"./scientific-skills/paper-2-web",
|
||||
"./scientific-skills/pathml",
|
||||
"./scientific-skills/pennylane",
|
||||
"./scientific-skills/perplexity-search",
|
||||
"./scientific-skills/plotly",
|
||||
"./scientific-skills/polars",
|
||||
"./scientific-skills/pydeseq2",
|
||||
"./scientific-skills/pydicom",
|
||||
"./scientific-skills/pyhealth",
|
||||
"./scientific-skills/pylabrobot",
|
||||
"./scientific-skills/pymatgen",
|
||||
"./scientific-skills/pymc",
|
||||
"./scientific-skills/pymoo",
|
||||
"./scientific-skills/pyopenms",
|
||||
"./scientific-skills/pufferlib",
|
||||
"./scientific-skills/pysam",
|
||||
"./scientific-skills/pytdc",
|
||||
"./scientific-skills/pytorch-lightning",
|
||||
"./scientific-skills/qiskit",
|
||||
"./scientific-skills/qutip",
|
||||
"./scientific-skills/rdkit",
|
||||
"./scientific-skills/scanpy",
|
||||
"./scientific-skills/scikit-bio",
|
||||
"./scientific-skills/scikit-learn",
|
||||
"./scientific-skills/scikit-survival",
|
||||
"./scientific-skills/scvi-tools",
|
||||
"./scientific-skills/seaborn",
|
||||
"./scientific-skills/shap",
|
||||
"./scientific-skills/simpy",
|
||||
"./scientific-skills/stable-baselines3",
|
||||
"./scientific-skills/statsmodels",
|
||||
"./scientific-skills/sympy",
|
||||
"./scientific-skills/torch_geometric",
|
||||
"./scientific-skills/torchdrug",
|
||||
"./scientific-skills/transformers",
|
||||
"./scientific-skills/umap-learn",
|
||||
"./scientific-skills/vaex",
|
||||
"./scientific-skills/zarr-python",
|
||||
"./scientific-skills/alphafold-database",
|
||||
"./scientific-skills/biorxiv-database",
|
||||
"./scientific-skills/brenda-database",
|
||||
"./scientific-skills/chembl-database",
|
||||
"./scientific-skills/clinicaltrials-database",
|
||||
"./scientific-skills/clinpgx-database",
|
||||
"./scientific-skills/clinvar-database",
|
||||
"./scientific-skills/cosmic-database",
|
||||
"./scientific-skills/drugbank-database",
|
||||
"./scientific-skills/ena-database",
|
||||
"./scientific-skills/ensembl-database",
|
||||
"./scientific-skills/fda-database",
|
||||
"./scientific-skills/gene-database",
|
||||
"./scientific-skills/geo-database",
|
||||
"./scientific-skills/gwas-database",
|
||||
"./scientific-skills/hmdb-database",
|
||||
"./scientific-skills/kegg-database",
|
||||
"./scientific-skills/metabolomics-workbench-database",
|
||||
"./scientific-skills/openalex-database",
|
||||
"./scientific-skills/opentargets-database",
|
||||
"./scientific-skills/pdb-database",
|
||||
"./scientific-skills/pubchem-database",
|
||||
"./scientific-skills/pubmed-database",
|
||||
"./scientific-skills/reactome-database",
|
||||
"./scientific-skills/string-database",
|
||||
"./scientific-skills/uniprot-database",
|
||||
"./scientific-skills/uspto-database",
|
||||
"./scientific-skills/zinc-database",
|
||||
"./scientific-skills/exploratory-data-analysis",
|
||||
"./scientific-skills/hypothesis-generation",
|
||||
"./scientific-skills/literature-review",
|
||||
"./scientific-skills/peer-review",
|
||||
"./scientific-skills/scholar-evaluation",
|
||||
"./scientific-skills/scientific-brainstorming",
|
||||
"./scientific-skills/scientific-critical-thinking",
|
||||
"./scientific-skills/scientific-writing",
|
||||
"./scientific-skills/statistical-analysis",
|
||||
"./scientific-skills/scientific-visualization",
|
||||
"./scientific-skills/citation-management",
|
||||
"./scientific-skills/clinical-decision-support",
|
||||
"./scientific-skills/clinical-reports",
|
||||
"./scientific-skills/generate-image",
|
||||
"./scientific-skills/latex-posters",
|
||||
"./scientific-skills/market-research-reports",
|
||||
"./scientific-skills/pptx-posters",
|
||||
"./scientific-skills/research-grants",
|
||||
"./scientific-skills/research-lookup",
|
||||
"./scientific-skills/scientific-schematics",
|
||||
"./scientific-skills/scientific-slides",
|
||||
"./scientific-skills/treatment-plans",
|
||||
"./scientific-skills/venue-templates",
|
||||
"./scientific-skills/document-skills/docx",
|
||||
"./scientific-skills/document-skills/pdf",
|
||||
"./scientific-skills/document-skills/pptx",
|
||||
"./scientific-skills/document-skills/xlsx",
|
||||
"./scientific-skills/benchling-integration",
|
||||
"./scientific-skills/dnanexus-integration",
|
||||
"./scientific-skills/labarchive-integration",
|
||||
"./scientific-skills/latchbio-integration",
|
||||
"./scientific-skills/omero-integration",
|
||||
"./scientific-skills/opentrons-integration",
|
||||
"./scientific-skills/protocolsio-integration",
|
||||
"./scientific-skills/get-available-resources",
|
||||
"./scientific-skills/iso-13485-certification"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "scientific-databases",
|
||||
"description": "Collection of scientific databases",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./scientific-databases/alphafold-database",
|
||||
"./scientific-databases/biorxiv-database",
|
||||
"./scientific-databases/chembl-database",
|
||||
"./scientific-databases/clinpgx-database",
|
||||
"./scientific-databases/clinvar-database",
|
||||
"./scientific-databases/clinicaltrials-database",
|
||||
"./scientific-databases/cosmic-database",
|
||||
"./scientific-databases/ena-database",
|
||||
"./scientific-databases/ensembl-database",
|
||||
"./scientific-databases/fda-database",
|
||||
"./scientific-databases/gene-database",
|
||||
"./scientific-databases/geo-database",
|
||||
"./scientific-databases/gwas-database",
|
||||
"./scientific-databases/hmdb-database",
|
||||
"./scientific-databases/kegg-database",
|
||||
"./scientific-databases/metabolomics-workbench-database",
|
||||
"./scientific-databases/opentargets-database",
|
||||
"./scientific-databases/pdb-database",
|
||||
"./scientific-databases/pubchem-database",
|
||||
"./scientific-databases/pubmed-database",
|
||||
"./scientific-databases/reactome-database",
|
||||
"./scientific-databases/string-database",
|
||||
"./scientific-databases/uniprot-database",
|
||||
"./scientific-databases/uspto-database",
|
||||
"./scientific-databases/zinc-database"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "scientific-thinking",
|
||||
"description": "Collection of scientific thinking methodologies",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./scientific-thinking/exploratory-data-analysis",
|
||||
"./scientific-thinking/hypothesis-generation",
|
||||
"./scientific-thinking/peer-review",
|
||||
"./scientific-thinking/scientific-brainstorming",
|
||||
"./scientific-thinking/scientific-critical-thinking",
|
||||
"./scientific-thinking/scientific-writing",
|
||||
"./scientific-thinking/statistical-analysis",
|
||||
"./scientific-thinking/scientific-visualization",
|
||||
"./scientific-thinking/document-skills/docx",
|
||||
"./scientific-thinking/document-skills/pdf",
|
||||
"./scientific-thinking/document-skills/pptx",
|
||||
"./scientific-thinking/document-skills/xlsx"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "scientific-integrations",
|
||||
"description": "Collection of scientific platform integrations",
|
||||
"source": "./",
|
||||
"strict": false,
|
||||
"skills": [
|
||||
"./scientific-integrations/benchling-integration",
|
||||
"./scientific-integrations/dnanexus-integration",
|
||||
"./scientific-integrations/labarchive-integration",
|
||||
"./scientific-integrations/latchbio-integration",
|
||||
"./scientific-integrations/omero-integration",
|
||||
"./scientific-integrations/opentrons-integration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "scientific-context-initialization",
|
||||
"description": "Always Auto-invoked skill that creates/updates workspace AGENT.md to instruct the agent to always search for existing skills before attempting any scientific task",
|
||||
"source": "./scientific-helpers/scientific-context-initialization",
|
||||
"strict": false
|
||||
},
|
||||
{
|
||||
"name": "get-available-resources",
|
||||
"description": "Detects and reports available system resources (CPU cores, GPUs, memory, disk space) to inform computational approach decisions",
|
||||
"source": "./scientific-helpers/get-available-resources",
|
||||
"strict": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
69
LICENSE.md
69
LICENSE.md
@@ -1,56 +1,21 @@
|
||||
# PolyForm Noncommercial License 1.0.0
|
||||
<https://polyformproject.org/licenses/noncommercial/1.0.0>
|
||||
MIT License
|
||||
|
||||
> **Required Notice:** Copyright © K-Dense Inc. (https://k-dense.ai)
|
||||
Copyright (c) 2025 K-Dense Inc.
|
||||
|
||||
## Acceptance
|
||||
In order to get any license under these terms, you must agree to them as both strict obligations and conditions to all your licenses.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
## Copyright License
|
||||
The licensor (**K-Dense Inc.**) grants you a copyright license for the software to do everything you might do with the software that would otherwise infringe the licensor's copyright in it for any permitted purpose. However, you may only distribute the software according to [Distribution License](#distribution-license) and make changes or new works based on the software according to [Changes and New Works License](#changes-and-new-works-license).
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
## Distribution License
|
||||
The licensor (**K-Dense Inc.**) grants you an additional copyright license to distribute copies of the software. Your license to distribute covers distributing the software with changes and new works permitted by [Changes and New Works License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
You must ensure that anyone who gets a copy of any part of the software from you also gets a copy of these terms or the URL for them above, as well as copies of any plain-text lines beginning with `Required Notice:` that the licensor provided with the software. For example:
|
||||
|
||||
> Required Notice: Copyright © K-Dense Inc. (https://k-dense.ai)
|
||||
|
||||
## Changes and New Works License
|
||||
The licensor (**K-Dense Inc.**) grants you an additional copyright license to make changes and new works based on the software for any permitted purpose.
|
||||
|
||||
## Patent License
|
||||
The licensor (**K-Dense Inc.**) grants you a patent license for the software that covers patent claims the licensor can license, or becomes able to license, that you would infringe by using the software.
|
||||
|
||||
## Noncommercial Purposes
|
||||
Any noncommercial purpose is a permitted purpose.
|
||||
|
||||
## Personal Uses
|
||||
Personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, amateur pursuits, or religious observance, without any anticipated commercial application, is use for a permitted purpose.
|
||||
|
||||
## Noncommercial Organizations
|
||||
Use by any charitable organization, educational institution, public research organization, public safety or health organization, environmental protection organization, or government institution is use for a permitted purpose regardless of the source of funding or obligations resulting from the funding.
|
||||
|
||||
## Fair Use
|
||||
You may have "fair use" rights for the software under the law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
These terms do not allow you to sublicense or transfer any of your licenses to anyone else, or prevent the licensor (**K-Dense Inc.**) from granting licenses to anyone else. These terms do not imply any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
If you make any written claim that the software infringes or contributes to infringement of any patent, your patent license for the software granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
The first time you are notified in writing that you have violated any of these terms, or done anything with the software not covered by your licenses, your licenses can nonetheless continue if you come into full compliance with these terms and take practical steps to correct past violations within 32 days of receiving notice. Otherwise, all your licenses end immediately.
|
||||
|
||||
## No Liability
|
||||
***As far as the law allows, the software comes as is, without any warranty or condition, and K-Dense Inc. will not be liable to you for any damages arising out of these terms or the use or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
The **licensor** is **K-Dense Inc.**, the individual or entity offering these terms, and the **software** is the software K-Dense Inc. makes available under these terms.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these terms.
|
||||
**Your company** is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. **Control** means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect.
|
||||
**Your licenses** are all the licenses granted to you for the software under these terms.
|
||||
**Use** means anything you do with the software requiring one of your licenses.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
712
README.md
712
README.md
@@ -1,33 +1,63 @@
|
||||
# Claude Scientific Skills
|
||||
|
||||
[](LICENSE.md)
|
||||
[](https://github.com/K-Dense-AI/claude-scientific-skills)
|
||||
[](#what-s-included)
|
||||
[](#what-s-included)
|
||||
[](LICENSE.md)
|
||||
[](#whats-included)
|
||||
|
||||
A comprehensive collection of ready-to-use scientific skills for Claude, curated by the K-Dense team.
|
||||
> 💼 For substantially more advanced capabilities, compute infrastructure, and enterprise-ready offerings, check out [k-dense.ai](https://k-dense.ai/).
|
||||
|
||||
These skills enable Claude to work with specialized scientific libraries and databases across multiple scientific domains:
|
||||
- 🧬 Bioinformatics & Genomics
|
||||
- 🧪 Cheminformatics & Drug Discovery
|
||||
- 🔬 Proteomics & Mass Spectrometry
|
||||
- 🤖 Machine Learning & AI
|
||||
- 🔮 Materials Science & Chemistry
|
||||
- 📊 Data Analysis & Visualization
|
||||
A comprehensive collection of **138 ready-to-use scientific skills** for Claude, created by the K-Dense team. Transform Claude into your AI research assistant capable of executing complex multi-step scientific workflows across biology, chemistry, medicine, and beyond.
|
||||
|
||||
These skills enable Claude to seamlessly work with specialized scientific libraries, databases, and tools across multiple scientific domains:
|
||||
- 🧬 Bioinformatics & Genomics - Sequence analysis, single-cell RNA-seq, gene regulatory networks, variant annotation, phylogenetic analysis
|
||||
- 🧪 Cheminformatics & Drug Discovery - Molecular property prediction, virtual screening, ADMET analysis, molecular docking, lead optimization
|
||||
- 🔬 Proteomics & Mass Spectrometry - LC-MS/MS processing, peptide identification, spectral matching, protein quantification
|
||||
- 🏥 Clinical Research & Precision Medicine - Clinical trials, pharmacogenomics, variant interpretation, drug safety, clinical decision support, treatment planning
|
||||
- 🧠 Healthcare AI & Clinical ML - EHR analysis, physiological signal processing, medical imaging, clinical prediction models
|
||||
- 🖼️ Medical Imaging & Digital Pathology - DICOM processing, whole slide image analysis, computational pathology, radiology workflows
|
||||
- 🤖 Machine Learning & AI - Deep learning, reinforcement learning, time series analysis, model interpretability, Bayesian methods
|
||||
- 🔮 Materials Science & Chemistry - Crystal structure analysis, phase diagrams, metabolic modeling, computational chemistry
|
||||
- 🌌 Physics & Astronomy - Astronomical data analysis, coordinate transformations, cosmological calculations, symbolic mathematics, physics computations
|
||||
- ⚙️ Engineering & Simulation - Discrete-event simulation, multi-objective optimization, metabolic engineering, systems modeling, process optimization
|
||||
- 📊 Data Analysis & Visualization - Statistical analysis, network analysis, time series, publication-quality figures, large-scale data processing, EDA
|
||||
- 🧪 Laboratory Automation - Liquid handling protocols, lab equipment control, workflow automation, LIMS integration
|
||||
- 📚 Scientific Communication - Literature review, peer review, scientific writing, document processing, posters, slides, schematics, citation management
|
||||
- 🔬 Multi-omics & Systems Biology - Multi-modal data integration, pathway analysis, network biology, systems-level insights
|
||||
- 🧬 Protein Engineering & Design - Protein language models, structure prediction, sequence design, function annotation
|
||||
- 🎓 Research Methodology - Hypothesis generation, scientific brainstorming, critical thinking, grant writing, scholar evaluation
|
||||
|
||||
**Transform Claude Code into an 'AI Scientist' on your desktop!**
|
||||
|
||||
> 💼 For substantially more advanced capabilities, compute infrastructure, and enterprise-ready offerings, check out [k-dense.ai](https://k-dense.ai/).
|
||||
> ⭐ **If you find this repository useful**, please consider giving it a star! It helps others discover these tools and encourages us to continue maintaining and expanding this collection.
|
||||
|
||||
---
|
||||
|
||||
## 📦 What's Included
|
||||
|
||||
This repository provides **138 scientific skills** organized into the following categories:
|
||||
|
||||
- **28+ Scientific Databases** - Direct API access to OpenAlex, PubMed, bioRxiv, ChEMBL, UniProt, COSMIC, ClinicalTrials.gov, and more
|
||||
- **55+ Python Packages** - RDKit, Scanpy, PyTorch Lightning, scikit-learn, BioPython, BioServices, PennyLane, Qiskit, and others
|
||||
- **15+ Scientific Integrations** - Benchling, DNAnexus, LatchBio, OMERO, Protocols.io, and more
|
||||
- **30+ Analysis & Communication Tools** - Literature review, scientific writing, peer review, document processing, posters, slides, schematics, and more
|
||||
- **10+ Research & Clinical Tools** - Hypothesis generation, grant writing, clinical decision support, treatment plans, regulatory compliance
|
||||
|
||||
Each skill includes:
|
||||
- ✅ Comprehensive documentation (`SKILL.md`)
|
||||
- ✅ Practical code examples
|
||||
- ✅ Use cases and best practices
|
||||
- ✅ Integration guides
|
||||
- ✅ Reference materials
|
||||
|
||||
---
|
||||
|
||||
## 📋 Table of Contents
|
||||
|
||||
- [What's Included](#what-s-included)
|
||||
- [What's Included](#whats-included)
|
||||
- [Why Use This?](#why-use-this)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Claude Code](#claude-code)
|
||||
- [Any MCP Client](#any-mcp-client-including-chatgpt-cursor-google-adk-openai-agent-sdk-etc)
|
||||
- [Claude Code](#claude-code-recommended)
|
||||
- [Cursor IDE](#cursor-ide)
|
||||
- [Any MCP Client](#any-mcp-client)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Quick Examples](#quick-examples)
|
||||
- [Use Cases](#use-cases)
|
||||
@@ -36,319 +66,398 @@ These skills enable Claude to work with specialized scientific libraries and dat
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [FAQ](#faq)
|
||||
- [Support](#support)
|
||||
- [Join Our Community](#join-our-community)
|
||||
- [Citation](#citation)
|
||||
- [License](#license)
|
||||
|
||||
---
|
||||
|
||||
## 📦 What's Included
|
||||
|
||||
| Category | Count | Description |
|
||||
|----------|-------|-------------|
|
||||
| 📊 **Scientific Databases** | 25 | PubMed, PubChem, UniProt, ChEMBL, COSMIC, AlphaFold DB, bioRxiv, and more |
|
||||
| 🔬 **Scientific Packages** | 46 | BioPython, RDKit, PyTorch, Scanpy, scvi-tools, ESM, and specialized tools |
|
||||
| 🔌 **Scientific Integrations** | 6 | Benchling, DNAnexus, Opentrons, LabArchives, LatchBio, OMERO |
|
||||
| 🛠️ **Scientific Helpers** | 2 | Context initialization and resource detection utilities |
|
||||
| 📚 **Documented Workflows** | 122 | Ready-to-use examples and reference materials |
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Why Use This?
|
||||
|
||||
✅ **Save Time** - Skip days of API documentation research and integration work
|
||||
✅ **Best Practices** - Curated workflows following scientific computing standards
|
||||
✅ **Production Ready** - Tested and validated code examples
|
||||
✅ **Regular Updates** - Maintained and expanded by K-Dense team
|
||||
✅ **Comprehensive** - Coverage across major scientific domains
|
||||
✅ **Enterprise Support** - Commercial offerings available for advanced needs
|
||||
### ⚡ **Accelerate Your Research**
|
||||
- **Save Days of Work** - Skip API documentation research and integration setup
|
||||
- **Production-Ready Code** - Tested, validated examples following scientific best practices
|
||||
- **Multi-Step Workflows** - Execute complex pipelines with a single prompt
|
||||
|
||||
### 🎯 **Comprehensive Coverage**
|
||||
- **138 Skills** - Extensive coverage across all major scientific domains
|
||||
- **28+ Databases** - Direct access to OpenAlex, PubMed, bioRxiv, ChEMBL, UniProt, COSMIC, and more
|
||||
- **55+ Python Packages** - RDKit, Scanpy, PyTorch Lightning, scikit-learn, BioServices, PennyLane, Qiskit, and others
|
||||
|
||||
### 🔧 **Easy Integration**
|
||||
- **One-Click Setup** - Install via Claude Code or MCP server
|
||||
- **Automatic Discovery** - Claude automatically finds and uses relevant skills
|
||||
- **Well Documented** - Each skill includes examples, use cases, and best practices
|
||||
|
||||
### 🌟 **Maintained & Supported**
|
||||
- **Regular Updates** - Continuously maintained and expanded by K-Dense team
|
||||
- **Community Driven** - Open source with active community contributions
|
||||
- **Enterprise Ready** - Commercial support available for advanced needs
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Getting Started
|
||||
|
||||
### Claude Code
|
||||
Register this repository as a Claude Code Plugin marketplace by running:
|
||||
Choose your preferred platform to get started:
|
||||
|
||||
### 🖥️ Claude Code (Recommended)
|
||||
|
||||
> 📚 **New to Claude Code?** Check out the [Claude Code Quickstart Guide](https://docs.claude.com/en/docs/claude-code/quickstart) to get started.
|
||||
|
||||
**Step 1: Install Claude Code**
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
curl -fsSL https://claude.ai/install.sh | bash
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```powershell
|
||||
irm https://claude.ai/install.ps1 | iex
|
||||
```
|
||||
|
||||
**Step 2: Register the Marketplace**
|
||||
|
||||
```bash
|
||||
/plugin marketplace add K-Dense-AI/claude-scientific-skills
|
||||
```
|
||||
|
||||
Then, to install a specific set of skills:
|
||||
**Step 3: Install Skills**
|
||||
|
||||
1. Select **Browse and install plugins**
|
||||
2. Select **claude-scientific-skills**
|
||||
3. Choose from:
|
||||
- `scientific-databases` - Access to 25 scientific databases
|
||||
- `scientific-packages` - 46 specialized Python packages
|
||||
- `scientific-thinking` - Analysis tools and document processing
|
||||
- `scientific-integrations` - Lab automation and platform integrations
|
||||
- `scientific-context-initialization` - Ensures Claude searches for and uses existing skills
|
||||
4. Select **Install now**
|
||||
1. Open Claude Code
|
||||
2. Select **Browse and install plugins**
|
||||
3. Choose **claude-scientific-skills**
|
||||
4. Select **scientific-skills**
|
||||
5. Click **Install now**
|
||||
|
||||
After installation, simply mention the skill or describe your task - Claude Code will automatically use the appropriate skills!
|
||||
**That's it!** Claude will automatically use the appropriate skills when you describe your scientific tasks. Make sure to keep the skill up to date!
|
||||
|
||||
> 💡 **Tip**: If you find that Claude isn't utilizing the installed skills as much as you'd like, install the `scientific-context-initialization` skill. It automatically creates/updates an `AGENT.md` file in your workspace that instructs Claude to always search for and use existing skills before attempting any scientific task. This ensures Claude leverages documented patterns, authentication methods, working examples, and best practices from the repository.
|
||||
---
|
||||
|
||||
### Any MCP Client (including ChatGPT, Cursor, Google ADK, OpenAI Agent SDK, etc.)
|
||||
Use our newly released MCP server that allows you to use any Claude Skill in any client!
|
||||
### ⌨️ Cursor IDE
|
||||
|
||||
🔗 **[claude-skills-mcp](https://github.com/K-Dense-AI/claude-skills-mcp)**
|
||||
One-click installation via our hosted MCP server:
|
||||
|
||||
<a href="https://cursor.com/en-US/install-mcp?name=claude-scientific-skills&config=eyJ1cmwiOiJodHRwczovL21jcC5rLWRlbnNlLmFpL2NsYXVkZS1zY2llbnRpZmljLXNraWxscy9tY3AifQ%3D%3D">
|
||||
<picture>
|
||||
<source srcset="https://cursor.com/deeplink/mcp-install-light.svg" media="(prefers-color-scheme: dark)">
|
||||
<source srcset="https://cursor.com/deeplink/mcp-install-dark.svg" media="(prefers-color-scheme: light)">
|
||||
<img src="https://cursor.com/deeplink/mcp-install-dark.svg" alt="Install MCP Server" style="height:2.7em;"/>
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
### 🔌 Any MCP Client
|
||||
|
||||
Access all skills via our MCP server in any MCP-compatible client (ChatGPT, Google ADK, OpenAI Agent SDK, etc.):
|
||||
|
||||
**Option 1: Hosted MCP Server** (Easiest)
|
||||
```
|
||||
https://mcp.k-dense.ai/claude-scientific-skills/mcp
|
||||
```
|
||||
|
||||
**Option 2: Self-Hosted** (More Control)
|
||||
🔗 **[claude-skills-mcp](https://github.com/K-Dense-AI/claude-skills-mcp)** - Deploy your own MCP server
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ Prerequisites
|
||||
|
||||
- **Python**: 3.8+ (3.10+ recommended for best compatibility)
|
||||
- **Claude Code**: Latest version or any MCP-compatible client
|
||||
- **Python**: 3.9+ (3.12+ recommended for best compatibility)
|
||||
- **uv**: Python package manager (required for installing skill dependencies)
|
||||
- **Client**: Claude Code, Cursor, or any MCP-compatible client
|
||||
- **System**: macOS, Linux, or Windows with WSL2
|
||||
- **Dependencies**: Automatically handled by individual skills (check `SKILL.md` files for specific requirements)
|
||||
|
||||
### Installing uv
|
||||
|
||||
The skills use `uv` as the package manager for installing Python dependencies. Install it using the instructions for your operating system:
|
||||
|
||||
**macOS and Linux:**
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
**Windows:**
|
||||
```powershell
|
||||
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||
```
|
||||
|
||||
**Alternative (via pip):**
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
After installation, verify it works by running:
|
||||
```bash
|
||||
uv --version
|
||||
```
|
||||
|
||||
For more installation options and details, visit the [official uv documentation](https://docs.astral.sh/uv/).
|
||||
|
||||
---
|
||||
|
||||
## 💡 Quick Examples
|
||||
|
||||
Once you've installed the skills, you can ask Claude to execute complex multi-step scientific workflows:
|
||||
Once you've installed the skills, you can ask Claude to execute complex multi-step scientific workflows. Here are some example prompts:
|
||||
|
||||
### End-to-End Drug Discovery Pipeline
|
||||
### 🧪 Drug Discovery Pipeline
|
||||
**Goal**: Find novel EGFR inhibitors for lung cancer treatment
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"I need to find novel EGFR inhibitors for lung cancer treatment. Query ChEMBL for existing
|
||||
EGFR inhibitors with IC50 < 50nM, analyze their structure-activity relationships using RDKit,
|
||||
generate similar molecules with improved properties using datamol, perform virtual screening
|
||||
with DiffDock against the AlphaFold-predicted EGFR structure, and search PubMed for recent
|
||||
papers on resistance mechanisms to prioritize scaffolds. Finally, check COSMIC for common
|
||||
EGFR mutations and assess how our candidates might interact with mutant forms."
|
||||
Use available skills you have access to whenever possible. Query ChEMBL for EGFR inhibitors (IC50 < 50nM), analyze structure-activity relationships
|
||||
with RDKit, generate improved analogs with datamol, perform virtual screening with DiffDock
|
||||
against AlphaFold EGFR structure, search PubMed for resistance mechanisms, check COSMIC for
|
||||
mutations, and create visualizations and a comprehensive report.
|
||||
```
|
||||
|
||||
### Comprehensive Single-Cell Analysis Workflow
|
||||
**Skills Used**: ChEMBL, RDKit, datamol, DiffDock, AlphaFold DB, PubMed, COSMIC, scientific visualization
|
||||
|
||||
---
|
||||
|
||||
### 🔬 Single-Cell RNA-seq Analysis
|
||||
**Goal**: Comprehensive analysis of 10X Genomics data with public data integration
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"Load this 10X Genomics dataset using Scanpy, perform quality control and doublet removal,
|
||||
integrate with public data from Cellxgene Census for the same tissue type, identify cell
|
||||
populations using known markers from NCBI Gene, perform differential expression analysis
|
||||
with PyDESeq2, run gene regulatory network inference with Arboreto, query Reactome and
|
||||
KEGG for pathway enrichment, and create publication-quality visualizations with matplotlib.
|
||||
Then cross-reference top dysregulated genes with Open Targets to identify potential
|
||||
therapeutic targets."
|
||||
Use available skills you have access to whenever possible. Load 10X dataset with Scanpy, perform QC and doublet removal, integrate with Cellxgene
|
||||
Census data, identify cell types using NCBI Gene markers, run differential expression with
|
||||
PyDESeq2, infer gene regulatory networks with Arboreto, enrich pathways via Reactome/KEGG,
|
||||
and identify therapeutic targets with Open Targets.
|
||||
```
|
||||
|
||||
### Multi-Omics Integration for Biomarker Discovery
|
||||
**Skills Used**: Scanpy, Cellxgene Census, NCBI Gene, PyDESeq2, Arboreto, Reactome, KEGG, Open Targets
|
||||
|
||||
---
|
||||
|
||||
### 🧬 Multi-Omics Biomarker Discovery
|
||||
**Goal**: Integrate RNA-seq, proteomics, and metabolomics to predict patient outcomes
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"I have RNA-seq, proteomics, and metabolomics data from cancer patients. Use PyDESeq2 for
|
||||
differential expression, pyOpenMS to analyze mass spec data, and integrate metabolite
|
||||
information from HMDB and Metabolomics Workbench. Map proteins to pathways using UniProt
|
||||
and KEGG, identify protein-protein interactions via STRING, correlate multi-omics layers
|
||||
using statsmodels, and build a machine learning model with scikit-learn to predict patient
|
||||
outcomes. Search ClinicalTrials.gov for ongoing trials targeting the top candidates."
|
||||
Use available skills you have access to whenever possible. Analyze RNA-seq with PyDESeq2, process mass spec with pyOpenMS, integrate metabolites from
|
||||
HMDB/Metabolomics Workbench, map proteins to pathways (UniProt/KEGG), find interactions via
|
||||
STRING, correlate omics layers with statsmodels, build predictive model with scikit-learn,
|
||||
and search ClinicalTrials.gov for relevant trials.
|
||||
```
|
||||
|
||||
### Structure-Based Virtual Screening Campaign
|
||||
**Skills Used**: PyDESeq2, pyOpenMS, HMDB, Metabolomics Workbench, UniProt, KEGG, STRING, statsmodels, scikit-learn, ClinicalTrials.gov
|
||||
|
||||
---
|
||||
|
||||
### 🎯 Virtual Screening Campaign
|
||||
**Goal**: Discover allosteric modulators for protein-protein interactions
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"I want to discover allosteric modulators for a protein-protein interaction. Retrieve the
|
||||
AlphaFold structure for both proteins, identify the interaction interface using BioPython,
|
||||
search ZINC15 for molecules with suitable properties for allosteric binding (MW 300-500,
|
||||
logP 2-4), filter for drug-likeness using RDKit, perform molecular docking with DiffDock
|
||||
to identify potential allosteric sites, rank candidates using DeepChem's property prediction
|
||||
models, check PubChem for suppliers, and search USPTO patents to assess freedom to operate.
|
||||
Finally, generate analogs with MedChem and molfeat for lead optimization."
|
||||
Use available skills you have access to whenever possible. Retrieve AlphaFold structures, identify interaction interface with BioPython, search ZINC
|
||||
for allosteric candidates (MW 300-500, logP 2-4), filter with RDKit, dock with DiffDock,
|
||||
rank with DeepChem, check PubChem suppliers, search USPTO patents, and optimize leads with
|
||||
MedChem/molfeat.
|
||||
```
|
||||
|
||||
### Clinical Genomics Variant Interpretation Pipeline
|
||||
**Skills Used**: AlphaFold DB, BioPython, ZINC, RDKit, DiffDock, DeepChem, PubChem, USPTO, MedChem, molfeat
|
||||
|
||||
---
|
||||
|
||||
### 🏥 Clinical Variant Interpretation
|
||||
**Goal**: Analyze VCF file for hereditary cancer risk assessment
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"Analyze this VCF file from a patient with suspected hereditary cancer. Use pysam to parse
|
||||
variants, annotate with Ensembl for functional consequences, query ClinVar for known
|
||||
pathogenic variants, check COSMIC for somatic mutations in cancer, retrieve gene information
|
||||
from NCBI Gene, analyze protein impact using UniProt, search PubMed for case reports of
|
||||
similar variants, query ClinPGx for pharmacogenomic implications, and generate a clinical
|
||||
report with ReportLab. Then search ClinicalTrials.gov for precision medicine trials matching
|
||||
the patient's profile."
|
||||
Use available skills you have access to whenever possible. Parse VCF with pysam, annotate variants with Ensembl VEP, query ClinVar for pathogenicity,
|
||||
check COSMIC for cancer mutations, retrieve gene info from NCBI Gene, analyze protein impact
|
||||
with UniProt, search PubMed for case reports, check ClinPGx for pharmacogenomics, generate
|
||||
clinical report with ReportLab, and find matching trials on ClinicalTrials.gov.
|
||||
```
|
||||
|
||||
### Systems Biology Network Analysis
|
||||
**Skills Used**: pysam, Ensembl, ClinVar, COSMIC, NCBI Gene, UniProt, PubMed, ClinPGx, ReportLab, ClinicalTrials.gov
|
||||
|
||||
---
|
||||
|
||||
### 🌐 Systems Biology Network Analysis
|
||||
**Goal**: Analyze gene regulatory networks from RNA-seq data
|
||||
|
||||
**Prompt**:
|
||||
```
|
||||
"Starting with a list of differentially expressed genes from my RNA-seq experiment, query
|
||||
NCBI Gene for detailed annotations, retrieve protein sequences from UniProt, identify
|
||||
protein-protein interactions using STRING, map to biological pathways in Reactome and KEGG,
|
||||
analyze network topology with Torch Geometric, identify hub genes and bottleneck proteins,
|
||||
perform gene regulatory network reconstruction with Arboreto, integrate with Open Targets
|
||||
for druggability assessment, use PyMC for Bayesian network modeling, and create interactive
|
||||
network visualizations. Finally, search GEO for similar expression patterns across diseases."
|
||||
Use available skills you have access to whenever possible. Query NCBI Gene for annotations, retrieve sequences from UniProt, identify interactions via
|
||||
STRING, map to Reactome/KEGG pathways, analyze topology with Torch Geometric, reconstruct
|
||||
GRNs with Arboreto, assess druggability with Open Targets, model with PyMC, visualize
|
||||
networks, and search GEO for similar patterns.
|
||||
```
|
||||
|
||||
**Skills Used**: NCBI Gene, UniProt, STRING, Reactome, KEGG, Torch Geometric, Arboreto, Open Targets, PyMC, GEO
|
||||
|
||||
> 📖 **Want more examples?** Check out [docs/examples.md](docs/examples.md) for comprehensive workflow examples and detailed use cases across all scientific domains.
|
||||
|
||||
---
|
||||
|
||||
## 🔬 Use Cases
|
||||
|
||||
### Drug Discovery Research
|
||||
- Screen compound libraries from PubChem and ZINC
|
||||
- Analyze bioactivity data from ChEMBL
|
||||
- Predict molecular properties with RDKit and DeepChem
|
||||
- Perform molecular docking with DiffDock
|
||||
### 🧪 Drug Discovery & Medicinal Chemistry
|
||||
- **Virtual Screening**: Screen millions of compounds from PubChem/ZINC against protein targets
|
||||
- **Lead Optimization**: Analyze structure-activity relationships with RDKit, generate analogs with datamol
|
||||
- **ADMET Prediction**: Predict absorption, distribution, metabolism, excretion, and toxicity with DeepChem
|
||||
- **Molecular Docking**: Predict binding poses and affinities with DiffDock
|
||||
- **Bioactivity Mining**: Query ChEMBL for known inhibitors and analyze SAR patterns
|
||||
|
||||
### Bioinformatics Analysis
|
||||
- Process genomic sequences with BioPython
|
||||
- Analyze single-cell RNA-seq data with Scanpy
|
||||
- Query gene information from Ensembl and NCBI Gene
|
||||
- Identify protein-protein interactions via STRING
|
||||
### 🧬 Bioinformatics & Genomics
|
||||
- **Sequence Analysis**: Process DNA/RNA/protein sequences with BioPython and pysam
|
||||
- **Single-Cell Analysis**: Analyze 10X Genomics data with Scanpy, identify cell types, infer GRNs with Arboreto
|
||||
- **Variant Annotation**: Annotate VCF files with Ensembl VEP, query ClinVar for pathogenicity
|
||||
- **Gene Discovery**: Query NCBI Gene, UniProt, and Ensembl for comprehensive gene information
|
||||
- **Network Analysis**: Identify protein-protein interactions via STRING, map to pathways (KEGG, Reactome)
|
||||
|
||||
### Materials Science
|
||||
- Analyze crystal structures with Pymatgen
|
||||
- Predict material properties
|
||||
- Design novel compounds and materials
|
||||
### 🏥 Clinical Research & Precision Medicine
|
||||
- **Clinical Trials**: Search ClinicalTrials.gov for relevant studies, analyze eligibility criteria
|
||||
- **Variant Interpretation**: Annotate variants with ClinVar, COSMIC, and ClinPGx for pharmacogenomics
|
||||
- **Drug Safety**: Query FDA databases for adverse events, drug interactions, and recalls
|
||||
- **Precision Therapeutics**: Match patient variants to targeted therapies and clinical trials
|
||||
|
||||
### Clinical Research
|
||||
- Search clinical trials on ClinicalTrials.gov
|
||||
- Analyze genetic variants in ClinVar
|
||||
- Review pharmacogenomic data from ClinPGx
|
||||
- Access cancer mutations from COSMIC
|
||||
### 🔬 Multi-Omics & Systems Biology
|
||||
- **Multi-Omics Integration**: Combine RNA-seq, proteomics, and metabolomics data
|
||||
- **Pathway Analysis**: Enrich differentially expressed genes in KEGG/Reactome pathways
|
||||
- **Network Biology**: Reconstruct gene regulatory networks, identify hub genes
|
||||
- **Biomarker Discovery**: Integrate multi-omics layers to predict patient outcomes
|
||||
|
||||
### Academic Research
|
||||
- Literature searches via PubMed
|
||||
- Patent landscape analysis using USPTO
|
||||
- Data visualization for publications
|
||||
- Statistical analysis and hypothesis testing
|
||||
### 📊 Data Analysis & Visualization
|
||||
- **Statistical Analysis**: Perform hypothesis testing, power analysis, and experimental design
|
||||
- **Publication Figures**: Create publication-quality visualizations with matplotlib and seaborn
|
||||
- **Network Visualization**: Visualize biological networks with NetworkX
|
||||
- **Report Generation**: Generate comprehensive PDF reports with ReportLab
|
||||
|
||||
### 🧪 Laboratory Automation
|
||||
- **Protocol Design**: Create Opentrons protocols for automated liquid handling
|
||||
- **LIMS Integration**: Integrate with Benchling and LabArchives for data management
|
||||
- **Workflow Automation**: Automate multi-step laboratory workflows
|
||||
|
||||
---
|
||||
|
||||
## 📚 Available Skills
|
||||
|
||||
### 🗄️ Scientific Databases
|
||||
**25 comprehensive databases** including PubMed, PubChem, UniProt, ChEMBL, AlphaFold DB, bioRxiv, COSMIC, Ensembl, KEGG, and more.
|
||||
This repository contains **138 scientific skills** organized across multiple domains. Each skill provides comprehensive documentation, code examples, and best practices for working with scientific libraries, databases, and tools.
|
||||
|
||||
📖 **[Full Database Documentation →](docs/scientific-databases.md)**
|
||||
### Skill Categories
|
||||
|
||||
<details>
|
||||
<summary><strong>View all databases</strong></summary>
|
||||
#### 🧬 **Bioinformatics & Genomics** (16+ skills)
|
||||
- Sequence analysis: BioPython, pysam, scikit-bio, BioServices
|
||||
- Single-cell analysis: Scanpy, AnnData, scvi-tools, Arboreto, Cellxgene Census
|
||||
- Genomic tools: gget, geniml, gtars, deepTools, FlowIO, Zarr
|
||||
- Phylogenetics: ETE Toolkit
|
||||
|
||||
- **AlphaFold DB** - AI-predicted protein structures (200M+ predictions)
|
||||
- **bioRxiv** - Life sciences preprint server with medRxiv integration
|
||||
- **ChEMBL** - Bioactive molecules and drug-like properties
|
||||
- **ClinPGx** - Clinical pharmacogenomics and gene-drug interactions
|
||||
- **ClinVar** - Genomic variants and clinical significance
|
||||
- **ClinicalTrials.gov** - Global clinical studies registry
|
||||
- **COSMIC** - Somatic cancer mutations database
|
||||
- **ENA** - European Nucleotide Archive
|
||||
- **Ensembl** - Genome browser and annotations
|
||||
- **FDA Databases** - Drug approvals, adverse events, recalls
|
||||
- **GEO** - Gene expression and functional genomics
|
||||
- **GWAS Catalog** - Genome-wide association studies
|
||||
- **HMDB** - Human metabolome database
|
||||
- **KEGG** - Biological pathways and molecular interactions
|
||||
- **Metabolomics Workbench** - NIH metabolomics data
|
||||
- **NCBI Gene** - Gene information and annotations
|
||||
- **Open Targets** - Therapeutic target identification
|
||||
- **PDB** - Protein structure database
|
||||
- **PubChem** - Chemical compound data (110M+ compounds)
|
||||
- **PubMed** - Biomedical literature database
|
||||
- **Reactome** - Curated biological pathways
|
||||
- **STRING** - Protein-protein interaction networks
|
||||
- **UniProt** - Protein sequences and annotations
|
||||
- **USPTO** - Patent and trademark data
|
||||
- **ZINC** - Commercially-available compounds for screening
|
||||
#### 🧪 **Cheminformatics & Drug Discovery** (10+ skills)
|
||||
- Molecular manipulation: RDKit, Datamol, Molfeat
|
||||
- Deep learning: DeepChem, TorchDrug
|
||||
- Docking & screening: DiffDock
|
||||
- Drug-likeness: MedChem
|
||||
- Benchmarks: PyTDC
|
||||
|
||||
</details>
|
||||
#### 🔬 **Proteomics & Mass Spectrometry** (2 skills)
|
||||
- Spectral processing: matchms, pyOpenMS
|
||||
|
||||
---
|
||||
#### 🏥 **Clinical Research & Precision Medicine** (12+ skills)
|
||||
- Clinical databases: ClinicalTrials.gov, ClinVar, ClinPGx, COSMIC, FDA Databases
|
||||
- Healthcare AI: PyHealth, NeuroKit2, Clinical Decision Support
|
||||
- Clinical documentation: Clinical Reports, Treatment Plans
|
||||
- Variant analysis: Ensembl, NCBI Gene
|
||||
|
||||
### 🔬 Scientific Packages
|
||||
**44 specialized Python packages** organized by domain.
|
||||
#### 🖼️ **Medical Imaging & Digital Pathology** (3 skills)
|
||||
- DICOM processing: pydicom
|
||||
- Whole slide imaging: histolab, PathML
|
||||
|
||||
📖 **[Full Package Documentation →](docs/scientific-packages.md)**
|
||||
#### 🧠 **Neuroscience & Electrophysiology** (1 skill)
|
||||
- Neural recordings: Neuropixels-Analysis (extracellular spikes, silicon probes, spike sorting)
|
||||
|
||||
<details>
|
||||
<summary><strong>Bioinformatics & Genomics (12 packages)</strong></summary>
|
||||
#### 🤖 **Machine Learning & AI** (15+ skills)
|
||||
- Deep learning: PyTorch Lightning, Transformers, Stable Baselines3, PufferLib
|
||||
- Classical ML: scikit-learn, scikit-survival, SHAP
|
||||
- Time series: aeon
|
||||
- Bayesian methods: PyMC
|
||||
- Optimization: PyMOO
|
||||
- Graph ML: Torch Geometric
|
||||
- Dimensionality reduction: UMAP-learn
|
||||
- Statistical modeling: statsmodels
|
||||
|
||||
- AnnData, Arboreto, BioPython, BioServices, Cellxgene Census
|
||||
- deepTools, FlowIO, gget, pysam, PyDESeq2, Scanpy, scvi-tools
|
||||
#### 🔮 **Materials Science, Chemistry & Physics** (7 skills)
|
||||
- Materials: Pymatgen
|
||||
- Metabolic modeling: COBRApy
|
||||
- Astronomy: Astropy
|
||||
- Quantum computing: Cirq, PennyLane, Qiskit, QuTiP
|
||||
|
||||
</details>
|
||||
#### ⚙️ **Engineering & Simulation** (3 skills)
|
||||
- Computational fluid dynamics: FluidSim
|
||||
- Discrete-event simulation: SimPy
|
||||
- Data processing: Dask, Polars, Vaex
|
||||
|
||||
<details>
|
||||
<summary><strong>Cheminformatics & Drug Discovery (8 packages)</strong></summary>
|
||||
#### 📊 **Data Analysis & Visualization** (14+ skills)
|
||||
- Visualization: Matplotlib, Seaborn, Plotly, Scientific Visualization
|
||||
- Geospatial analysis: GeoPandas
|
||||
- Network analysis: NetworkX
|
||||
- Symbolic math: SymPy
|
||||
- PDF generation: ReportLab
|
||||
- Data access: Data Commons
|
||||
- Exploratory data analysis: EDA workflows
|
||||
- Statistical analysis: Statistical Analysis workflows
|
||||
|
||||
- Datamol, DeepChem, DiffDock, MedChem, Molfeat, PyTDC, RDKit, TorchDrug
|
||||
#### 🧪 **Laboratory Automation** (3 skills)
|
||||
- Liquid handling: PyLabRobot
|
||||
- Protocol management: Protocols.io
|
||||
- LIMS integration: Benchling, LabArchives
|
||||
|
||||
</details>
|
||||
#### 🔬 **Multi-omics & Systems Biology** (5+ skills)
|
||||
- Pathway analysis: KEGG, Reactome, STRING
|
||||
- Multi-omics: BIOMNI, Denario, HypoGeniC
|
||||
- Data management: LaminDB
|
||||
|
||||
<details>
|
||||
<summary><strong>Proteomics & Mass Spectrometry (2 packages)</strong></summary>
|
||||
#### 🧬 **Protein Engineering & Design** (2 skills)
|
||||
- Protein language models: ESM
|
||||
- Cloud laboratory platform: Adaptyv (automated protein testing and validation)
|
||||
|
||||
- matchms, pyOpenMS
|
||||
#### 📚 **Scientific Communication** (20+ skills)
|
||||
- Literature: OpenAlex, PubMed, bioRxiv, Literature Review
|
||||
- Web search: Perplexity Search (AI-powered search with real-time information)
|
||||
- Writing: Scientific Writing, Peer Review
|
||||
- Document processing: XLSX, MarkItDown, Document Skills
|
||||
- Publishing: Paper-2-Web, Venue Templates
|
||||
- Presentations: Scientific Slides, LaTeX Posters, PPTX Posters
|
||||
- Diagrams: Scientific Schematics
|
||||
- Citations: Citation Management
|
||||
- Illustration: Generate Image (AI image generation with FLUX.2 Pro and Gemini 3 Pro (Nano Banana Pro))
|
||||
|
||||
</details>
|
||||
#### 🔬 **Scientific Databases** (28+ skills)
|
||||
- Protein: UniProt, PDB, AlphaFold DB
|
||||
- Chemical: PubChem, ChEMBL, DrugBank, ZINC, HMDB
|
||||
- Genomic: Ensembl, NCBI Gene, GEO, ENA, GWAS Catalog
|
||||
- Literature: bioRxiv (preprints)
|
||||
- Clinical: ClinVar, COSMIC, ClinicalTrials.gov, ClinPGx, FDA Databases
|
||||
- Pathways: KEGG, Reactome, STRING
|
||||
- Targets: Open Targets
|
||||
- Metabolomics: Metabolomics Workbench
|
||||
- Enzymes: BRENDA
|
||||
- Patents: USPTO
|
||||
|
||||
<details>
|
||||
<summary><strong>Machine Learning & Deep Learning (9 packages)</strong></summary>
|
||||
#### 🔧 **Infrastructure & Platforms** (6+ skills)
|
||||
- Cloud compute: Modal
|
||||
- Genomics platforms: DNAnexus, LatchBio
|
||||
- Microscopy: OMERO
|
||||
- Automation: Opentrons
|
||||
- Tool discovery: ToolUniverse, Get Available Resources
|
||||
|
||||
- PyMC, PyMOO, PyTorch Lightning, scikit-learn, SHAP, statsmodels
|
||||
- Torch Geometric, Transformers, UMAP-learn
|
||||
#### 🎓 **Research Methodology & Planning** (8+ skills)
|
||||
- Ideation: Scientific Brainstorming, Hypothesis Generation
|
||||
- Critical analysis: Scientific Critical Thinking, Scholar Evaluation
|
||||
- Funding: Research Grants
|
||||
- Discovery: Research Lookup
|
||||
- Market analysis: Market Research Reports
|
||||
|
||||
</details>
|
||||
#### ⚖️ **Regulatory & Standards** (1 skill)
|
||||
- Medical device standards: ISO 13485 Certification
|
||||
|
||||
<details>
|
||||
<summary><strong>Materials Science & Chemistry (3 packages)</strong></summary>
|
||||
> 📖 **For complete details on all skills**, see [docs/scientific-skills.md](docs/scientific-skills.md)
|
||||
|
||||
- Astropy, COBRApy, Pymatgen
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>Data Analysis & Visualization (5 packages)</strong></summary>
|
||||
|
||||
- Dask, Matplotlib, Polars, ReportLab, Seaborn
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>Additional Packages (6 packages)</strong></summary>
|
||||
|
||||
- BIOMNI (Multi-omics), ETE Toolkit (Phylogenetics)
|
||||
- Paper-2-Web (Academic paper dissemination and presentation)
|
||||
- scikit-bio (Sequence analysis), ToolUniverse (600+ scientific tool ecosystem)
|
||||
- Zarr (Array storage)
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
### 🧠 Scientific Thinking & Analysis
|
||||
**Comprehensive analysis tools** and document processing capabilities.
|
||||
|
||||
📖 **[Full Thinking & Analysis Documentation →](docs/scientific-thinking.md)**
|
||||
|
||||
**Analysis & Methodology:**
|
||||
- Exploratory Data Analysis (automated statistics and insights)
|
||||
- Hypothesis Generation (structured frameworks)
|
||||
- Peer Review (comprehensive evaluation toolkit)
|
||||
- Scientific Brainstorming (ideation workflows)
|
||||
- Scientific Critical Thinking (rigorous reasoning)
|
||||
- Scientific Visualization (publication-quality figures)
|
||||
- Scientific Writing (IMRAD format, citation styles)
|
||||
- Statistical Analysis (testing and experimental design)
|
||||
|
||||
**Document Processing:**
|
||||
- DOCX, PDF, PPTX, XLSX manipulation and analysis
|
||||
- Tracked changes, comments, and formatting preservation
|
||||
- Text extraction, table parsing, and data analysis
|
||||
|
||||
---
|
||||
|
||||
### 🔌 Scientific Integrations
|
||||
**6 platform integrations** for lab automation and workflow management.
|
||||
|
||||
📖 **[Full Integration Documentation →](docs/scientific-integrations.md)**
|
||||
|
||||
- **Benchling** - R&D platform and LIMS integration
|
||||
- **DNAnexus** - Cloud genomics and biomedical data analysis
|
||||
- **LabArchives** - Electronic Lab Notebook (ELN) integration
|
||||
- **LatchBio** - Workflow platform and cloud execution
|
||||
- **OMERO** - Microscopy and bio-image data management
|
||||
- **Opentrons** - Laboratory automation protocols
|
||||
|
||||
---
|
||||
|
||||
### 🛠️ Scientific Helpers
|
||||
**2 helper utilities** for enhanced scientific computing capabilities.
|
||||
|
||||
- **scientific-context-initialization** - Auto-invoked skill that creates/updates workspace AGENT.md to instruct Claude to search for and use existing skills before attempting any scientific task
|
||||
- **get-available-resources** - Detects available system resources (CPU cores, GPUs, memory, disk space) and generates strategic recommendations for computational approaches (parallel processing, out-of-core computing, GPU acceleration)
|
||||
> 💡 **Looking for practical examples?** Check out [docs/examples.md](docs/examples.md) for comprehensive workflow examples across all scientific domains.
|
||||
|
||||
---
|
||||
|
||||
@@ -402,26 +511,19 @@ Contributors are recognized in our community and may be featured in:
|
||||
|
||||
Your contributions help make scientific computing more accessible and enable researchers to leverage AI tools more effectively!
|
||||
|
||||
📖 **[Contributing Guidelines →](CONTRIBUTING.md)** *(coming soon)*
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Problem: Claude not using installed skills**
|
||||
- Solution: Install the `scientific-context-initialization` skill
|
||||
- This creates an `AGENT.md` file that instructs Claude to search for and use existing skills before attempting tasks
|
||||
- After installation, Claude will automatically leverage documented patterns, examples, and best practices
|
||||
|
||||
**Problem: Skills not loading in Claude Code**
|
||||
- Solution: Ensure you've installed the latest version of Claude Code
|
||||
- Try reinstalling the plugin: `/plugin marketplace add K-Dense-AI/claude-scientific-skills`
|
||||
|
||||
**Problem: Missing Python dependencies**
|
||||
- Solution: Check the specific `SKILL.md` file for required packages
|
||||
- Install dependencies: `pip install package-name`
|
||||
- Install dependencies: `uv pip install package-name`
|
||||
|
||||
**Problem: API rate limits**
|
||||
- Solution: Many databases have rate limits. Review the specific database documentation
|
||||
@@ -439,29 +541,41 @@ Your contributions help make scientific computing more accessible and enable res
|
||||
|
||||
## ❓ FAQ
|
||||
|
||||
### General Questions
|
||||
|
||||
**Q: Is this free to use?**
|
||||
A: Yes, for noncommercial use. See the [License](#license) section for details.
|
||||
A: Yes! This project is MIT licensed, allowing free use for any purpose including commercial projects.
|
||||
|
||||
**Q: Do I need all the Python packages installed?**
|
||||
A: No, only install the packages you need. Each skill specifies its requirements.
|
||||
|
||||
**Q: Can I use this with other AI models?**
|
||||
A: The skills are designed for Claude but can be adapted for other models with MCP support.
|
||||
|
||||
**Q: How often is this updated?**
|
||||
A: We regularly update skills to reflect the latest versions of packages and APIs.
|
||||
**Q: Why are all skills grouped into one plugin instead of separate plugins?**
|
||||
A: We believe good science in the age of AI is inherently interdisciplinary. Bundling all skills into a single plugin makes it trivial for you (and Claude) to bridge across fields—e.g., combining genomics, cheminformatics, clinical data, and machine learning in one workflow—without worrying about which individual skills to install or wire together.
|
||||
|
||||
**Q: Can I use this for commercial projects?**
|
||||
A: For commercial use, please visit [K-Dense](https://k-dense.ai/) for enterprise licensing.
|
||||
A: Absolutely! The MIT License allows both commercial and noncommercial use without restrictions.
|
||||
|
||||
**Q: How often is this updated?**
|
||||
A: We regularly update skills to reflect the latest versions of packages and APIs. Major updates are announced in release notes.
|
||||
|
||||
**Q: Can I use this with other AI models?**
|
||||
A: The skills are optimized for Claude but can be adapted for other models with MCP support. The MCP server works with any MCP-compatible client.
|
||||
|
||||
### Installation & Setup
|
||||
|
||||
**Q: Do I need all the Python packages installed?**
|
||||
A: No! Only install the packages you need. Each skill specifies its requirements in its `SKILL.md` file.
|
||||
|
||||
**Q: What if a skill doesn't work?**
|
||||
A: First check the troubleshooting section, then file an issue on GitHub with details.
|
||||
|
||||
**Q: Can I contribute my own skills?**
|
||||
A: Absolutely! See the [Contributing](#contributing) section for guidelines.
|
||||
A: First check the [Troubleshooting](#troubleshooting) section. If the issue persists, file an issue on GitHub with detailed reproduction steps.
|
||||
|
||||
**Q: Do the skills work offline?**
|
||||
A: Database skills require internet access. Package skills work offline once dependencies are installed.
|
||||
A: Database skills require internet access to query APIs. Package skills work offline once Python dependencies are installed.
|
||||
|
||||
### Contributing
|
||||
|
||||
**Q: Can I contribute my own skills?**
|
||||
A: Absolutely! We welcome contributions. See the [Contributing](#contributing) section for guidelines and best practices.
|
||||
|
||||
**Q: How do I report bugs or suggest features?**
|
||||
A: Open an issue on GitHub with a clear description. For bugs, include reproduction steps and expected vs actual behavior.
|
||||
|
||||
---
|
||||
|
||||
@@ -473,19 +587,73 @@ Need help? Here's how to get support:
|
||||
- 🐛 **Bug Reports**: [Open an issue](https://github.com/K-Dense-AI/claude-scientific-skills/issues)
|
||||
- 💡 **Feature Requests**: [Submit a feature request](https://github.com/K-Dense-AI/claude-scientific-skills/issues/new)
|
||||
- 💼 **Enterprise Support**: Contact [K-Dense](https://k-dense.ai/) for commercial support
|
||||
- 🌐 **MCP Support**: Visit the [claude-skills-mcp](https://github.com/K-Dense-AI/claude-skills-mcp) repository
|
||||
- 🌐 **MCP Support**: Visit the [claude-skills-mcp](https://github.com/K-Dense-AI/claude-skills-mcp) repository or use our hosted MCP server
|
||||
|
||||
---
|
||||
|
||||
## 🎉 Join Our Community!
|
||||
|
||||
**We'd love to have you join us!** 🚀
|
||||
|
||||
Connect with other scientists, researchers, and AI enthusiasts using Claude for scientific computing. Share your discoveries, ask questions, get help with your projects, and collaborate with the community!
|
||||
|
||||
🌟 **[Join our Slack Community](https://join.slack.com/t/k-densecommunity/shared_invite/zt-3iajtyls1-EwmkwIZk0g_o74311Tkf5g)** 🌟
|
||||
|
||||
Whether you're just getting started or you're a power user, our community is here to support you. We share tips, troubleshoot issues together, showcase cool projects, and discuss the latest developments in AI-powered scientific research.
|
||||
|
||||
**See you there!** 💬
|
||||
|
||||
---
|
||||
|
||||
## 📖 Citation
|
||||
|
||||
If you use Claude Scientific Skills in your research or project, please cite it as:
|
||||
|
||||
### BibTeX
|
||||
```bibtex
|
||||
@software{claude_scientific_skills_2025,
|
||||
author = {{K-Dense Inc.}},
|
||||
title = {Claude Scientific Skills: A Comprehensive Collection of Scientific Tools for Claude AI},
|
||||
year = {2025},
|
||||
url = {https://github.com/K-Dense-AI/claude-scientific-skills},
|
||||
note = {skills covering databases, packages, integrations, and analysis tools}
|
||||
}
|
||||
```
|
||||
|
||||
### APA
|
||||
```
|
||||
K-Dense Inc. (2025). Claude Scientific Skills: A comprehensive collection of scientific tools for Claude AI [Computer software]. https://github.com/K-Dense-AI/claude-scientific-skills
|
||||
```
|
||||
|
||||
### MLA
|
||||
```
|
||||
K-Dense Inc. Claude Scientific Skills: A Comprehensive Collection of Scientific Tools for Claude AI. 2025, github.com/K-Dense-AI/claude-scientific-skills.
|
||||
```
|
||||
|
||||
### Plain Text
|
||||
```
|
||||
Claude Scientific Skills by K-Dense Inc. (2025)
|
||||
Available at: https://github.com/K-Dense-AI/claude-scientific-skills
|
||||
```
|
||||
|
||||
We appreciate acknowledgment in publications, presentations, or projects that benefit from these skills!
|
||||
|
||||
---
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the **PolyForm Noncommercial License 1.0.0**.
|
||||
This project is licensed under the **MIT License**.
|
||||
|
||||
**Copyright © K-Dense Inc.** ([k-dense.ai](https://k-dense.ai/))
|
||||
**Copyright © 2025 K-Dense Inc.** ([k-dense.ai](https://k-dense.ai/))
|
||||
|
||||
### Key Points:
|
||||
- ✅ **Free for noncommercial use** (research, education, personal projects)
|
||||
- ✅ **Free for noncommercial organizations** (universities, research institutions)
|
||||
- ❌ **Commercial use requires separate license** (contact K-Dense)
|
||||
- ✅ **Free for any use** (commercial and noncommercial)
|
||||
- ✅ **Open source** - modify, distribute, and use freely
|
||||
- ✅ **Permissive** - minimal restrictions on reuse
|
||||
- ⚠️ **No warranty** - provided "as is" without warranty of any kind
|
||||
|
||||
See [LICENSE.md](LICENSE.md) for full terms.
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#K-Dense-AI/claude-scientific-skills&type=date&legend=top-left)
|
||||
|
||||
2668
docs/examples.md
Normal file
2668
docs/examples.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
# Scientific Databases
|
||||
|
||||
- **AlphaFold DB** - AI-predicted protein structure database with 200M+ predictions, confidence metrics (pLDDT, PAE), and Google Cloud bulk access
|
||||
- **ChEMBL** - Bioactive molecule database with drug-like properties (2M+ compounds, 19M+ activities, 13K+ targets)
|
||||
- **ClinPGx** - Clinical pharmacogenomics database (successor to PharmGKB) providing gene-drug interactions, CPIC clinical guidelines, allele functions, drug labels, and pharmacogenomic annotations for precision medicine and personalized pharmacotherapy (consolidates PharmGKB, CPIC, and PharmCAT resources)
|
||||
- **ClinVar** - NCBI's public archive of genomic variants and their clinical significance with standardized classifications (pathogenic, benign, VUS), E-utilities API access, and bulk FTP downloads for variant interpretation and precision medicine research
|
||||
- **ClinicalTrials.gov** - Comprehensive registry of clinical studies conducted worldwide (maintained by U.S. National Library of Medicine) with API v2 access for searching trials by condition, intervention, location, sponsor, study status, and phase; retrieve detailed trial information including eligibility criteria, outcomes, contacts, and locations; export to CSV/JSON formats for analysis (public API, no authentication required, ~50 req/min rate limit)
|
||||
- **COSMIC** - Catalogue of Somatic Mutations in Cancer, the world's largest database of somatic cancer mutations (millions of mutations across thousands of cancer types, Cancer Gene Census, mutational signatures, structural variants, and drug resistance data)
|
||||
- **ENA (European Nucleotide Archive)** - Comprehensive public repository for nucleotide sequence data and metadata with REST APIs for accessing sequences, assemblies, samples, studies, and reads; supports advanced search, taxonomy lookups, and bulk downloads via FTP/Aspera (rate limit: 50 req/sec)
|
||||
- **Ensembl** - Genome browser and bioinformatics database providing genomic annotations, sequences, variants, and comparative genomics data for 250+ vertebrate species (Release 115, 2025) with comprehensive REST API for gene lookups, sequence retrieval, variant effect prediction (VEP), ortholog finding, assembly mapping (GRCh37/GRCh38), and region analysis
|
||||
- **FDA Databases** - Comprehensive access to all FDA (Food and Drug Administration) regulatory databases through openFDA API covering drugs (adverse events, labeling, NDC, recalls, approvals, shortages), medical devices (adverse events, 510k clearances, PMA, UDI, classifications), foods (recalls, adverse events, allergen tracking), animal/veterinary medicines (species-specific adverse events), and substances (UNII/CAS lookup, chemical structures, molecular data) for drug safety research, pharmacovigilance, regulatory compliance, and scientific analysis
|
||||
- **GEO (Gene Expression Omnibus)** - High-throughput gene expression and functional genomics data repository (264K+ studies, 8M+ samples) with microarray, RNA-seq, and expression profile access
|
||||
- **GWAS Catalog** - NHGRI-EBI catalog of published genome-wide association studies with curated SNP-trait associations (thousands of studies, genome-wide significant associations p≤5×10⁻⁸), full summary statistics, REST API access for variant/trait/gene queries, and FTP downloads for genetic epidemiology and precision medicine research
|
||||
- **HMDB (Human Metabolome Database)** - Comprehensive metabolomics resource with 220K+ metabolite entries, detailed chemical/biological data, concentration ranges, disease associations, pathways, and spectral data for metabolite identification and biomarker discovery
|
||||
- **KEGG** - Kyoto Encyclopedia of Genes and Genomes for biological pathway analysis, gene-to-pathway mapping, compound searches, and molecular interaction networks (pathway enrichment, metabolic pathways, gene annotations, drug-drug interactions, ID conversion)
|
||||
- **Metabolomics Workbench** - NIH Common Fund metabolomics data repository with 4,200+ processed studies, standardized nomenclature (RefMet), mass spectrometry searches, and comprehensive REST API for accessing metabolite structures, study metadata, experimental results, and gene/protein-metabolite associations
|
||||
- **Open Targets** - Comprehensive therapeutic target identification and validation platform integrating genetics, omics, and chemical data (200M+ evidence strings, target-disease associations with scoring, tractability assessments, safety liabilities, known drugs from ChEMBL, GraphQL API) for drug target discovery, prioritization, evidence evaluation, drug repurposing, competitive intelligence, and mechanism research
|
||||
- **NCBI Gene** - Work with NCBI Gene database to search, retrieve, and analyze gene information including nomenclature, sequences, variations, phenotypes, and pathways using E-utilities and Datasets API
|
||||
- **Protein Data Bank (PDB)** - Access 3D structural data of proteins, nucleic acids, and biological macromolecules (200K+ structures) with search, retrieval, and analysis capabilities
|
||||
- **PubChem** - Access chemical compound data from the world's largest free chemical database (110M+ compounds, 270M+ bioactivities)
|
||||
- **PubMed** - Access to PubMed literature database with advanced search capabilities
|
||||
- **Reactome** - Curated pathway database for biological processes and molecular interactions (2,825+ human pathways, 16K+ reactions, 11K+ proteins) with pathway enrichment analysis, expression data analysis, and species comparison using Content Service and Analysis Service APIs
|
||||
- **STRING** - Protein-protein interaction network database (5000+ genomes, 59.3M proteins, 20B+ interactions) with functional enrichment analysis, interaction partner discovery, and network visualization from experimental data, computational prediction, and text-mining
|
||||
- **UniProt** - Universal Protein Resource for protein sequences, annotations, and functional information (UniProtKB/Swiss-Prot reviewed entries, TrEMBL unreviewed entries) with REST API access for search, retrieval, ID mapping, and batch operations across 200+ databases
|
||||
- **USPTO** - United States Patent and Trademark Office data access including patent searches, trademark lookups, patent examination history (PEDS), office actions, assignments, citations, and litigation records; supports PatentSearch API (ElasticSearch-based patent search), TSDR (Trademark Status & Document Retrieval), Patent/Trademark Assignment APIs, and additional specialized APIs for comprehensive IP analysis
|
||||
- **ZINC** - Free database of commercially-available compounds for virtual screening and drug discovery (230M+ purchasable compounds in ready-to-dock 3D formats)
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
# Scientific Integrations
|
||||
|
||||
## Laboratory Information Management Systems (LIMS) & R&D Platforms
|
||||
- **Benchling Integration** - Toolkit for integrating with Benchling's R&D platform, providing programmatic access to laboratory data management including registry entities (DNA sequences, proteins), inventory systems (samples, containers, locations), electronic lab notebooks (entries, protocols), workflows (tasks, automation), and data exports using Python SDK and REST API
|
||||
|
||||
## Cloud Platforms for Genomics & Biomedical Data
|
||||
- **DNAnexus Integration** - Comprehensive toolkit for working with the DNAnexus cloud platform for genomics and biomedical data analysis. Covers building and deploying apps/applets (Python/Bash), managing data objects (files, records, databases), running analyses and workflows, using the dxpy Python SDK, and configuring app metadata and dependencies (dxapp.json setup, system packages, Docker, assets). Enables processing of FASTQ/BAM/VCF files, bioinformatics pipelines, job execution, workflow orchestration, and platform operations including project management and permissions
|
||||
|
||||
## Laboratory Automation
|
||||
- **Opentrons Integration** - Toolkit for creating, editing, and debugging Opentrons Python Protocol API v2 protocols for laboratory automation using Flex and OT-2 robots. Enables automated liquid handling, pipetting workflows, hardware module control (thermocycler, temperature, magnetic, heater-shaker, absorbance plate reader), labware management, and complex protocol development for biological and chemical experiments
|
||||
|
||||
## Electronic Lab Notebooks (ELN)
|
||||
- **LabArchives Integration** - Toolkit for interacting with LabArchives Electronic Lab Notebook (ELN) REST API. Provides programmatic access to notebooks (backup, retrieval, management), entries (creation, comments, attachments), user authentication, site reports and analytics, and third-party integrations (Protocols.io, GraphPad Prism, SnapGene, Geneious, Jupyter, REDCap). Includes Python scripts for configuration setup, notebook operations, and entry management. Supports multi-regional API endpoints (US, UK, Australia) and OAuth authentication
|
||||
|
||||
## Workflow Platforms & Cloud Execution
|
||||
- **LatchBio Integration** - Integration with the Latch platform for building, deploying, and executing bioinformatics workflows. Provides comprehensive support for creating serverless bioinformatics pipelines using Python decorators, deploying Nextflow/Snakemake pipelines, managing cloud data (LatchFile, LatchDir) and structured Registry (Projects, Tables, Records), configuring computational resources (CPU, GPU, memory, storage), and using pre-built Latch Verified workflows (RNA-seq, AlphaFold, DESeq2, single-cell analysis, CRISPR editing). Enables automatic containerization, UI generation, workflow versioning, and execution on scalable cloud infrastructure with comprehensive data management
|
||||
|
||||
## Microscopy & Bio-image Data
|
||||
- **OMERO Integration** - Toolkit for interacting with OMERO microscopy data management systems using Python. Provides comprehensive access to microscopy images stored in OMERO servers, including dataset and screening data retrieval, pixel data analysis, annotation and metadata management, regions of interest (ROIs) creation and analysis, batch processing, OMERO.scripts development, and OMERO.tables for structured data storage. Essential for researchers working with high-content screening data, multi-dimensional microscopy datasets, or collaborative image repositories
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# Scientific Packages
|
||||
|
||||
## Bioinformatics & Genomics
|
||||
- **AnnData** - Annotated data matrices for single-cell genomics and h5ad files
|
||||
- **Arboreto** - Gene regulatory network inference using GRNBoost2 and GENIE3
|
||||
- **BioPython** - Sequence manipulation, NCBI database access, BLAST searches, alignments, and phylogenetics
|
||||
- **BioServices** - Programmatic access to 40+ biological web services (KEGG, UniProt, ChEBI, ChEMBL)
|
||||
- **Cellxgene Census** - Query and analyze large-scale single-cell RNA-seq data
|
||||
- **gget** - Efficient genomic database queries (Ensembl, UniProt, NCBI, PDB, COSMIC)
|
||||
- **pysam** - Read, write, and manipulate genomic data files (SAM/BAM/CRAM alignments, VCF/BCF variants, FASTA/FASTQ sequences) with pileup analysis, coverage calculations, and bioinformatics workflows
|
||||
- **PyDESeq2** - Differential gene expression analysis for bulk RNA-seq data
|
||||
- **Scanpy** - Single-cell RNA-seq analysis with clustering, marker genes, and UMAP/t-SNE visualization
|
||||
- **scvi-tools** - Probabilistic deep learning models for single-cell omics analysis. PyTorch-based framework providing variational autoencoders (VAEs) for dimensionality reduction, batch correction, differential expression, and data integration across modalities. Includes 25+ models: scVI/scANVI (RNA-seq integration and cell type annotation), totalVI (CITE-seq protein+RNA), MultiVI (multiome RNA+ATAC integration), PeakVI (ATAC-seq analysis), DestVI/Stereoscope/Tangram (spatial transcriptomics deconvolution), MethylVI (methylation), CytoVI (flow/mass cytometry), VeloVI (RNA velocity), contrastiveVI (perturbation studies), and Solo (doublet detection). Supports seamless integration with Scanpy/AnnData ecosystem, GPU acceleration, reference mapping (scArches), and probabilistic differential expression with uncertainty quantification
|
||||
|
||||
## Cheminformatics & Drug Discovery
|
||||
- **Datamol** - Molecular manipulation and featurization with enhanced RDKit workflows
|
||||
- **DeepChem** - Molecular machine learning, graph neural networks, and MoleculeNet benchmarks
|
||||
- **DiffDock** - Diffusion-based molecular docking for protein-ligand binding prediction
|
||||
- **MedChem** - Medicinal chemistry analysis, ADMET prediction, and drug-likeness assessment
|
||||
- **Molfeat** - 100+ molecular featurizers including fingerprints, descriptors, and pretrained models
|
||||
- **PyTDC** - Therapeutics Data Commons for drug discovery datasets and benchmarks
|
||||
- **RDKit** - Cheminformatics toolkit for molecular I/O, descriptors, fingerprints, and SMARTS
|
||||
- **TorchDrug** - PyTorch-based machine learning platform for drug discovery with 40+ datasets, 20+ GNN models for molecular property prediction, protein modeling, knowledge graph reasoning, molecular generation, and retrosynthesis planning
|
||||
|
||||
## Proteomics & Mass Spectrometry
|
||||
- **matchms** - Processing and similarity matching of mass spectrometry data with 40+ filters, spectral library matching (Cosine, Modified Cosine, Neutral Losses), metadata harmonization, molecular fingerprint comparison, and support for multiple file formats (MGF, MSP, mzML, JSON)
|
||||
- **pyOpenMS** - Comprehensive mass spectrometry data analysis for proteomics and metabolomics (LC-MS/MS processing, peptide identification, feature detection, quantification, chemical calculations, and integration with search engines like Comet, Mascot, MSGF+)
|
||||
|
||||
## Protein Engineering & Design
|
||||
- **ESM (Evolutionary Scale Modeling)** - State-of-the-art protein language models from EvolutionaryScale for protein design, structure prediction, and representation learning. Includes ESM3 (1.4B-98B parameter multimodal generative models for simultaneous reasoning across sequence, structure, and function with chain-of-thought generation, inverse folding, and function-conditioned design) and ESM C (300M-6B parameter efficient embedding models 3x faster than ESM2 for similarity analysis, classification, and feature extraction). Supports local inference with open weights and cloud-based Forge API for scalable batch processing. Use cases: novel protein design, structure prediction from sequence, sequence design from structure, protein embeddings, function annotation, variant generation, and directed evolution workflows
|
||||
|
||||
## Machine Learning & Deep Learning
|
||||
- **PyMC** - Bayesian statistical modeling and probabilistic programming
|
||||
- **PyMOO** - Multi-objective optimization with evolutionary algorithms
|
||||
- **PyTorch Lightning** - Deep learning framework that organizes PyTorch code to eliminate boilerplate while maintaining full flexibility. Automates training workflows (40+ tasks including epoch/batch iteration, optimizer steps, gradient management, checkpointing), supports multi-GPU/TPU training with DDP/FSDP/DeepSpeed strategies, includes LightningModule for model organization, Trainer for automation, LightningDataModule for data pipelines, callbacks for extensibility, and integrations with TensorBoard, Wandb, MLflow for experiment tracking
|
||||
- **scikit-learn** - Machine learning algorithms, preprocessing, and model selection
|
||||
- **SHAP** - Model interpretability and explainability using Shapley values from game theory. Provides unified approach to explain any ML model with TreeExplainer (fast exact explanations for XGBoost/LightGBM/Random Forest), DeepExplainer (TensorFlow/PyTorch neural networks), KernelExplainer (model-agnostic), and LinearExplainer. Includes comprehensive visualizations (waterfall plots for individual predictions, beeswarm plots for global importance, scatter plots for feature relationships, bar/force/heatmap plots), supports model debugging, fairness analysis, feature engineering guidance, and production deployment
|
||||
- **statsmodels** - Statistical modeling and econometrics (OLS, GLM, logit/probit, ARIMA, time series forecasting, hypothesis testing, diagnostics)
|
||||
- **Torch Geometric** - Graph Neural Networks for molecular and geometric data
|
||||
- **Transformers** - State-of-the-art machine learning models for NLP, computer vision, audio, and multimodal tasks. Provides 1M+ pre-trained models accessible via pipelines (text-classification, NER, QA, summarization, translation, text-generation, image-classification, object-detection, ASR, VQA), comprehensive training via Trainer API with distributed training and mixed precision, flexible text generation with multiple decoding strategies (greedy, beam search, sampling), and Auto classes for automatic architecture selection (BERT, GPT, T5, ViT, BART, etc.)
|
||||
- **UMAP-learn** - Dimensionality reduction and manifold learning
|
||||
|
||||
## Materials Science & Chemistry
|
||||
- **Astropy** - Astronomy and astrophysics (coordinates, cosmology, FITS files)
|
||||
- **COBRApy** - Constraint-based metabolic modeling and flux balance analysis
|
||||
- **Pymatgen** - Materials structure analysis, phase diagrams, and electronic structure
|
||||
|
||||
## Data Analysis & Visualization
|
||||
- **Dask** - Parallel computing for larger-than-memory datasets with distributed DataFrames, Arrays, Bags, and Futures
|
||||
- **Matplotlib** - Publication-quality plotting and visualization
|
||||
- **Polars** - High-performance DataFrame operations with lazy evaluation
|
||||
- **Seaborn** - Statistical data visualization with dataset-oriented interface, automatic confidence intervals, publication-quality themes, colorblind-safe palettes, and comprehensive support for exploratory analysis, distribution comparisons, correlation matrices, regression plots, and multi-panel figures
|
||||
- **ReportLab** - Programmatic PDF generation for reports and documents
|
||||
|
||||
## Phylogenetics & Trees
|
||||
- **ETE Toolkit** - Phylogenetic tree manipulation, visualization, and analysis
|
||||
|
||||
## Genomics Tools
|
||||
- **deepTools** - NGS data analysis (ChIP-seq, RNA-seq, ATAC-seq) with BAM/bigWig files
|
||||
- **FlowIO** - Flow Cytometry Standard (FCS) file reading and manipulation
|
||||
- **scikit-bio** - Bioinformatics sequence analysis and diversity metrics
|
||||
- **Zarr** - Chunked, compressed N-dimensional array storage
|
||||
|
||||
## Multi-omics & AI Agent Frameworks
|
||||
- **BIOMNI** - Autonomous biomedical AI agent framework from Stanford SNAP lab for executing complex research tasks across genomics, drug discovery, molecular biology, and clinical analysis. Combines LLM reasoning with code execution and ~11GB of integrated biomedical databases (Ensembl, NCBI Gene, UniProt, PDB, AlphaFold, ClinVar, OMIM, HPO, PubMed, KEGG, Reactome, GO). Supports multiple LLM providers (Claude, GPT-4, Gemini, Groq, Bedrock). Includes A1 agent class for autonomous task decomposition, BiomniEval1 benchmark framework, and MCP server integration. Use cases: CRISPR screening design, single-cell RNA-seq analysis, ADMET prediction, GWAS interpretation, rare disease diagnosis, protein structure analysis, literature synthesis, and multi-omics integration
|
||||
|
||||
## Scientific Communication & Publishing
|
||||
- **Paper-2-Web** - Autonomous pipeline for transforming academic papers into multiple promotional formats using the Paper2All system. Converts LaTeX or PDF papers into: (1) Paper2Web - interactive, layout-aware academic homepages with responsive design, interactive figures, and mobile support; (2) Paper2Video - professional presentation videos with slides, narration, cursor movements, and optional talking-head generation using Hallo2; (3) Paper2Poster - print-ready conference posters with custom dimensions, professional layouts, and institution branding. Supports GPT-4/GPT-4.1 models, batch processing, QR code generation, multi-language content, and quality assessment metrics. Use cases: conference materials, video abstracts, preprint enhancement, research promotion, poster sessions, and academic website creation
|
||||
|
||||
## Tool Discovery & Research Platforms
|
||||
- **ToolUniverse** - Unified ecosystem providing standardized access to 600+ scientific tools, models, datasets, and APIs across bioinformatics, cheminformatics, genomics, structural biology, and proteomics. Enables AI agents to function as research scientists through: (1) Tool Discovery - natural language, semantic, and keyword-based search for finding relevant scientific tools (Tool_Finder, Tool_Finder_LLM, Tool_Finder_Keyword); (2) Tool Execution - standardized AI-Tool Interaction Protocol for running tools with consistent interfaces; (3) Tool Composition - sequential and parallel workflow chaining for multi-step research pipelines; (4) Model Context Protocol (MCP) integration for Claude Desktop/Code. Supports drug discovery workflows (disease→targets→structures→screening→candidates), genomics analysis (expression→differential analysis→pathways), clinical genomics (variants→annotation→pathogenicity→disease associations), and cross-domain research. Use cases: accessing scientific databases (OpenTargets, PubChem, UniProt, PDB, ChEMBL, KEGG), protein structure prediction (AlphaFold), molecular docking, pathway enrichment, variant annotation, literature searches, and automated scientific workflows
|
||||
|
||||
|
||||
211
docs/scientific-skills.md
Normal file
211
docs/scientific-skills.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# Scientific Skills
|
||||
|
||||
## Scientific Databases
|
||||
|
||||
- **AlphaFold DB** - Comprehensive AI-predicted protein structure database from DeepMind providing 200M+ high-confidence protein structure predictions covering UniProt reference proteomes and beyond. Includes confidence metrics (pLDDT for per-residue confidence, PAE for pairwise accuracy estimates), structure quality assessment, predicted aligned error matrices, and multiple structure formats (PDB, mmCIF, AlphaFold DB format). Supports programmatic access via REST API, bulk downloads through Google Cloud Storage, and integration with structural analysis tools. Enables structure-based drug discovery, protein function prediction, structural genomics, comparative modeling, and structural bioinformatics research without experimental structure determination
|
||||
- **BRENDA** - World's most comprehensive enzyme information system containing detailed enzyme data from scientific literature. Query kinetic parameters (Km, kcat, Vmax), reaction equations, substrate specificities, organism information, and optimal conditions for 45,000+ enzymes with millions of kinetic data points via SOAP API. Supports enzyme discovery by substrate/product, cross-organism comparisons, environmental parameter analysis (pH, temperature optima), cofactor requirements, inhibition/activation data, and thermophilic homolog identification. Includes helper scripts for parsing BRENDA response formats, visualization of kinetic parameters, and enzymatic pathway construction. Use cases: metabolic engineering, enzyme engineering and optimization, kinetic modeling, retrosynthesis planning, industrial enzyme selection, and biochemical research requiring comprehensive enzyme kinetic data
|
||||
- **ChEMBL** - Comprehensive manually curated database of bioactive molecules with drug-like properties maintained by EMBL-EBI. Contains 2M+ unique compounds, 19M+ bioactivity measurements, 13K+ protein targets, and 1.1M+ assays from 90K+ publications. Provides detailed compound information including chemical structures (SMILES, InChI), bioactivity data (IC50, EC50, Ki, Kd values), target information (protein families, pathways), ADMET properties, drug indications, clinical trial data, and patent information. Features REST API access, web interface, downloadable data files, and integration with other databases (UniProt, PubChem, DrugBank). Use cases: drug discovery, target identification, lead optimization, bioactivity prediction, chemical biology research, and drug repurposing
|
||||
- **ClinPGx** - Clinical pharmacogenomics database (successor to PharmGKB) providing gene-drug interactions, CPIC clinical guidelines, allele functions, drug labels, and pharmacogenomic annotations for precision medicine and personalized pharmacotherapy (consolidates PharmGKB, CPIC, and PharmCAT resources)
|
||||
- **ClinVar** - NCBI's public archive of genomic variants and their clinical significance with standardized classifications (pathogenic, benign, VUS), E-utilities API access, and bulk FTP downloads for variant interpretation and precision medicine research
|
||||
- **ClinicalTrials.gov** - Comprehensive registry of clinical studies conducted worldwide (maintained by U.S. National Library of Medicine) with API v2 access for searching trials by condition, intervention, location, sponsor, study status, and phase; retrieve detailed trial information including eligibility criteria, outcomes, contacts, and locations; export to CSV/JSON formats for analysis (public API, no authentication required, ~50 req/min rate limit)
|
||||
- **COSMIC** - Catalogue of Somatic Mutations in Cancer, the world's largest database of somatic cancer mutations (millions of mutations across thousands of cancer types, Cancer Gene Census, mutational signatures, structural variants, and drug resistance data)
|
||||
- **DrugBank** - Comprehensive bioinformatics and cheminformatics database containing detailed drug and drug target information (9,591+ drug entries including 2,037 FDA-approved small molecules, 241 biotech drugs, 96 nutraceuticals, 6,000+ experimental compounds) with 200+ data fields per entry covering chemical structures (SMILES, InChI), pharmacology (mechanism of action, pharmacodynamics, ADME), drug-drug interactions, protein targets (enzymes, transporters, carriers), biological pathways, external identifiers (PubChem, ChEMBL, UniProt), and physicochemical properties for drug discovery, pharmacology research, interaction analysis, target identification, chemical similarity searches, and ADMET predictions
|
||||
- **ENA (European Nucleotide Archive)** - Comprehensive public repository for nucleotide sequence data and metadata with REST APIs for accessing sequences, assemblies, samples, studies, and reads; supports advanced search, taxonomy lookups, and bulk downloads via FTP/Aspera (rate limit: 50 req/sec)
|
||||
- **Ensembl** - Genome browser and bioinformatics database providing genomic annotations, sequences, variants, and comparative genomics data for 250+ vertebrate species (Release 115, 2025) with comprehensive REST API for gene lookups, sequence retrieval, variant effect prediction (VEP), ortholog finding, assembly mapping (GRCh37/GRCh38), and region analysis
|
||||
- **FDA Databases** - Comprehensive access to all FDA (Food and Drug Administration) regulatory databases through openFDA API covering drugs (adverse events, labeling, NDC, recalls, approvals, shortages), medical devices (adverse events, 510k clearances, PMA, UDI, classifications), foods (recalls, adverse events, allergen tracking), animal/veterinary medicines (species-specific adverse events), and substances (UNII/CAS lookup, chemical structures, molecular data) for drug safety research, pharmacovigilance, regulatory compliance, and scientific analysis
|
||||
- **GEO (Gene Expression Omnibus)** - NCBI's comprehensive public repository for high-throughput gene expression and functional genomics data. Contains 264K+ studies, 8M+ samples, and petabytes of data from microarray, RNA-seq, ChIP-seq, ATAC-seq, and other high-throughput experiments. Provides standardized data submission formats (MINIML, SOFT), programmatic access via Entrez Programming Utilities (E-utilities) and GEOquery R package, bulk FTP downloads, and web-based search and retrieval. Supports data mining, meta-analysis, differential expression analysis, and cross-study comparisons. Includes curated datasets, series records with experimental design, platform annotations, and sample metadata. Use cases: gene expression analysis, biomarker discovery, disease mechanism research, drug response studies, and functional genomics research
|
||||
- **GWAS Catalog** - NHGRI-EBI catalog of published genome-wide association studies with curated SNP-trait associations (thousands of studies, genome-wide significant associations p≤5×10⁻⁸), full summary statistics, REST API access for variant/trait/gene queries, and FTP downloads for genetic epidemiology and precision medicine research
|
||||
- **HMDB (Human Metabolome Database)** - Comprehensive metabolomics resource with 220K+ metabolite entries, detailed chemical/biological data, concentration ranges, disease associations, pathways, and spectral data for metabolite identification and biomarker discovery
|
||||
- **KEGG** - Kyoto Encyclopedia of Genes and Genomes, comprehensive database resource integrating genomic, chemical, and systemic functional information. Provides pathway databases (KEGG PATHWAY with 500+ reference pathways, metabolic pathways, signaling pathways, disease pathways), genome databases (KEGG GENES with gene catalogs from 5,000+ organisms), chemical databases (KEGG COMPOUND, KEGG DRUG, KEGG GLYCAN), and disease/drug databases (KEGG DISEASE, KEGG DRUG). Features pathway enrichment analysis, gene-to-pathway mapping, compound searches, molecular interaction networks, ortholog identification (KO - KEGG Orthology), ID conversion across databases, and visualization tools. Supports REST API access, KEGG Mapper for pathway mapping, and integration with bioinformatics tools. Use cases: pathway enrichment analysis, metabolic pathway reconstruction, drug target identification, comparative genomics, systems biology, and functional annotation of genes
|
||||
- **Metabolomics Workbench** - NIH Common Fund metabolomics data repository with 4,200+ processed studies, standardized nomenclature (RefMet), mass spectrometry searches, and comprehensive REST API for accessing metabolite structures, study metadata, experimental results, and gene/protein-metabolite associations
|
||||
- **OpenAlex** - Comprehensive open catalog of 240M+ scholarly works, authors, institutions, topics, sources, publishers, and funders. Provides complete bibliometric database for academic literature search, citation analysis, research trend tracking, author publication discovery, institution research output analysis, and open access paper identification. Features REST API with no authentication required (100k requests/day, 10 req/sec with email), advanced filtering (publication year, citations, open access status, topics, authors, institutions), aggregation/grouping capabilities, random sampling for research studies, batch ID lookups (DOI, ORCID, ROR, ISSN), and comprehensive metadata (titles, abstracts, citations, authorships, topics, funding). Supports literature reviews, bibliometric analysis, research output evaluation, citation network analysis, and academic database queries across all scientific domains
|
||||
- **Open Targets** - Comprehensive therapeutic target identification and validation platform integrating genetics, omics, and chemical data (200M+ evidence strings, target-disease associations with scoring, tractability assessments, safety liabilities, known drugs from ChEMBL, GraphQL API) for drug target discovery, prioritization, evidence evaluation, drug repurposing, competitive intelligence, and mechanism research
|
||||
- **NCBI Gene** - Comprehensive gene-specific database from NCBI providing curated information about genes from 500+ organisms. Contains gene nomenclature (official symbols, aliases, full names), genomic locations (chromosomal positions, exons, introns), sequences (genomic, mRNA, protein), gene function and phenotypes, pathways and interactions, orthologs and paralogs, variation data (SNPs, mutations), expression data, and cross-references to 200+ external databases (UniProt, Ensembl, HGNC, OMIM, Reactome). Supports programmatic access via E-utilities API (Entrez Programming Utilities) and NCBI Datasets API, bulk downloads, and web interface. Enables gene annotation, comparative genomics, variant interpretation, pathway analysis, and integration with other NCBI resources (PubMed, dbSNP, ClinVar). Use cases: gene information retrieval, variant annotation, functional genomics, disease gene discovery, and bioinformatics workflows
|
||||
- **Protein Data Bank (PDB)** - Worldwide repository for 3D structural data of proteins, nucleic acids, and biological macromolecules. Contains 200K+ experimentally determined structures from X-ray crystallography, NMR spectroscopy, and cryo-electron microscopy. Provides comprehensive structure information including atomic coordinates, experimental data, structure quality metrics, ligand binding sites, protein-protein interfaces, and metadata (authors, methods, citations). Features advanced search capabilities (by sequence, structure similarity, ligand, organism, resolution), REST API and FTP access, structure visualization tools, and integration with analysis software. Supports structure comparison, homology modeling, drug design, structural biology research, and educational use. Maintained by wwPDB consortium (RCSB PDB, PDBe, PDBj, BMRB). Use cases: structural biology research, drug discovery, protein engineering, molecular modeling, and structural bioinformatics
|
||||
- **PubChem** - World's largest free chemical information database maintained by NCBI. Contains 110M+ unique chemical compounds, 270M+ bioactivity test results, 300M+ chemical structures, and 1M+ patents. Provides comprehensive compound information including chemical structures (2D/3D structures, SMILES, InChI), physicochemical properties (molecular weight, logP, H-bond donors/acceptors), bioactivity data (assays, targets, pathways), safety and toxicity data, literature references, and vendor information. Features REST API (PUG REST, PUG SOAP, PUG View), web interface with advanced search, bulk downloads, and integration with other NCBI resources. Supports chemical similarity searches, substructure searches, property-based filtering, and cheminformatics analysis. Use cases: drug discovery, chemical biology, lead identification, ADMET prediction, chemical database mining, and molecular property analysis
|
||||
- **PubMed** - NCBI's comprehensive biomedical literature database containing 35M+ citations from MEDLINE, life science journals, and online books. Provides access to abstracts, full-text articles (when available), MeSH (Medical Subject Headings) terms, author information, publication dates, and citation networks. Features advanced search capabilities with Boolean operators, field tags (author, title, journal, MeSH terms, publication date), filters (article type, species, language, publication date range), and saved searches with email alerts. Supports programmatic access via E-utilities API (Entrez Programming Utilities), bulk downloads, citation export in multiple formats (RIS, BibTeX, MEDLINE), and integration with reference management software. Includes PubMed Central (PMC) for open-access full-text articles. Use cases: literature searches, systematic reviews, citation analysis, research discovery, and staying current with scientific publications
|
||||
- **Reactome** - Curated pathway database for biological processes and molecular interactions (2,825+ human pathways, 16K+ reactions, 11K+ proteins) with pathway enrichment analysis, expression data analysis, and species comparison using Content Service and Analysis Service APIs
|
||||
- **STRING** - Protein-protein interaction network database (5000+ genomes, 59.3M proteins, 20B+ interactions) with functional enrichment analysis, interaction partner discovery, and network visualization from experimental data, computational prediction, and text-mining
|
||||
- **UniProt** - Universal Protein Resource for protein sequences, annotations, and functional information (UniProtKB/Swiss-Prot reviewed entries, TrEMBL unreviewed entries) with REST API access for search, retrieval, ID mapping, and batch operations across 200+ databases
|
||||
- **USPTO** - United States Patent and Trademark Office data access including patent searches, trademark lookups, patent examination history (PEDS), office actions, assignments, citations, and litigation records; supports PatentSearch API (ElasticSearch-based patent search), TSDR (Trademark Status & Document Retrieval), Patent/Trademark Assignment APIs, and additional specialized APIs for comprehensive IP analysis
|
||||
- **ZINC** - Free database of commercially-available compounds for virtual screening and drug discovery maintained by UCSF. Contains 230M+ purchasable compounds from 100+ vendors in ready-to-dock 3D formats (SDF, MOL2) with pre-computed conformers. Provides compound information including chemical structures, vendor information and pricing, physicochemical properties (molecular weight, logP, H-bond donors/acceptors, rotatable bonds), drug-likeness filters (Lipinski's Rule of Five, Veber rules), and substructure search capabilities. Features multiple compound subsets (drug-like, lead-like, fragment-like, natural products), downloadable subsets for specific screening campaigns, and integration with molecular docking software (AutoDock, DOCK, Glide). Supports structure-based and ligand-based virtual screening workflows. Use cases: virtual screening campaigns, lead identification, compound library design, high-throughput docking, and drug discovery research
|
||||
- **bioRxiv** - Preprint server for the life sciences providing Python-based tools for searching and retrieving preprints. Supports comprehensive searches by keywords, authors, date ranges, and subject categories, returning structured JSON metadata including titles, abstracts, DOIs, and citation information. Features PDF downloads for full-text analysis, filtering by bioRxiv subject categories (neuroscience, bioinformatics, genomics, etc.), and integration with literature review workflows. Use cases: tracking recent preprints, conducting systematic literature reviews, analyzing research trends, monitoring publications by specific authors, and staying current with emerging research before formal peer review
|
||||
|
||||
## Scientific Integrations
|
||||
|
||||
### Laboratory Information Management Systems (LIMS) & R&D Platforms
|
||||
- **Benchling Integration** - Toolkit for integrating with Benchling's R&D platform, providing programmatic access to laboratory data management including registry entities (DNA sequences, proteins), inventory systems (samples, containers, locations), electronic lab notebooks (entries, protocols), workflows (tasks, automation), and data exports using Python SDK and REST API
|
||||
|
||||
### Cloud Platforms for Genomics & Biomedical Data
|
||||
- **DNAnexus Integration** - Comprehensive toolkit for working with the DNAnexus cloud platform for genomics and biomedical data analysis. Covers building and deploying apps/applets (Python/Bash), managing data objects (files, records, databases), running analyses and workflows, using the dxpy Python SDK, and configuring app metadata and dependencies (dxapp.json setup, system packages, Docker, assets). Enables processing of FASTQ/BAM/VCF files, bioinformatics pipelines, job execution, workflow orchestration, and platform operations including project management and permissions
|
||||
|
||||
### Laboratory Automation
|
||||
- **Opentrons Integration** - Toolkit for creating, editing, and debugging Opentrons Python Protocol API v2 protocols for laboratory automation using Flex and OT-2 robots. Enables automated liquid handling, pipetting workflows, hardware module control (thermocycler, temperature, magnetic, heater-shaker, absorbance plate reader), labware management, and complex protocol development for biological and chemical experiments
|
||||
|
||||
### Electronic Lab Notebooks (ELN)
|
||||
- **LabArchives Integration** - Toolkit for interacting with LabArchives Electronic Lab Notebook (ELN) REST API. Provides programmatic access to notebooks (backup, retrieval, management), entries (creation, comments, attachments), user authentication, site reports and analytics, and third-party integrations (Protocols.io, GraphPad Prism, SnapGene, Geneious, Jupyter, REDCap). Includes Python scripts for configuration setup, notebook operations, and entry management. Supports multi-regional API endpoints (US, UK, Australia) and OAuth authentication
|
||||
|
||||
### Workflow Platforms & Cloud Execution
|
||||
- **LatchBio Integration** - Integration with the Latch platform for building, deploying, and executing bioinformatics workflows. Provides comprehensive support for creating serverless bioinformatics pipelines using Python decorators, deploying Nextflow/Snakemake pipelines, managing cloud data (LatchFile, LatchDir) and structured Registry (Projects, Tables, Records), configuring computational resources (CPU, GPU, memory, storage), and using pre-built Latch Verified workflows (RNA-seq, AlphaFold, DESeq2, single-cell analysis, CRISPR editing). Enables automatic containerization, UI generation, workflow versioning, and execution on scalable cloud infrastructure with comprehensive data management
|
||||
|
||||
### Microscopy & Bio-image Data
|
||||
- **OMERO Integration** - Toolkit for interacting with OMERO microscopy data management systems using Python. Provides comprehensive access to microscopy images stored in OMERO servers, including dataset and screening data retrieval, pixel data analysis, annotation and metadata management, regions of interest (ROIs) creation and analysis, batch processing, OMERO.scripts development, and OMERO.tables for structured data storage. Essential for researchers working with high-content screening data, multi-dimensional microscopy datasets, or collaborative image repositories
|
||||
|
||||
### Protocol Management & Sharing
|
||||
- **Protocols.io Integration** - Integration with protocols.io API for managing scientific protocols. Enables programmatic access to protocol discovery (search by keywords, DOI, category), protocol lifecycle management (create, update, publish with DOI), step-by-step procedure documentation, collaborative development with workspaces and discussions, file management (upload data, images, documents), experiment tracking and documentation, and data export. Supports OAuth authentication, protocol PDF generation, materials management, threaded comments, workspace permissions, and institutional protocol repositories. Essential for protocol standardization, reproducibility, lab knowledge management, and scientific collaboration
|
||||
|
||||
## Scientific Packages
|
||||
|
||||
### Bioinformatics & Genomics
|
||||
- **AnnData** - Python package for handling annotated data matrices, specifically designed for single-cell genomics data. Provides efficient storage and manipulation of high-dimensional data with associated annotations (observations/cells and variables/genes). Key features include: HDF5-based h5ad file format for efficient I/O and compression, integration with pandas DataFrames for metadata, support for sparse matrices (scipy.sparse) for memory efficiency, layered data organization (X for main data matrix, obs for observation annotations, var for variable annotations, obsm/varm for multi-dimensional annotations, obsp/varp for pairwise matrices), and seamless integration with Scanpy, scvi-tools, and other single-cell analysis packages. Supports lazy loading, chunked operations, and conversion to/from other formats (CSV, HDF5, Zarr). Use cases: single-cell RNA-seq data management, multi-modal single-cell data (RNA+ATAC, CITE-seq), spatial transcriptomics, and any high-dimensional annotated data requiring efficient storage and manipulation
|
||||
- **Arboreto** - Python package for efficient gene regulatory network (GRN) inference from single-cell RNA-seq data using ensemble tree-based methods. Implements GRNBoost2 (gradient boosting-based network inference) and GENIE3 (random forest-based inference) algorithms optimized for large-scale single-cell datasets. Key features include: parallel processing for scalability, support for sparse matrices and large datasets (millions of cells), integration with Scanpy/AnnData workflows, customizable hyperparameters, and output formats compatible with network analysis tools. Provides ranked lists of potential regulatory interactions (transcription factor-target gene pairs) with confidence scores. Use cases: identifying transcription factor-target relationships, reconstructing gene regulatory networks from single-cell data, understanding cell-type-specific regulatory programs, and inferring causal relationships in gene expression
|
||||
- **BioPython** - Comprehensive Python library for computational biology and bioinformatics providing tools for sequence manipulation, database access, and biological data analysis. Key features include: sequence objects (Seq, SeqRecord, SeqIO) for DNA/RNA/protein sequences with biological alphabet validation, file format parsers (FASTA, FASTQ, GenBank, EMBL, Swiss-Prot, PDB, SAM/BAM, VCF, GFF), NCBI database access (Entrez Programming Utilities for PubMed, GenBank, BLAST, taxonomy), BLAST integration (running searches, parsing results), sequence alignment (pairwise and multiple sequence alignment with Bio.Align), phylogenetics (tree construction and manipulation with Bio.Phylo), population genetics (Hardy-Weinberg, F-statistics), protein structure analysis (PDB parsing, structure calculations), and statistical analysis tools. Supports integration with NumPy, pandas, and other scientific Python libraries. Use cases: sequence analysis, database queries, phylogenetic analysis, sequence alignment, file format conversion, and general bioinformatics workflows
|
||||
- **BioServices** - Python library providing unified programmatic access to 40+ biological web services and databases. Supports major bioinformatics resources including KEGG (pathway and compound data), UniProt (protein sequences and annotations), ChEBI (chemical entities), ChEMBL (bioactive molecules), Reactome (pathways), IntAct (protein interactions), BioModels (biological models), and many others. Features consistent API across different services, automatic result caching, error handling and retry logic, support for both REST and SOAP web services, and conversion of results to Python objects (dictionaries, lists, BioPython objects). Handles authentication, rate limiting, and API versioning. Use cases: automated data retrieval from multiple biological databases, building bioinformatics pipelines, database integration workflows, and programmatic access to biological web resources without manual web browsing
|
||||
- **Cellxgene Census** - Python package for querying and analyzing large-scale single-cell RNA-seq data from the CZ CELLxGENE Discover census. Provides access to 50M+ cells across 1,000+ datasets with standardized annotations and metadata. Key features include: efficient data access using TileDB-SOMA format for scalable queries, integration with AnnData and Scanpy for downstream analysis, cell metadata filtering and querying, gene expression retrieval, and support for both human and mouse data. Enables subsetting datasets by cell type, tissue, disease, or other metadata before downloading, reducing data transfer and memory requirements. Supports local caching and batch operations. Use cases: large-scale single-cell analysis, cell-type discovery, cross-dataset comparisons, reference dataset construction, and exploratory analysis of public single-cell data
|
||||
- **gget** - Command-line tool and Python package for efficient querying of genomic databases with a simple, unified interface. Provides fast access to Ensembl (gene information, sequences, orthologs, variants), UniProt (protein sequences and annotations), NCBI (BLAST searches, gene information), PDB (protein structures), COSMIC (cancer mutations), and other databases. Features include: single-command queries without complex API setup, automatic result formatting, batch query support, integration with pandas DataFrames, and support for both command-line and Python API usage. Optimized for speed and ease of use, making database queries accessible to users without extensive bioinformatics experience. Use cases: quick gene lookups, sequence retrieval, variant annotation, protein structure access, and rapid database queries in bioinformatics workflows
|
||||
- **geniml** - Genomic interval machine learning toolkit providing unsupervised methods for building ML models on BED files. Key capabilities include Region2Vec (word2vec-style embeddings of genomic regions and region sets using tokenization and neural language modeling), BEDspace (joint embeddings of regions and metadata labels using StarSpace for cross-modal queries), scEmbed (Region2Vec applied to single-cell ATAC-seq data generating cell-level embeddings for clustering and annotation with scanpy integration), consensus peak building (four statistical methods CC/CCF/ML/HMM for creating reference universes from BED collections), and comprehensive utilities (BBClient for BED caching, BEDshift for genomic randomization preserving context, evaluation metrics for embedding quality, Text2BedNN for neural search backends). Part of BEDbase ecosystem. Supports Python API and CLI workflows, pre-trained models on Hugging Face, and integration with gtars for tokenization. Use cases: region similarity searches, dimension reduction of chromatin accessibility data, scATAC-seq clustering and cell-type annotation, metadata-aware genomic queries, universe construction for standardized references, and any ML task requiring genomic region feature vectors
|
||||
- **gtars** - High-performance Rust toolkit for genomic interval analysis providing specialized tools for overlap detection using IGD (Integrated Genome Database) indexing, coverage track generation (uniwig module for WIG/BigWig formats), genomic tokenization for machine learning applications (TreeTokenizer for deep learning models), reference sequence management (refget protocol compliance), fragment processing for single-cell genomics (barcode-based splitting and cluster analysis), and fragment scoring against reference datasets. Offers Python bindings with NumPy integration, command-line tools (gtars-cli), and Rust library. Key modules include: tokenizers (convert genomic regions to ML tokens), overlaprs (efficient overlap computation), uniwig (ATAC-seq/ChIP-seq/RNA-seq coverage profiles), refget (GA4GH-compliant sequence digests), bbcache (BEDbase.org integration), scoring (fragment enrichment metrics), and fragsplit (single-cell fragment manipulation). Supports parallel processing, memory-mapped files, streaming for large datasets, and serves as foundation for geniml genomic ML package. Ideal for genomic ML preprocessing, regulatory element analysis, variant annotation, chromatin accessibility profiling, and computational genomics workflows
|
||||
- **pysam** - Read, write, and manipulate genomic data files (SAM/BAM/CRAM alignments, VCF/BCF variants, FASTA/FASTQ sequences) with pileup analysis, coverage calculations, and bioinformatics workflows
|
||||
- **PyDESeq2** - Python implementation of the DESeq2 differential gene expression analysis method for bulk RNA-seq data. Provides statistical methods for determining differential expression between experimental conditions using negative binomial generalized linear models. Key features include: size factor estimation for library size normalization, dispersion estimation and shrinkage, hypothesis testing with Wald test or likelihood ratio test, multiple testing correction (Benjamini-Hochberg FDR), results filtering and ranking, and integration with pandas DataFrames. Handles complex experimental designs, batch effects, and replicates. Produces fold-change estimates, p-values, and adjusted p-values for each gene. Use cases: identifying differentially expressed genes between conditions, RNA-seq experiment analysis, biomarker discovery, and gene expression studies requiring rigorous statistical analysis
|
||||
- **Scanpy** - Comprehensive Python toolkit for single-cell RNA-seq data analysis built on AnnData. Provides end-to-end workflows for preprocessing (quality control, normalization, log transformation), dimensionality reduction (PCA, UMAP, t-SNE, ForceAtlas2), clustering (Leiden, Louvain, hierarchical clustering), marker gene identification, trajectory inference (PAGA, diffusion maps), and visualization. Key features include: efficient handling of large datasets (millions of cells) using sparse matrices, integration with scvi-tools for advanced analysis, support for multi-modal data (RNA+ATAC, CITE-seq), batch correction methods, and publication-quality plotting functions. Includes extensive documentation, tutorials, and integration with other single-cell tools. Supports GPU acceleration for certain operations. Use cases: single-cell RNA-seq analysis, cell-type identification, trajectory analysis, batch correction, and comprehensive single-cell genomics workflows
|
||||
- **scvi-tools** - Probabilistic deep learning models for single-cell omics analysis. PyTorch-based framework providing variational autoencoders (VAEs) for dimensionality reduction, batch correction, differential expression, and data integration across modalities. Includes 25+ models: scVI/scANVI (RNA-seq integration and cell type annotation), totalVI (CITE-seq protein+RNA), MultiVI (multiome RNA+ATAC integration), PeakVI (ATAC-seq analysis), DestVI/Stereoscope/Tangram (spatial transcriptomics deconvolution), MethylVI (methylation), CytoVI (flow/mass cytometry), VeloVI (RNA velocity), contrastiveVI (perturbation studies), and Solo (doublet detection). Supports seamless integration with Scanpy/AnnData ecosystem, GPU acceleration, reference mapping (scArches), and probabilistic differential expression with uncertainty quantification
|
||||
|
||||
### Data Management & Infrastructure
|
||||
- **LaminDB** - Open-source data framework for biology that makes data queryable, traceable, reproducible, and FAIR (Findable, Accessible, Interoperable, Reusable). Provides unified platform combining lakehouse architecture, lineage tracking, feature stores, biological ontologies (via Bionty plugin with 20+ ontologies: genes, proteins, cell types, tissues, diseases, pathways), LIMS, and ELN capabilities through a single Python API. Key features include: automatic data lineage tracking (code, inputs, outputs, environment), versioned artifacts (DataFrame, AnnData, SpatialData, Parquet, Zarr), schema validation and data curation with standardization/synonym mapping, queryable metadata with feature-based filtering, cross-registry traversal, and streaming for large datasets. Supports integrations with workflow managers (Nextflow, Snakemake, Redun), MLOps platforms (Weights & Biases, MLflow, HuggingFace, scVI-tools), cloud storage (S3, GCS, S3-compatible), array stores (TileDB-SOMA, DuckDB), and visualization (Vitessce). Deployment options: local SQLite, cloud storage with SQLite, or cloud storage with PostgreSQL for production. Use cases: scRNA-seq standardization and analysis, flow cytometry/spatial data management, multi-modal dataset integration, computational workflow tracking with reproducibility, biological ontology-based annotation, data lakehouse construction for unified queries, ML pipeline integration with experiment tracking, and FAIR-compliant dataset publishing
|
||||
- **Modal** - Serverless cloud platform for running Python code with minimal configuration, specialized for AI/ML workloads and scientific computing. Execute functions on powerful GPUs (T4, L4, A10, A100, L40S, H100, H200, B200), scale automatically from zero to thousands of containers, and pay only for compute used. Key features include: declarative container image building with uv/pip/apt package management, automatic autoscaling with configurable limits and buffer containers, GPU acceleration with multi-GPU support (up to 8 GPUs per container), persistent storage via Volumes for model weights and datasets, secret management for API keys and credentials, scheduled jobs with cron expressions, web endpoints for deploying serverless APIs, parallel execution with `.map()` for batch processing, input concurrency for I/O-bound workloads, and resource configuration (CPU cores, memory, disk). Supports custom Docker images, integration with Hugging Face/Weights & Biases, FastAPI for web endpoints, and distributed training. Free tier includes $30/month credits. Use cases: ML model deployment and inference (LLMs, image generation, embeddings), GPU-accelerated training, batch processing large datasets in parallel, scheduled compute-intensive jobs, serverless API deployment with autoscaling, scientific computing requiring distributed compute or specialized hardware, and data pipeline automation
|
||||
|
||||
### Cheminformatics & Drug Discovery
|
||||
- **Datamol** - Python library for molecular manipulation and featurization built on RDKit with enhanced workflows and performance optimizations. Provides utilities for molecular I/O (reading/writing SMILES, SDF, MOL files), molecular standardization and sanitization, molecular transformations (tautomer enumeration, stereoisomer generation), molecular featurization (descriptors, fingerprints, graph representations), parallel processing for large datasets, and integration with machine learning pipelines. Features include: optimized RDKit operations, caching for repeated computations, molecular filtering and preprocessing, and seamless integration with pandas DataFrames. Designed for drug discovery and cheminformatics workflows requiring efficient processing of large compound libraries. Use cases: molecular preprocessing for ML models, compound library management, molecular similarity searches, and cheminformatics data pipelines
|
||||
- **DeepChem** - Deep learning framework for molecular machine learning and drug discovery built on TensorFlow and PyTorch. Provides implementations of graph neural networks (GCN, GAT, MPNN, AttentiveFP) for molecular property prediction, molecular featurization (molecular graphs, fingerprints, descriptors), pre-trained models, and MoleculeNet benchmark suite (50+ datasets for molecular property prediction, toxicity, ADMET). Key features include: support for both TensorFlow and PyTorch backends, distributed training, hyperparameter optimization, model interpretation tools, and integration with RDKit. Includes datasets for quantum chemistry, toxicity prediction, ADMET properties, and binding affinity prediction. Use cases: molecular property prediction, drug discovery, ADMET prediction, toxicity screening, and molecular machine learning research
|
||||
- **DiffDock** - State-of-the-art diffusion-based molecular docking method for predicting protein-ligand binding poses and binding affinities. Uses diffusion models to generate diverse, high-quality binding poses without requiring exhaustive search. Key features include: fast inference compared to traditional docking methods, generation of multiple diverse poses, confidence scoring for predictions, and support for flexible ligand docking. Provides pre-trained models and Python API for integration into drug discovery pipelines. Achieves superior performance on standard benchmarks (PDBbind, CASF) compared to traditional docking methods. Use cases: virtual screening, lead optimization, binding pose prediction, structure-based drug design, and initial pose generation for refinement with more expensive methods
|
||||
- **MedChem** - Python library for medicinal chemistry analysis and drug-likeness assessment. Provides tools for calculating molecular descriptors, ADMET (Absorption, Distribution, Metabolism, Excretion, Toxicity) property prediction, drug-likeness filters (Lipinski's Rule of Five, Veber rules, Egan rules, Muegge rules), molecular complexity metrics, and synthetic accessibility scoring. Features include: integration with RDKit, parallel processing for large datasets, and comprehensive property calculators. Supports filtering compound libraries based on drug-like properties, identifying potential ADMET issues early in drug discovery, and prioritizing compounds for further development. Use cases: lead optimization, compound library filtering, ADMET prediction, drug-likeness assessment, and medicinal chemistry analysis in drug discovery workflows
|
||||
- **Molfeat** - Comprehensive Python library providing 100+ molecular featurizers for converting molecules into numerical representations suitable for machine learning. Includes molecular fingerprints (ECFP, MACCS, RDKit, Pharmacophore), molecular descriptors (2D/3D descriptors, constitutional, topological, electronic), graph-based representations (molecular graphs, line graphs), and pre-trained models (MolBERT, ChemBERTa, Uni-Mol embeddings). Features unified API across different featurizer types, caching for performance, parallel processing, and integration with popular ML frameworks (scikit-learn, PyTorch, TensorFlow). Supports both traditional cheminformatics descriptors and modern learned representations. Use cases: molecular property prediction, virtual screening, molecular similarity searches, and preparing molecular data for machine learning models
|
||||
- **PyTDC** - Python library providing access to Therapeutics Data Commons (TDC), a collection of curated datasets and benchmarks for drug discovery and development. Includes datasets for ADMET prediction (absorption, distribution, metabolism, excretion, toxicity), drug-target interactions, drug-drug interactions, drug response prediction, molecular generation, and retrosynthesis. Features standardized data formats, data loaders with automatic preprocessing, benchmark tasks with evaluation metrics, leaderboards for model comparison, and integration with popular ML frameworks. Provides both single-molecule and drug-pair datasets, covering various stages of drug discovery from target identification to clinical outcomes. Use cases: benchmarking ML models for drug discovery, ADMET prediction model development, drug-target interaction prediction, and drug discovery research
|
||||
- **RDKit** - Open-source cheminformatics toolkit for molecular informatics and drug discovery. Provides comprehensive functionality for molecular I/O (reading/writing SMILES, SDF, MOL, PDB files), molecular descriptors (200+ 2D and 3D descriptors), molecular fingerprints (Morgan, RDKit, MACCS, topological torsions), SMARTS pattern matching for substructure searches, molecular alignment and 3D coordinate generation, pharmacophore perception, reaction handling, and molecular drawing. Features high-performance C++ core with Python bindings, support for large molecule sets, and extensive documentation. Widely used in pharmaceutical industry and academic research. Use cases: molecular property calculation, virtual screening, molecular similarity searches, substructure matching, molecular visualization, and general cheminformatics workflows
|
||||
- **TorchDrug** - PyTorch-based machine learning platform for drug discovery with 40+ datasets, 20+ GNN models for molecular property prediction, protein modeling, knowledge graph reasoning, molecular generation, and retrosynthesis planning
|
||||
|
||||
### Proteomics & Mass Spectrometry
|
||||
- **matchms** - Processing and similarity matching of mass spectrometry data with 40+ filters, spectral library matching (Cosine, Modified Cosine, Neutral Losses), metadata harmonization, molecular fingerprint comparison, and support for multiple file formats (MGF, MSP, mzML, JSON)
|
||||
- **pyOpenMS** - Comprehensive mass spectrometry data analysis for proteomics and metabolomics (LC-MS/MS processing, peptide identification, feature detection, quantification, chemical calculations, and integration with search engines like Comet, Mascot, MSGF+)
|
||||
|
||||
### Medical Imaging & Digital Pathology
|
||||
- **histolab** - Digital pathology toolkit for whole slide image (WSI) processing and analysis. Provides automated tissue detection, tile extraction for deep learning pipelines, and preprocessing for gigapixel histopathology images. Key features include: multi-format WSI support (SVS, TIFF, NDPI), three tile extraction strategies (RandomTiler for sampling, GridTiler for complete coverage, ScoreTiler for quality-driven selection), automated tissue masks with customizable filters, built-in scorers (NucleiScorer, CellularityScorer), pyramidal image handling, visualization tools (thumbnails, mask overlays, tile previews), and H&E stain decomposition. Supports multiple tissue sections, artifact removal, pen annotation exclusion, and reproducible extraction with seeding. Use cases: creating training datasets for computational pathology, extracting informative tiles for tumor classification, whole-slide tissue characterization, quality assessment of histology samples, automated nuclei density analysis, and preprocessing for digital pathology deep learning workflows
|
||||
- **PathML** - Comprehensive computational pathology toolkit for whole slide image analysis, tissue segmentation, and machine learning on pathology data. Provides end-to-end workflows for digital pathology research including data loading, preprocessing, feature extraction, and model deployment
|
||||
- **pydicom** - Pure Python package for working with DICOM (Digital Imaging and Communications in Medicine) files. Provides comprehensive support for reading, writing, and manipulating medical imaging data from CT, MRI, X-ray, ultrasound, PET scans and other modalities. Key features include: pixel data extraction and manipulation with automatic decompression (JPEG/JPEG 2000/RLE), metadata access and modification with 1000+ standardized DICOM tags, image format conversion (PNG/JPEG/TIFF), anonymization tools for removing Protected Health Information (PHI), windowing and display transformations (VOI LUT application), multi-frame and 3D volume processing, DICOM sequence handling, and support for multiple transfer syntaxes. Use cases: medical image analysis, PACS system integration, radiology workflows, research data processing, DICOM anonymization, format conversion, image preprocessing for machine learning, multi-slice volume reconstruction, and clinical imaging pipelines
|
||||
|
||||
### Healthcare AI & Clinical Machine Learning
|
||||
- **NeuroKit2** - Comprehensive biosignal processing toolkit for analyzing physiological data including ECG, EEG, EDA, RSP, PPG, EMG, and EOG signals. Use this skill when processing cardiovascular signals, brain activity, electrodermal responses, respiratory patterns, muscle activity, or eye movements. Key features include: automated signal processing pipelines (cleaning, peak detection, delineation, quality assessment), heart rate variability analysis across time/frequency/nonlinear domains (SDNN, RMSSD, LF/HF, DFA, entropy measures), EEG analysis (frequency band power, microstates, source localization), autonomic nervous system assessment (sympathetic indices, respiratory sinus arrhythmia), comprehensive complexity measures (25+ entropy types, 15+ fractal dimensions, Lyapunov exponents), event-related and interval-related analysis modes, epoch creation and averaging for stimulus-locked responses, multi-signal integration with unified workflows, and extensive signal processing utilities (filtering, decomposition, peak correction, spectral analysis). Includes modular reference documentation across 12 specialized domains. Use cases: heart rate variability for cardiovascular health assessment, EEG microstates for consciousness studies, electrodermal activity for emotion research, respiratory variability analysis, psychophysiology experiments, affective computing, stress monitoring, sleep staging, autonomic dysfunction assessment, biofeedback applications, and multi-modal physiological signal integration for comprehensive human state monitoring
|
||||
- **PyHealth** - Comprehensive healthcare AI toolkit for developing, testing, and deploying machine learning models with clinical data. Provides specialized tools for electronic health records (EHR), physiological signals, medical imaging, and clinical text analysis. Key features include: 10+ healthcare datasets (MIMIC-III/IV, eICU, OMOP, sleep EEG, COVID-19 CXR), 20+ predefined clinical prediction tasks (mortality, hospital readmission, length of stay, drug recommendation, sleep staging, EEG analysis), 33+ models (Logistic Regression, MLP, CNN, RNN, Transformer, GNN, plus healthcare-specific models like RETAIN, SafeDrug, GAMENet, StageNet), comprehensive data processing (sequence processors, signal processors, medical code translation between ICD-9/10, NDC, RxNorm, ATC systems), training/evaluation utilities (Trainer class, fairness metrics, calibration, uncertainty quantification), and interpretability tools (attention visualization, SHAP, ChEFER). 3x faster than pandas for healthcare data processing. Use cases: ICU mortality prediction, hospital readmission risk assessment, safe medication recommendation with drug-drug interaction constraints, sleep disorder diagnosis from EEG signals, medical code standardization and translation, clinical text to ICD coding, length of stay estimation, and any clinical ML application requiring interpretability, fairness assessment, and calibrated predictions for healthcare deployment
|
||||
|
||||
### Clinical Documentation & Decision Support
|
||||
- **Clinical Decision Support** - Generate professional clinical decision support (CDS) documents for pharmaceutical and clinical research settings. Includes patient cohort analyses (biomarker-stratified with outcomes) and treatment recommendation reports (evidence-based guidelines with decision algorithms). Features GRADE evidence grading, statistical analysis (hazard ratios, survival curves, waterfall plots), biomarker integration (genomic alterations, gene expression signatures, IHC markers), and regulatory compliance. Use cases: pharmaceutical cohort reporting, clinical guideline development, comparative effectiveness analyses, treatment algorithm creation, and evidence synthesis for drug development
|
||||
- **Clinical Reports** - Write comprehensive clinical reports following established guidelines and standards. Covers case reports (CARE guidelines), diagnostic reports (radiology, pathology, laboratory), clinical trial reports (ICH-E3, SAE, CSR), and patient documentation (SOAP notes, H&P, discharge summaries). Includes templates, regulatory compliance (HIPAA, FDA, ICH-GCP), and validation tools. Use cases: journal case reports, diagnostic findings documentation, clinical trial reporting, patient progress notes, and regulatory submissions
|
||||
- **Treatment Plans** - Generate concise (3-4 page), focused medical treatment plans in LaTeX/PDF format for all clinical specialties. Supports general medical treatment, rehabilitation therapy, mental health care, chronic disease management, perioperative care, and pain management. Features SMART goal frameworks, evidence-based interventions, HIPAA compliance, and professional formatting. Use cases: individualized patient care plans, rehabilitation programs, psychiatric treatment plans, surgical care pathways, and pain management protocols
|
||||
|
||||
### Neuroscience & Electrophysiology
|
||||
- **Neuropixels-Analysis** - Comprehensive toolkit for analyzing Neuropixels high-density neural recordings using SpikeInterface, Allen Institute, and International Brain Laboratory (IBL) best practices. Supports the full workflow from raw data to publication-ready curated units. Key features include: data loading from SpikeGLX, Open Ephys, and NWB formats, preprocessing pipelines (highpass filtering, phase shift correction for Neuropixels 1.0, bad channel detection, common average referencing), motion/drift estimation and correction (kilosort_like and nonrigid_accurate presets), spike sorting integration (Kilosort4 GPU, SpykingCircus2, Mountainsort5 CPU), comprehensive postprocessing (waveform extraction, template computation, spike amplitudes, correlograms, unit locations), quality metrics computation (SNR, ISI violations, presence ratio, amplitude cutoff, drift metrics), automated curation using Allen Institute and IBL criteria with configurable thresholds, AI-assisted visual curation for uncertain units using Claude API, and export to Phy for manual review or NWB for sharing. Supports Neuropixels 1.0 (960 electrodes, 384 channels) and Neuropixels 2.0 (single and 4-shank configurations). Use cases: extracellular electrophysiology analysis, spike sorting from silicon probes, neural population recordings, systems neuroscience research, unit quality assessment, publication-ready neural data processing, and integration of AI-assisted curation for borderline units
|
||||
|
||||
### Protein Engineering & Design
|
||||
- **Adaptyv** - Cloud laboratory platform for automated protein testing and validation. Submit protein sequences via API or web interface and receive experimental results in approximately 21 days. Supports multiple assay types including binding assays (biolayer interferometry for protein-target interactions, KD/kon/koff measurements), expression testing (quantify protein expression levels in E. coli, mammalian, yeast, or insect cells), thermostability measurements (DSF and CD for Tm determination and thermal stability profiling), and enzyme activity assays (kinetic parameters, substrate specificity, inhibitor testing). Includes computational optimization tools for pre-screening sequences: NetSolP/SoluProt for solubility prediction, SolubleMPNN for sequence redesign to improve expression, ESM for sequence likelihood scoring, ipTM (AlphaFold-Multimer) for interface stability assessment, and pSAE for aggregation risk quantification. Platform features automated workflows from expression through purification to assay execution with quality control, webhook notifications for experiment completion, batch submission support for high-throughput screening, and comprehensive results with kinetic parameters, confidence metrics, and raw data access. Use cases: antibody affinity maturation, therapeutic protein developability assessment, enzyme engineering and optimization, protein stability improvement, AI-driven protein design validation, library screening for expression and function, lead optimization with experimental feedback, and integration of computational design with wet-lab validation in iterative design-build-test-learn cycles
|
||||
- **ESM (Evolutionary Scale Modeling)** - State-of-the-art protein language models from EvolutionaryScale for protein design, structure prediction, and representation learning. Includes ESM3 (1.4B-98B parameter multimodal generative models for simultaneous reasoning across sequence, structure, and function with chain-of-thought generation, inverse folding, and function-conditioned design) and ESM C (300M-6B parameter efficient embedding models 3x faster than ESM2 for similarity analysis, classification, and feature extraction). Supports local inference with open weights and cloud-based Forge API for scalable batch processing. Use cases: novel protein design, structure prediction from sequence, sequence design from structure, protein embeddings, function annotation, variant generation, and directed evolution workflows
|
||||
|
||||
### Machine Learning & Deep Learning
|
||||
- **aeon** - Comprehensive scikit-learn compatible Python toolkit for time series machine learning providing state-of-the-art algorithms across 7 domains: classification (13 algorithm categories including ROCKET variants, deep learning with InceptionTime/ResNet/FCN, distance-based with DTW/ERP/LCSS, shapelet-based, dictionary methods like BOSS/WEASEL, and hybrid ensembles HIVECOTE), regression (9 categories mirroring classification approaches), clustering (k-means/k-medoids with temporal distances, deep learning autoencoders, spectral methods), forecasting (ARIMA, ETS, Theta, Threshold Autoregressive, TCN, DeepAR), anomaly detection (STOMP/MERLIN matrix profile, clustering-based CBLOF/KMeans, isolation methods, copula-based COPOD), segmentation (ClaSP, FLUSS, HMM, binary segmentation), and similarity search (MASS algorithm, STOMP motif discovery, approximate nearest neighbors). Includes 40+ distance metrics (elastic: DTW/DDTW/WDTW/Shape-DTW, edit-based: ERP/EDR/LCSS/TWE/MSM, lock-step: Euclidean/Manhattan), extensive transformations (ROCKET/MiniRocket/MultiRocket for features, Catch22/TSFresh for statistics, SAX/PAA for symbolic representation, shapelet transforms, wavelets, matrix profile), 20+ deep learning architectures (FCN, ResNet, InceptionTime, TCN, autoencoders with attention mechanisms), comprehensive benchmarking tools (UCR/UEA archives with 100+ datasets, published results repository, statistical testing), and performance-optimized implementations using numba. Features progressive model complexity from fast baselines (MiniRocket: <1 second training, 0.95+ accuracy on many benchmarks) to state-of-the-art ensembles (HIVECOTE V2), GPU acceleration support, and extensive visualization utilities. Use cases: physiological signal classification (ECG, EEG), industrial sensor monitoring, financial forecasting, change point detection, pattern discovery, activity recognition from wearables, predictive maintenance, climate time series analysis, and any sequential data requiring specialized temporal modeling beyond standard ML
|
||||
- **PufferLib** - High-performance reinforcement learning library achieving 1M-4M steps/second through optimized vectorization, native multi-agent support, and efficient PPO training (PuffeRL). Use this skill for RL training on any environment (Gymnasium, PettingZoo, Atari, Procgen), creating custom PufferEnv environments, developing policies (CNN, LSTM, multi-input architectures), optimizing parallel simulation performance, or scaling multi-agent systems. Includes Ocean suite (20+ environments), seamless framework integration with automatic space flattening, zero-copy vectorization with shared memory buffers, distributed training support, and comprehensive reference guides for training workflows, environment development, vectorization optimization, policy architectures, and third-party integrations
|
||||
- **PyMC** - Comprehensive Python library for Bayesian statistical modeling and probabilistic programming. Provides intuitive syntax for building probabilistic models, advanced MCMC sampling algorithms (NUTS, Metropolis-Hastings, Slice sampling), variational inference methods (ADVI, SVGD), Gaussian processes, time series models (ARIMA, state space models), and model comparison tools (WAIC, LOO). Features include: automatic differentiation via Aesara (formerly Theano), GPU acceleration support, parallel sampling, model diagnostics and convergence checking, and integration with ArviZ for visualization and analysis. Supports hierarchical models, mixture models, survival analysis, and custom distributions. Use cases: Bayesian data analysis, uncertainty quantification, A/B testing, time series forecasting, hierarchical modeling, and probabilistic machine learning
|
||||
- **PyMOO** - Python framework for multi-objective optimization using evolutionary algorithms. Provides implementations of state-of-the-art algorithms including NSGA-II, NSGA-III, MOEA/D, SPEA2, and reference-point based methods. Features include: support for constrained and unconstrained optimization, multiple problem types (continuous, discrete, mixed-variable), performance indicators (hypervolume, IGD, GD), visualization tools (Pareto front plots, convergence plots), and parallel evaluation support. Supports custom problem definitions, algorithm configuration, and result analysis. Designed for engineering design, parameter optimization, and any problem requiring optimization of multiple conflicting objectives simultaneously. Use cases: multi-objective optimization problems, Pareto-optimal solution finding, engineering design optimization, and research in evolutionary computation
|
||||
- **PyTorch Lightning** - Deep learning framework that organizes PyTorch code to eliminate boilerplate while maintaining full flexibility. Automates training workflows (40+ tasks including epoch/batch iteration, optimizer steps, gradient management, checkpointing), supports multi-GPU/TPU training with DDP/FSDP/DeepSpeed strategies, includes LightningModule for model organization, Trainer for automation, LightningDataModule for data pipelines, callbacks for extensibility, and integrations with TensorBoard, Wandb, MLflow for experiment tracking
|
||||
- **PennyLane** - Cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Enables building and training quantum circuits with automatic differentiation, seamless integration with PyTorch/JAX/NumPy, and device-independent execution across simulators and quantum hardware (IBM, Amazon Braket, Google, Rigetti, IonQ). Key features include: quantum circuit construction with QNodes (quantum functions with automatic differentiation), 100+ quantum gates and operations (Pauli, Hadamard, rotation, controlled gates), circuit templates and layers for common ansatze (StronglyEntanglingLayers, BasicEntanglerLayers, UCCSD for chemistry), gradient computation methods (parameter-shift rule for hardware, backpropagation for simulators, adjoint differentiation), quantum chemistry module (molecular Hamiltonian construction, VQE for ground state energy, differentiable Hartree-Fock solver), ML framework integration (TorchLayer for PyTorch models, JAX transformations, TensorFlow deprecated), built-in optimizers (Adam, GradientDescent, QNG, Rotosolve), measurement types (expectation values, probabilities, samples, state vectors), device ecosystem (default.qubit simulator, lightning.qubit for performance, hardware plugins for IBM/Braket/Cirq/Rigetti/IonQ), and Catalyst for just-in-time compilation with adaptive circuits. Supports variational quantum algorithms (VQE, QAOA), quantum neural networks, hybrid quantum-classical models, data encoding strategies (angle, amplitude, IQP embeddings), and pulse-level programming. Use cases: variational quantum eigensolver for molecular simulations, quantum circuit machine learning with gradient-based optimization, hybrid quantum-classical neural networks, quantum chemistry calculations with differentiable workflows, quantum algorithm prototyping with hardware-agnostic code, quantum machine learning research with automatic differentiation, and deploying quantum circuits across multiple quantum computing platforms
|
||||
- **Qiskit** - World's most popular open-source quantum computing framework for building, optimizing, and executing quantum circuits with 13M+ downloads and 74% developer preference. Provides comprehensive tools for quantum algorithm development including circuit construction with 100+ quantum gates (Pauli, Hadamard, CNOT, rotation gates, controlled gates), circuit transpilation with 83x faster optimization than competitors producing circuits with 29% fewer two-qubit gates, primitives for execution (Sampler for bitstring measurements and probability distributions, Estimator for expectation values and observables), visualization tools (circuit diagrams in matplotlib/LaTeX, result histograms, Bloch sphere, state visualizations), backend-agnostic execution (local simulators including StatevectorSampler and Aer, IBM Quantum hardware with 100+ qubit systems, IonQ trapped ion, Amazon Braket multi-provider), session and batch modes for iterative and parallel workloads, error mitigation with configurable resilience levels (readout error correction, ZNE, PEC reducing sampling overhead by 100x), four-step patterns workflow (Map classical problems to quantum circuits, Optimize through transpilation, Execute with primitives, Post-process results), algorithm libraries including Qiskit Nature for quantum chemistry (molecular Hamiltonians, VQE for ground states, UCCSD ansatz, multiple fermion-to-qubit mappings), Qiskit Optimization for combinatorial problems (QAOA, portfolio optimization, MaxCut), and Qiskit Machine Learning (quantum kernels, VQC, QNN), support for Python/C/Rust with modular architecture, parameterized circuits for variational algorithms, quantum Fourier transform, Grover search, Shor's algorithm, pulse-level control, IBM Quantum Runtime for cloud execution with job management and queuing, and comprehensive documentation with textbook and tutorials. Use cases: variational quantum eigensolver for molecular ground state energy, QAOA for combinatorial optimization problems, quantum chemistry simulations with multiple ansatze and mappings, quantum machine learning with kernel methods and neural networks, hybrid quantum-classical algorithms, quantum algorithm research and prototyping across multiple hardware platforms, quantum circuit optimization and benchmarking, quantum error mitigation and characterization, quantum information science experiments, and production quantum computing workflows on real quantum hardware
|
||||
- **QuTiP** - Quantum Toolbox in Python for simulating and analyzing quantum mechanical systems. Provides comprehensive tools for both closed (unitary) and open (dissipative) quantum systems including quantum states (kets, bras, density matrices, Fock states, coherent states), quantum operators (creation/annihilation operators, Pauli matrices, angular momentum operators, quantum gates), time evolution solvers (Schrödinger equation with sesolve, Lindblad master equation with mesolve, quantum trajectories with Monte Carlo mcsolve, Bloch-Redfield brmesolve, Floquet methods for periodic Hamiltonians), analysis tools (expectation values, entropy measures, fidelity, concurrence, correlation functions, steady state calculations), visualization (Bloch sphere with animations, Wigner functions, Q-functions, Fock distributions, matrix histograms), and advanced methods (Hierarchical Equations of Motion for non-Markovian dynamics, permutational invariance for identical particles, stochastic solvers, superoperators). Supports tensor products for composite systems, partial traces, time-dependent Hamiltonians, multiple dissipation channels, and parallel processing. Includes extensive documentation, tutorials, and examples. Use cases: quantum optics simulations (cavity QED, photon statistics), quantum computing (gate operations, circuit dynamics), open quantum systems (decoherence, dissipation), quantum information theory (entanglement dynamics, quantum channels), condensed matter physics (spin chains, many-body systems), and general quantum mechanics research and education
|
||||
- **scikit-learn** - Industry-standard Python library for classical machine learning providing comprehensive supervised learning (classification: Logistic Regression, SVM, Decision Trees, Random Forests with 17+ variants, Gradient Boosting with XGBoost-compatible HistGradientBoosting, Naive Bayes, KNN, Neural Networks/MLP; regression: Linear, Ridge, Lasso, ElasticNet, SVR, ensemble methods), unsupervised learning (clustering: K-Means, DBSCAN, HDBSCAN, OPTICS, Agglomerative/Hierarchical, Spectral, Gaussian Mixture Models, BIRCH, MeanShift; dimensionality reduction: PCA, Kernel PCA, t-SNE, Isomap, LLE, NMF, TruncatedSVD, FastICA, LDA; outlier detection: IsolationForest, LocalOutlierFactor, OneClassSVM), data preprocessing (scaling: StandardScaler, MinMaxScaler, RobustScaler; encoding: OneHotEncoder, OrdinalEncoder, LabelEncoder; imputation: SimpleImputer, KNNImputer, IterativeImputer; feature engineering: PolynomialFeatures, KBinsDiscretizer, text vectorization with CountVectorizer/TfidfVectorizer), model evaluation (cross-validation: KFold, StratifiedKFold, TimeSeriesSplit, GroupKFold; hyperparameter tuning: GridSearchCV, RandomizedSearchCV, HalvingGridSearchCV; metrics: 30+ evaluation metrics for classification/regression/clustering including accuracy, precision, recall, F1, ROC-AUC, MSE, R², silhouette score), and Pipeline/ColumnTransformer for production-ready workflows. Features consistent API (fit/predict/transform), extensive documentation, integration with NumPy/pandas/SciPy, joblib persistence, and scikit-learn-compatible ecosystem (XGBoost, LightGBM, CatBoost, imbalanced-learn). Optimized implementations using Cython/OpenMP for performance. Use cases: predictive modeling, customer segmentation, anomaly detection, feature engineering, model selection/validation, text classification, image classification (with feature extraction), time series forecasting (with preprocessing), medical diagnosis, fraud detection, recommendation systems, and any tabular data ML task requiring interpretable models or established algorithms
|
||||
- **scikit-survival** - Survival analysis and time-to-event modeling with censored data. Built on scikit-learn, provides Cox proportional hazards models (CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis with elastic net regularization), ensemble methods (Random Survival Forests, Gradient Boosting), Survival Support Vector Machines (linear and kernel), non-parametric estimators (Kaplan-Meier, Nelson-Aalen), competing risks analysis, and specialized evaluation metrics (concordance index, time-dependent AUC, Brier score). Handles right-censored data, integrates with scikit-learn pipelines, and supports feature selection and hyperparameter tuning via cross-validation
|
||||
- **SHAP** - Model interpretability and explainability using Shapley values from game theory. Provides unified approach to explain any ML model with TreeExplainer (fast exact explanations for XGBoost/LightGBM/Random Forest), DeepExplainer (TensorFlow/PyTorch neural networks), KernelExplainer (model-agnostic), and LinearExplainer. Includes comprehensive visualizations (waterfall plots for individual predictions, beeswarm plots for global importance, scatter plots for feature relationships, bar/force/heatmap plots), supports model debugging, fairness analysis, feature engineering guidance, and production deployment
|
||||
- **Stable Baselines3** - PyTorch-based reinforcement learning library providing reliable implementations of RL algorithms (PPO, SAC, DQN, TD3, DDPG, A2C, HER, RecurrentPPO). Use this skill for training RL agents on standard or custom Gymnasium environments, implementing callbacks for monitoring and control, using vectorized environments for parallel training, creating custom environments with proper Gymnasium API implementation, and integrating with deep RL workflows. Includes comprehensive training templates, evaluation utilities, algorithm selection guidance (on-policy vs off-policy, continuous vs discrete actions), support for multi-input policies (dict observations), goal-conditioned learning with HER, and integration with TensorBoard for experiment tracking
|
||||
- **statsmodels** - Statistical modeling and econometrics (OLS, GLM, logit/probit, ARIMA, time series forecasting, hypothesis testing, diagnostics)
|
||||
- **Torch Geometric** - Graph Neural Networks for molecular and geometric data
|
||||
- **Transformers** - State-of-the-art machine learning models for NLP, computer vision, audio, and multimodal tasks. Provides 1M+ pre-trained models accessible via pipelines (text-classification, NER, QA, summarization, translation, text-generation, image-classification, object-detection, ASR, VQA), comprehensive training via Trainer API with distributed training and mixed precision, flexible text generation with multiple decoding strategies (greedy, beam search, sampling), and Auto classes for automatic architecture selection (BERT, GPT, T5, ViT, BART, etc.)
|
||||
- **UMAP-learn** - Python implementation of Uniform Manifold Approximation and Projection (UMAP) for dimensionality reduction and manifold learning. Provides fast, scalable nonlinear dimensionality reduction that preserves both local and global structure of high-dimensional data. Key features include: support for both supervised and unsupervised dimensionality reduction, ability to handle mixed data types, integration with scikit-learn API, and efficient implementation using numba for performance. Produces low-dimensional embeddings (typically 2D or 3D) suitable for visualization and downstream analysis. Often outperforms t-SNE in preserving global structure while maintaining local neighborhoods. Use cases: data visualization, feature extraction, preprocessing for machine learning, single-cell data analysis, and exploratory data analysis of high-dimensional datasets
|
||||
|
||||
### Materials Science & Chemistry
|
||||
- **Astropy** - Comprehensive Python library for astronomy and astrophysics providing core functionality for astronomical research and data analysis. Includes coordinate system transformations (ICRS, Galactic, FK5, AltAz), physical units and quantities with automatic dimensional consistency, FITS file operations (reading, writing, manipulating headers and data), cosmological calculations (luminosity distance, lookback time, Hubble parameter, Planck/WMAP models), precise time handling across multiple time scales (UTC, TAI, TT, TDB) and formats (JD, MJD, ISO), table operations with unit support (FITS, CSV, HDF5, VOTable), WCS transformations between pixel and world coordinates, astronomical constants, modeling framework, visualization tools, and statistical functions. Use for celestial coordinate transformations, unit conversions, FITS image/table processing, cosmological distance calculations, barycentric time corrections, catalog cross-matching, and astronomical data analysis
|
||||
- **COBRApy** - Python package for constraint-based reconstruction and analysis (COBRA) of metabolic networks. Provides tools for building, manipulating, and analyzing genome-scale metabolic models (GEMs). Key features include: flux balance analysis (FBA) for predicting optimal metabolic fluxes, flux variability analysis (FVA), gene knockout simulations, pathway analysis, model validation, and integration with other COBRA Toolbox formats (SBML, JSON). Supports various optimization objectives (biomass production, ATP production, metabolite production), constraint handling (reaction bounds, gene-protein-reaction associations), and model comparison. Includes utilities for model construction, gap filling, and model refinement. Use cases: metabolic engineering, systems biology, biotechnology applications, understanding cellular metabolism, and predicting metabolic phenotypes
|
||||
- **Pymatgen** - Python Materials Genomics (pymatgen) library for materials science computation and analysis. Provides comprehensive tools for crystal structure manipulation, phase diagram construction, electronic structure analysis, and materials property calculations. Key features include: structure objects with symmetry analysis, space group determination, structure matching and comparison, phase diagram generation from formation energies, band structure and density of states analysis, defect calculations, surface and interface analysis, and integration with DFT codes (VASP, Quantum ESPRESSO, ABINIT). Supports Materials Project database integration, structure file I/O (CIF, POSCAR, VASP), and high-throughput materials screening workflows. Use cases: materials discovery, crystal structure analysis, phase stability prediction, electronic structure calculations, and computational materials science research
|
||||
|
||||
### Engineering & Simulation
|
||||
- **FluidSim** - Object-oriented Python framework for high-performance computational fluid dynamics (CFD) simulations using pseudospectral methods with FFT. Provides solvers for periodic-domain equations including 2D/3D incompressible Navier-Stokes equations (with/without stratification), shallow water equations, and Föppl-von Kármán elastic plate equations. Key features include: Pythran/Transonic compilation for performance comparable to Fortran/C++, MPI parallelization for large-scale simulations, hierarchical parameter configuration with type safety, comprehensive output management (physical fields in HDF5, spatial means, energy/enstrophy spectra, spectral energy budgets), custom forcing mechanisms (time-correlated random forcing, proportional forcing, script-defined forcing), flexible initial conditions (noise, vortex, dipole, Taylor-Green, from file, in-script), online and offline visualization, and integration with ParaView/VisIt for 3D visualization. Supports workflow features including simulation restart/continuation, parametric studies with batch execution, cluster submission integration, and adaptive CFL-based time stepping. Use cases: 2D/3D turbulence studies with energy cascade analysis, stratified oceanic and atmospheric flows with buoyancy effects, geophysical flows with rotation (Coriolis effects), vortex dynamics and fundamental fluid mechanics research, high-resolution direct numerical simulation (DNS), parametric studies exploring parameter spaces, validation studies (Taylor-Green vortex), and any periodic-domain fluid dynamics research requiring HPC-grade performance with Python flexibility
|
||||
|
||||
### Data Analysis & Visualization
|
||||
- **Dask** - Parallel computing for larger-than-memory datasets with distributed DataFrames, Arrays, Bags, and Futures
|
||||
- **Data Commons** - Programmatic access to public statistical data from global sources including census bureaus, health organizations, and environmental agencies. Provides unified Python API for querying demographic data, economic indicators, health statistics, and environmental datasets through a knowledge graph interface. Features three main endpoints: Observation (statistical time-series queries for population, GDP, unemployment rates, disease prevalence), Node (knowledge graph exploration for entity relationships and hierarchies), and Resolve (entity identification from names, coordinates, or Wikidata IDs). Seamless Pandas integration for DataFrames, relation expressions for hierarchical queries, data source filtering for consistency, and support for custom Data Commons instances
|
||||
- **GeoPandas** - Python library extending pandas for working with geospatial vector data including shapefiles, GeoJSON, and GeoPackage files. Provides GeoDataFrame and GeoSeries data structures combining geometric data with tabular attributes for spatial analysis. Key features include: reading/writing spatial file formats (Shapefile, GeoJSON, GeoPackage, PostGIS, Parquet) with Arrow acceleration for 2-4x faster I/O, geometric operations (buffer, simplify, centroid, convex hull, affine transformations) through Shapely integration, spatial analysis (spatial joins with predicates like intersects/contains/within, nearest neighbor joins, overlay operations for union/intersection/difference, dissolve for aggregation, clipping), coordinate reference system (CRS) management (setting CRS, reprojecting between coordinate systems, UTM estimation), and visualization (static choropleth maps with matplotlib, interactive maps with folium, multi-layer mapping, classification schemes with mapclassify). Supports spatial indexing for performance, filtering during read operations (bbox, mask, SQL WHERE), and integration with cartopy for cartographic projections. Use cases: spatial data manipulation, buffer analysis, spatial joins between datasets, dissolving boundaries, calculating areas/distances in projected CRS, reprojecting coordinate systems, creating choropleth maps, converting between spatial file formats, PostGIS database integration, and geospatial data analysis workflows
|
||||
- **Matplotlib** - Comprehensive Python plotting library for creating publication-quality static, animated, and interactive visualizations. Provides extensive customization options for creating figures, subplots, axes, and annotations. Key features include: support for multiple plot types (line, scatter, bar, histogram, contour, 3D, and many more), extensive customization (colors, fonts, styles, layouts), multiple backends (PNG, PDF, SVG, interactive backends), LaTeX integration for mathematical notation, and integration with NumPy and pandas. Includes specialized modules (pyplot for MATLAB-like interface, artist layer for fine-grained control, backend layer for rendering). Supports complex multi-panel figures, color maps, legends, and annotations. Use cases: scientific figure creation, data visualization, exploratory data analysis, publication graphics, and any application requiring high-quality plots
|
||||
- **NetworkX** - Comprehensive toolkit for creating, analyzing, and visualizing complex networks and graphs. Supports four graph types (Graph, DiGraph, MultiGraph, MultiDiGraph) with nodes as any hashable objects and rich edge attributes. Provides 100+ algorithms including shortest paths (Dijkstra, Bellman-Ford, A*), centrality measures (degree, betweenness, closeness, eigenvector, PageRank), clustering (coefficients, triangles, transitivity), community detection (modularity-based, label propagation, Girvan-Newman), connectivity analysis (components, cuts, flows), tree algorithms (MST, spanning trees), matching, graph coloring, isomorphism, and traversal (DFS, BFS). Includes 50+ graph generators for classic (complete, cycle, wheel), random (Erdős-Rényi, Barabási-Albert, Watts-Strogatz, stochastic block model), lattice (grid, hexagonal, hypercube), and specialized networks. Supports I/O across formats (edge lists, GraphML, GML, JSON, Pajek, GEXF, DOT) with Pandas/NumPy/SciPy integration. Visualization capabilities include 8+ layout algorithms (spring/force-directed, circular, spectral, Kamada-Kawai), customizable node/edge appearance, interactive visualizations with Plotly/PyVis, and publication-quality figure generation. Use cases: social network analysis, biological networks (protein-protein interactions, gene regulatory networks, metabolic pathways), transportation systems, citation networks, knowledge graphs, web structure analysis, infrastructure networks, and any domain involving pairwise relationships requiring structural analysis or graph-based modeling
|
||||
- **Polars** - High-performance DataFrame library written in Rust with Python bindings, designed for fast data manipulation and analysis. Provides lazy evaluation for query optimization, efficient memory usage, and parallel processing. Key features include: DataFrame operations (filtering, grouping, joining, aggregations), support for large datasets (larger than RAM), integration with pandas and NumPy, expression API for complex transformations, and support for multiple data formats (CSV, Parquet, JSON, Excel, Arrow). Features query optimization through lazy evaluation, automatic parallelization, and efficient memory management. Often 5-30x faster than pandas for many operations. Use cases: large-scale data processing, ETL pipelines, data analysis workflows, and high-performance data manipulation tasks
|
||||
- **Plotly** - Interactive scientific and statistical data visualization library for Python with 40+ chart types. Provides both high-level API (Plotly Express) for quick visualizations and low-level API (graph objects) for fine-grained control. Key features include: comprehensive chart types (scatter, line, bar, histogram, box, violin, heatmap, contour, 3D plots, geographic maps, financial charts, statistical distributions, hierarchical charts), interactive features (hover tooltips, pan/zoom, legend toggling, animations, rangesliders, buttons/dropdowns), publication-quality output (static images in PNG/PDF/SVG via Kaleido, interactive HTML with embeddable figures), extensive customization (templates, themes, color scales, fonts, layouts, annotations, shapes), subplot support (multi-plot figures with shared axes), and Dash integration for building analytical web applications. Plotly Express offers one-line creation of complex visualizations with automatic color encoding, faceting, and trendlines. Graph objects provide precise control for specialized visualizations (candlestick charts, 3D surfaces, sankey diagrams, gauge charts). Supports pandas DataFrames, NumPy arrays, and various data formats. Use cases: scientific data visualization, statistical analysis, financial charting, interactive dashboards, publication figures, exploratory data analysis, and any application requiring interactive or publication-quality visualizations
|
||||
- **Seaborn** - Statistical data visualization with dataset-oriented interface, automatic confidence intervals, publication-quality themes, colorblind-safe palettes, and comprehensive support for exploratory analysis, distribution comparisons, correlation matrices, regression plots, and multi-panel figures
|
||||
- **SimPy** - Process-based discrete-event simulation framework for modeling systems with processes, queues, and resource contention (manufacturing, service operations, network traffic, logistics). Supports generator-based process definition, multiple resource types (Resource, PriorityResource, PreemptiveResource, Container, Store), event-driven scheduling, process interaction mechanisms (signaling, interruption, parallel/sequential execution), real-time simulation synchronized with wall-clock time, and comprehensive monitoring capabilities for utilization, wait times, and queue statistics
|
||||
- **SymPy** - Symbolic mathematics in Python for exact computation using mathematical symbols rather than numerical approximations. Provides comprehensive support for symbolic algebra (simplification, expansion, factorization), calculus (derivatives, integrals, limits, series), equation solving (algebraic, differential, systems of equations), matrices and linear algebra (eigenvalues, decompositions, solving linear systems), physics (classical mechanics with Lagrangian/Hamiltonian formulations, quantum mechanics, vector analysis, units), number theory (primes, factorization, modular arithmetic, Diophantine equations), geometry (2D/3D analytic geometry), combinatorics (permutations, combinations, partitions, group theory), logic and sets, statistics (probability distributions, random variables), special functions (gamma, Bessel, orthogonal polynomials), and code generation (lambdify to NumPy/SciPy functions, C/Fortran code generation, LaTeX output for documentation). Emphasizes exact arithmetic using rational numbers and symbolic representations, supports assumptions for improved simplification (positive, real, integer), integrates seamlessly with NumPy/SciPy through lambdify for fast numerical evaluation, and enables symbolic-to-numeric pipelines for scientific computing workflows
|
||||
- **Vaex** - High-performance Python library for lazy, out-of-core DataFrames to process and visualize tabular datasets larger than available RAM. Processes over a billion rows per second through memory-mapped files (HDF5, Apache Arrow), lazy evaluation, and virtual columns (zero memory overhead). Provides instant file opening, efficient aggregations across billions of rows, interactive visualizations without sampling, machine learning pipelines with transformers (scalers, encoders, PCA), and seamless integration with pandas/NumPy/Arrow. Includes comprehensive ML framework (vaex.ml) with feature scaling, categorical encoding, dimensionality reduction, and integration with scikit-learn/XGBoost/LightGBM/CatBoost. Supports distributed computing via Dask, asynchronous operations, and state management for production deployment. Use cases: processing gigabyte to terabyte datasets, fast statistical aggregations on massive data, visualizing billion-row datasets, ML pipelines on big data, converting between data formats, and working with astronomical, financial, or scientific large-scale datasets
|
||||
- **ReportLab** - Python library for programmatic PDF generation and document creation. Provides comprehensive tools for creating PDFs from scratch including text formatting, tables, graphics, images, charts, and complex layouts. Key features include: high-level Platypus framework for document layout, low-level canvas API for precise control, support for fonts (TrueType, Type 1), vector graphics, image embedding, page templates, headers/footers, and multi-page documents. Supports barcodes, forms, encryption, and digital signatures. Can generate reports, invoices, certificates, and complex documents programmatically. Use cases: automated report generation, document creation, invoice generation, certificate printing, and any application requiring programmatic PDF creation
|
||||
|
||||
### Phylogenetics & Trees
|
||||
- **ETE Toolkit** - Python library for phylogenetic tree manipulation, visualization, and analysis. Provides comprehensive tools for working with phylogenetic trees including tree construction, manipulation (pruning, collapsing, rooting), tree comparison (Robinson-Foulds distance, tree reconciliation), annotation (node colors, labels, branch styles), and publication-quality visualization. Key features include: support for multiple tree formats (Newick, Nexus, PhyloXML), integration with phylogenetic software (PhyML, RAxML, FastTree), tree annotation with metadata, interactive tree visualization, and export to various image formats (PNG, PDF, SVG). Supports species trees, gene trees, and reconciliation analysis. Use cases: phylogenetic analysis, tree visualization, evolutionary biology research, comparative genomics, and teaching phylogenetics
|
||||
|
||||
### Genomics Tools
|
||||
- **deepTools** - Comprehensive suite of Python tools for exploring and visualizing next-generation sequencing (NGS) data, particularly ChIP-seq, RNA-seq, and ATAC-seq experiments. Provides command-line tools and Python API for processing BAM and bigWig files. Key features include: quality control metrics (plotFingerprint, plotCorrelation), coverage track generation (bamCoverage for creating bigWig files), matrix generation for heatmaps (computeMatrix, plotHeatmap, plotProfile), comparative analysis (multiBigwigSummary, plotPCA), and efficient handling of large files. Supports normalization methods, binning options, and various visualization outputs. Designed for high-throughput analysis workflows and publication-quality figure generation. Use cases: ChIP-seq peak visualization, RNA-seq coverage analysis, ATAC-seq signal tracks, comparative genomics, and NGS data exploration
|
||||
- **FlowIO** - Python library for reading and manipulating Flow Cytometry Standard (FCS) files, the standard format for flow cytometry data. Provides efficient parsing of FCS files (versions 2.0, 3.0, 3.1), access to event data (fluorescence intensities, scatter parameters), metadata extraction (keywords, parameters, acquisition settings), and conversion to pandas DataFrames or NumPy arrays. Features include: support for large FCS files, handling of multiple data segments, access to text segments and analysis segments, and integration with flow cytometry analysis workflows. Enables programmatic access to flow cytometry data for downstream analysis, visualization, and machine learning applications. Use cases: flow cytometry data analysis, high-throughput screening, immune cell profiling, and automated processing of FCS files
|
||||
- **scikit-bio** - Python library for bioinformatics providing data structures, algorithms, and parsers for biological sequence analysis. Built on NumPy, SciPy, and pandas. Key features include: sequence objects (DNA, RNA, protein sequences) with biological alphabet validation, sequence alignment algorithms (local, global, semiglobal), phylogenetic tree manipulation, diversity metrics (alpha diversity, beta diversity, phylogenetic diversity), distance metrics for sequences and communities, file format parsers (FASTA, FASTQ, QIIME formats, Newick), and statistical analysis tools. Provides scikit-learn compatible transformers for machine learning workflows. Supports efficient processing of large sequence datasets. Use cases: sequence analysis, microbial ecology (16S rRNA analysis), metagenomics, phylogenetic analysis, and bioinformatics research requiring sequence manipulation and diversity calculations
|
||||
- **Zarr** - Python library implementing the Zarr chunked, compressed N-dimensional array storage format. Provides efficient storage and access to large multi-dimensional arrays with chunking and compression. Key features include: support for NumPy-like arrays with chunked storage, multiple compression codecs (zlib, blosc, lz4, zstd), support for various data types, efficient partial array reading (only load needed chunks), support for both local filesystem and cloud storage (S3, GCS, Azure), and integration with NumPy, Dask, and Xarray. Enables working with arrays larger than available RAM through lazy loading and efficient chunk access. Supports parallel read/write operations and is optimized for cloud storage backends. Use cases: large-scale scientific data storage, cloud-based array storage, out-of-core array operations, and efficient storage of multi-dimensional datasets (genomics, imaging, climate data)
|
||||
|
||||
### Multi-omics & AI Agent Frameworks
|
||||
- **BIOMNI** - Autonomous biomedical AI agent framework from Stanford SNAP lab for executing complex research tasks across genomics, drug discovery, molecular biology, and clinical analysis. Combines LLM reasoning with code execution and ~11GB of integrated biomedical databases (Ensembl, NCBI Gene, UniProt, PDB, AlphaFold, ClinVar, OMIM, HPO, PubMed, KEGG, Reactome, GO). Supports multiple LLM providers (Claude, GPT-4, Gemini, Groq, Bedrock). Includes A1 agent class for autonomous task decomposition, BiomniEval1 benchmark framework, and MCP server integration. Use cases: CRISPR screening design, single-cell RNA-seq analysis, ADMET prediction, GWAS interpretation, rare disease diagnosis, protein structure analysis, literature synthesis, and multi-omics integration
|
||||
- **Denario** - Multiagent AI system for scientific research assistance that automates complete research workflows from data analysis through publication. Built on AG2 and LangGraph frameworks, orchestrates specialized agents for hypothesis generation, methodology development, computational analysis, and LaTeX paper writing. Supports multiple LLM providers (Google Vertex AI, OpenAI) with flexible pipeline stages allowing manual or automated inputs. Key features include: end-to-end research automation (data description → idea generation → methodology → results → paper), journal-specific formatting (APS and others), GUI interface via Streamlit, Docker deployment with LaTeX environment, reproducible research with version-controlled outputs, literature search integration, and integration with scientific Python stack (pandas, sklearn, scipy). Provides both programmatic Python API and web-based interface. Use cases: automated hypothesis generation from datasets, research methodology development, computational experiment execution with visualization, publication-ready manuscript generation, time-series analysis research, machine learning experiment automation, and accelerating the complete scientific research lifecycle from ideation to publication
|
||||
- **HypoGeniC** - Automated hypothesis generation and testing using large language models to accelerate scientific discovery. Provides three frameworks: HypoGeniC (data-driven hypothesis generation from observational data), HypoRefine (synergistic approach combining literature insights with empirical patterns through an agentic system), and Union methods (mechanistic combination of literature and data-driven hypotheses). Features iterative refinement that improves hypotheses by learning from challenging examples, Redis caching for API cost reduction, and customizable YAML-based prompt templates. Includes command-line tools for generation (hypogenic_generation) and testing (hypogenic_inference). Research applications have demonstrated 14.19% accuracy improvement in AI-content detection and 7.44% in deception detection. Use cases: deception detection in reviews, AI-generated content identification, mental stress detection, exploratory research without existing literature, hypothesis-driven analysis in novel domains, and systematic exploration of competing explanations
|
||||
|
||||
### Scientific Communication & Publishing
|
||||
- **Citation Management** - Comprehensive citation management for academic research. Search Google Scholar and PubMed for papers, extract accurate metadata from multiple sources (CrossRef, PubMed, arXiv), validate citations, and generate properly formatted BibTeX entries. Features include converting DOIs, PMIDs, or arXiv IDs to BibTeX, cleaning and formatting bibliography files, finding highly cited papers, checking for duplicates, and ensuring consistent citation formatting. Use cases: building bibliographies for manuscripts, verifying citation accuracy, citation deduplication, and maintaining reference databases
|
||||
- **Generate Image** - AI-powered image generation and editing for scientific illustrations, schematics, and visualizations using OpenRouter's image generation models. Supports multiple models including google/gemini-3-pro-image-preview (high quality, recommended default) and black-forest-labs/flux.2-pro (fast, high quality). Key features include: text-to-image generation from detailed prompts, image editing capabilities (modify existing images with natural language instructions), automatic base64 encoding/decoding, PNG output with configurable paths, and comprehensive error handling. Requires OpenRouter API key (via .env file or environment variable). Use cases: generating scientific diagrams and illustrations, creating publication-quality figures, editing existing images (changing colors, adding elements, removing backgrounds), producing schematics for papers and presentations, visualizing experimental setups, creating graphical abstracts, and generating conceptual illustrations for scientific communication
|
||||
- **LaTeX Posters** - Create professional research posters in LaTeX using beamerposter, tikzposter, or baposter. Support for conference presentations, academic posters, and scientific communication with layout design, color schemes, multi-column formats, figure integration, and poster-specific best practices. Features compliance with conference size requirements (A0, A1, 36×48"), complex multi-column layouts, and integration of figures, tables, equations, and citations. Use cases: conference poster sessions, thesis defenses, symposia presentations, and research group templates
|
||||
- **Market Research Reports** - Generate comprehensive market research reports (50+ pages) in the style of top consulting firms (McKinsey, BCG, Gartner). Features professional LaTeX formatting, extensive visual generation, deep integration with research-lookup for data gathering, and multi-framework strategic analysis including Porter's Five Forces, PESTLE, SWOT, TAM/SAM/SOM, and BCG Matrix. Use cases: investment decisions, strategic planning, competitive landscape analysis, market sizing, and market entry evaluation
|
||||
- **Paper-2-Web** - Autonomous pipeline for transforming academic papers into multiple promotional formats using the Paper2All system. Converts LaTeX or PDF papers into: (1) Paper2Web - interactive, layout-aware academic homepages with responsive design, interactive figures, and mobile support; (2) Paper2Video - professional presentation videos with slides, narration, cursor movements, and optional talking-head generation using Hallo2; (3) Paper2Poster - print-ready conference posters with custom dimensions, professional layouts, and institution branding. Supports GPT-4/GPT-4.1 models, batch processing, QR code generation, multi-language content, and quality assessment metrics. Use cases: conference materials, video abstracts, preprint enhancement, research promotion, poster sessions, and academic website creation
|
||||
- **Perplexity Search** - AI-powered web search using Perplexity models via LiteLLM and OpenRouter for real-time, web-grounded answers with source citations. Provides access to multiple Perplexity models: Sonar Pro (general-purpose, best cost-quality balance), Sonar Pro Search (most advanced agentic search with multi-step reasoning), Sonar (cost-effective for simple queries), Sonar Reasoning Pro (advanced step-by-step analysis), and Sonar Reasoning (basic reasoning). Key features include: single OpenRouter API key setup (no separate Perplexity account), real-time access to current information beyond training data cutoff, comprehensive query design guidance (domain-specific patterns, time constraints, source preferences), cost optimization strategies with usage monitoring, programmatic and CLI interfaces, batch processing support, and integration with other scientific skills. Installation uses uv pip for LiteLLM, with detailed setup, troubleshooting, and security documentation. Use cases: finding recent scientific publications and research, conducting literature searches across domains, verifying facts with source citations, accessing current developments in any field, comparing technologies and approaches, performing domain-specific research (biomedical, clinical, technical), supplementing PubMed searches with real-time web results, and discovering latest developments post-database indexing
|
||||
- **PPTX Posters** - Create professional research posters using PowerPoint/HTML formats for researchers who prefer WYSIWYG tools over LaTeX. Features design principles, layout templates, quality checklists, and export guidance for poster sessions. Use cases: conference posters when LaTeX is not preferred, quick poster creation, and collaborative poster design
|
||||
- **Scientific Schematics** - Create publication-quality scientific diagrams using Nano Banana Pro AI with smart iterative refinement. Uses Gemini 3 Pro for quality review with document-type-specific thresholds (journal: 8.5/10, conference: 8.0/10, poster: 7.0/10). Specializes in neural network architectures, system diagrams, flowcharts, biological pathways, and complex scientific visualizations. Features natural language input, automatic quality assessment, and publication-ready output. Use cases: creating figures for papers, generating workflow diagrams, visualizing experimental designs, and producing graphical abstracts
|
||||
- **Scientific Slides** - Build slide decks and presentations for research talks using PowerPoint and LaTeX Beamer. Features slide structure, design templates, timing guidance, and visual validation. Emphasizes visual engagement with minimal text, research-backed content with proper citations, and story-driven narrative. Use cases: conference presentations, academic seminars, thesis defenses, grant pitches, and professional talks
|
||||
- **Venue Templates** - Access comprehensive LaTeX templates, formatting requirements, and submission guidelines for major scientific publication venues (Nature, Science, PLOS, IEEE, ACM), academic conferences (NeurIPS, ICML, CVPR, CHI), research posters, and grant proposals (NSF, NIH, DOE, DARPA). Provides ready-to-use templates and detailed specifications for successful academic submissions. Use cases: manuscript preparation, conference papers, research posters, and grant proposals with venue-specific formatting
|
||||
|
||||
### Document Processing & Conversion
|
||||
- **MarkItDown** - Python utility for converting 20+ file formats to Markdown optimized for LLM processing. Converts Office documents (PDF, DOCX, PPTX, XLSX), images with OCR, audio with transcription, web content (HTML, YouTube transcripts, EPUB), and structured data (CSV, JSON, XML) while preserving document structure (headings, lists, tables, hyperlinks). Key features include: Azure Document Intelligence integration for enhanced PDF table extraction, LLM-powered image descriptions using GPT-4o, batch processing with ZIP archive support, modular installation for specific formats, streaming approach without temporary files, and plugin system for custom converters. Supports Python 3.10+. Use cases: preparing documents for RAG systems, extracting text from PDFs and Office files, transcribing audio to text, performing OCR on images and scanned documents, converting YouTube videos to searchable text, processing HTML and EPUB books, converting structured data to readable format, document analysis pipelines, and LLM training data preparation
|
||||
|
||||
### Laboratory Automation & Equipment Control
|
||||
- **PyLabRobot** - Hardware-agnostic, pure Python SDK for automated and autonomous laboratories. Provides unified interface for controlling liquid handling robots (Hamilton STAR/STARlet, Opentrons OT-2, Tecan EVO), plate readers (BMG CLARIOstar), heater shakers, incubators, centrifuges, pumps, and scales. Key features include: modular resource management system for plates, tips, and containers with hierarchical deck layouts and JSON serialization; comprehensive liquid handling operations (aspirate, dispense, transfer, serial dilutions, plate replication) with automatic tip and volume tracking; backend abstraction enabling hardware-agnostic protocols that work across different robots; ChatterboxBackend for protocol simulation and testing without hardware; browser-based visualizer for real-time 3D deck state visualization; cross-platform support (Windows, macOS, Linux, Raspberry Pi); and integration capabilities for multi-device workflows combining liquid handlers, analytical equipment, and material handling devices. Use cases: automated sample preparation, high-throughput screening, serial dilution protocols, plate reading workflows, laboratory protocol development and validation, robotic liquid handling automation, and reproducible laboratory automation with state tracking and persistence
|
||||
|
||||
### Tool Discovery & Research Platforms
|
||||
- **Get Available Resources** - Detect available computational resources and generate strategic recommendations for scientific computing tasks at the start of any computationally intensive scientific task. Automatically identifies CPU capabilities, GPU availability (NVIDIA CUDA, AMD ROCm, Apple Silicon Metal), memory constraints, and disk space. Creates JSON file with resource information and recommendations for parallel processing (joblib, multiprocessing), out-of-core computing (Dask, Zarr), GPU acceleration (PyTorch, JAX), or memory-efficient strategies. Use cases: determining optimal computational approaches before data analysis, model training, or large file operations
|
||||
- **ToolUniverse** - Unified ecosystem providing standardized access to 600+ scientific tools, models, datasets, and APIs across bioinformatics, cheminformatics, genomics, structural biology, and proteomics. Enables AI agents to function as research scientists through: (1) Tool Discovery - natural language, semantic, and keyword-based search for finding relevant scientific tools (Tool_Finder, Tool_Finder_LLM, Tool_Finder_Keyword); (2) Tool Execution - standardized AI-Tool Interaction Protocol for running tools with consistent interfaces; (3) Tool Composition - sequential and parallel workflow chaining for multi-step research pipelines; (4) Model Context Protocol (MCP) integration for Claude Desktop/Code. Supports drug discovery workflows (disease→targets→structures→screening→candidates), genomics analysis (expression→differential analysis→pathways), clinical genomics (variants→annotation→pathogenicity→disease associations), and cross-domain research. Use cases: accessing scientific databases (OpenTargets, PubChem, UniProt, PDB, ChEMBL, KEGG), protein structure prediction (AlphaFold), molecular docking, pathway enrichment, variant annotation, literature searches, and automated scientific workflows
|
||||
|
||||
### Research Methodology & Proposal Writing
|
||||
- **Research Grants** - Write competitive research proposals for NSF, NIH, DOE, and DARPA. Features agency-specific formatting, review criteria understanding, budget preparation, broader impacts statements, significance narratives, innovation sections, and compliance with submission requirements. Covers project descriptions, specific aims, technical narratives, milestone plans, budget justifications, and biosketches. Use cases: federal grant applications, resubmissions with reviewer response, multi-institutional collaborations, and preliminary data sections
|
||||
- **Research Lookup** - Look up current research information using Perplexity's Sonar Pro Search or Sonar Reasoning Pro models through OpenRouter. Intelligently selects models based on query complexity. Provides access to current academic literature, recent studies, technical documentation, and general research information with proper citations. Use cases: finding latest research, literature verification, gathering background research, finding citation sources, and staying current with emerging trends
|
||||
- **Scholar Evaluation** - Apply the ScholarEval framework to systematically evaluate scholarly and research work. Provides structured evaluation methodology based on peer-reviewed research assessment criteria for analyzing academic papers, research proposals, literature reviews, and scholarly writing across multiple quality dimensions. Use cases: evaluating research papers for quality and rigor, assessing methodology design, scoring data analysis approaches, benchmarking research quality, and assessing publication readiness
|
||||
|
||||
### Regulatory & Standards Compliance
|
||||
- **ISO 13485 Certification** - Comprehensive toolkit for preparing ISO 13485:2016 certification documentation for medical device Quality Management Systems. Provides gap analysis of existing documentation, templates for all mandatory documents, compliance checklists, and step-by-step documentation creation. Covers 31 required procedures including Quality Manuals, Medical Device Files, and work instructions. Use cases: starting ISO 13485 certification process, conducting gap analysis, creating or updating QMS documentation, preparing for certification audits, transitioning from FDA QSR to QMSR, and harmonizing with EU MDR requirements
|
||||
|
||||
## Scientific Thinking & Analysis
|
||||
|
||||
### Analysis & Methodology
|
||||
- **Exploratory Data Analysis** - Comprehensive EDA toolkit with automated statistics, visualizations, and insights for any tabular dataset
|
||||
- **Hypothesis Generation** - Structured frameworks for generating and evaluating scientific hypotheses
|
||||
- **Literature Review** - Systematic literature search and review toolkit with support for multiple scientific databases (PubMed, bioRxiv, Google Scholar), citation management with multiple citation styles (APA, AMA, Vancouver, Chicago, IEEE, Nature, Science), citation verification and deduplication, search strategies (Boolean operators, MeSH terms, field tags), PDF report generation with formatted references, and comprehensive templates for conducting systematic reviews following PRISMA guidelines
|
||||
- **Peer Review** - Comprehensive toolkit for conducting high-quality scientific peer review with structured evaluation of methodology, statistics, reproducibility, ethics, and presentation across all scientific disciplines
|
||||
- **Scientific Brainstorming** - Conversational brainstorming partner for generating novel research ideas, exploring connections, challenging assumptions, and developing creative approaches through structured ideation workflows
|
||||
- **Scientific Critical Thinking** - Tools and approaches for rigorous scientific reasoning and evaluation
|
||||
- **Scientific Visualization** - Best practices and templates for creating publication-quality scientific figures with matplotlib and seaborn, including statistical plots with automatic confidence intervals, colorblind-safe palettes, multi-panel figures, heatmaps, and journal-specific formatting
|
||||
- **Scientific Writing** - Comprehensive toolkit for writing, structuring, and formatting scientific research papers using IMRAD format, multiple citation styles (APA, AMA, Vancouver, Chicago, IEEE), reporting guidelines (CONSORT, STROBE, PRISMA), effective figures and tables, field-specific terminology, venue-specific structure expectations, and core writing principles for clarity, conciseness, and accuracy across all scientific disciplines
|
||||
- **Statistical Analysis** - Comprehensive statistical testing, power analysis, and experimental design
|
||||
|
||||
### Document Processing
|
||||
- **XLSX** - Spreadsheet creation, editing, and analysis with support for formulas, formatting, data analysis, and visualization
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Scientific Thinking & Analysis
|
||||
|
||||
## Analysis & Methodology
|
||||
- **Exploratory Data Analysis** - Comprehensive EDA toolkit with automated statistics, visualizations, and insights for any tabular dataset
|
||||
- **Hypothesis Generation** - Structured frameworks for generating and evaluating scientific hypotheses
|
||||
- **Peer Review** - Comprehensive toolkit for conducting high-quality scientific peer review with structured evaluation of methodology, statistics, reproducibility, ethics, and presentation across all scientific disciplines
|
||||
- **Scientific Brainstorming** - Conversational brainstorming partner for generating novel research ideas, exploring connections, challenging assumptions, and developing creative approaches through structured ideation workflows
|
||||
- **Scientific Critical Thinking** - Tools and approaches for rigorous scientific reasoning and evaluation
|
||||
- **Scientific Visualization** - Best practices and templates for creating publication-quality scientific figures with matplotlib and seaborn, including statistical plots with automatic confidence intervals, colorblind-safe palettes, multi-panel figures, heatmaps, and journal-specific formatting
|
||||
- **Scientific Writing** - Comprehensive toolkit for writing, structuring, and formatting scientific research papers using IMRAD format, multiple citation styles (APA, AMA, Vancouver, Chicago, IEEE), reporting guidelines (CONSORT, STROBE, PRISMA), effective figures and tables, field-specific terminology, venue-specific structure expectations, and core writing principles for clarity, conciseness, and accuracy across all scientific disciplines
|
||||
- **Statistical Analysis** - Comprehensive statistical testing, power analysis, and experimental design
|
||||
|
||||
## Document Processing
|
||||
- **DOCX** - Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction
|
||||
- **PDF** - PDF manipulation toolkit for extracting text and tables, creating new PDFs, merging/splitting documents, and handling forms
|
||||
- **PPTX** - Presentation creation, editing, and analysis with support for layouts, comments, and speaker notes
|
||||
- **XLSX** - Spreadsheet creation, editing, and analysis with support for formulas, formatting, data analysis, and visualization
|
||||
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
---
|
||||
name: context-initialization
|
||||
description: "Always Auto-invoked skill that creates/updates workspace AGENT.md to instruct the agent to always search for existing skills before attempting any scientific task."
|
||||
---
|
||||
|
||||
# Context Initialization
|
||||
|
||||
## Overview
|
||||
|
||||
This skill automatically creates or updates an `AGENT.md` file in the workspace root that instructs the agent to use existing skills before attempting to solve scientific tasks. This ensures the agent uses documented patterns, authentication methods, working examples, and best practices from the repository's skills rather than inventing solutions from scratch.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill is automatically invoked when:
|
||||
- The agent initializes in this workspace
|
||||
- User begins any scientific task (database access, package usage, platform integration, or methodology)
|
||||
- User mentions specific databases, packages, platforms, or research methods
|
||||
- Any scientific data retrieval, analysis, or research task is started
|
||||
|
||||
**No manual invocation required** - this skill runs automatically.
|
||||
|
||||
## What This Skill Does
|
||||
|
||||
Creates or updates `AGENT.md` in the workspace root with instructions for the agent to:
|
||||
|
||||
1. **Search first**: Look for relevant skills across all skill categories before writing code
|
||||
2. **Use existing patterns**: Apply documented API access patterns, workflows, and examples
|
||||
3. **Follow best practices**: Use rate limits, authentication, configurations, and established methodologies
|
||||
4. **Adapt examples**: Leverage working code examples from `scripts/` folders and reference documentation
|
||||
|
||||
**Important**: If `AGENT.md` already exists in the workspace, this skill will update it intelligently rather than overwriting it. This preserves any custom instructions or modifications while ensuring the essential skill-search directives are present.
|
||||
|
||||
## Skill Categories
|
||||
|
||||
This unified context initialization covers four major skill categories:
|
||||
|
||||
### Database Access Tasks
|
||||
- Search `scientific-databases/` for 24+ database skills
|
||||
- Use documented API endpoints and authentication patterns
|
||||
- Apply working code examples and best practices
|
||||
- Follow rate limits and error handling patterns
|
||||
|
||||
### Scientific Package Usage
|
||||
- Search `scientific-packages/` for 40+ Python package skills
|
||||
- Use installation instructions and API usage examples
|
||||
- Apply best practices and common patterns
|
||||
- Leverage working scripts and reference documentation
|
||||
|
||||
### Laboratory Platform Integration
|
||||
- Search `scientific-integrations/` for 6+ platform integration skills
|
||||
- Use authentication and setup instructions
|
||||
- Apply API access patterns and platform-specific best practices
|
||||
- Leverage working integration examples
|
||||
|
||||
### Scientific Analysis & Research Methods
|
||||
- Search `scientific-thinking/` for methodology skills
|
||||
- Use established data analysis frameworks (EDA, statistical analysis)
|
||||
- Apply research methodologies (hypothesis generation, brainstorming, critical thinking)
|
||||
- Leverage communication skills (scientific writing, visualization, peer review)
|
||||
- Use document processing skills (DOCX, PDF, PPTX, XLSX)
|
||||
|
||||
## Implementation
|
||||
|
||||
When invoked, this skill manages the workspace `AGENT.md` file as follows:
|
||||
|
||||
- **If `AGENT.md` does not exist**: Creates a new file using the complete template from `references/AGENT.md`
|
||||
- **If `AGENT.md` already exists**: Updates the file to ensure the essential skill-search directives are present, while preserving any existing custom content or modifications
|
||||
|
||||
The file includes sections instructing the agent to search for and use existing skills across all scientific task categories.
|
||||
|
||||
The complete reference template is available in `references/AGENT.md`.
|
||||
|
||||
## Benefits
|
||||
|
||||
By centralizing context initialization, this skill ensures:
|
||||
- **Consistency**: The agent always uses the same approach across all skill types
|
||||
- **Efficiency**: One initialization covers all scientific tasks
|
||||
- **Maintainability**: Updates to the initialization strategy occur in one place
|
||||
- **Completeness**: The agent is reminded to search across all available skill categories
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
# Reference: Complete Context Initialization Template
|
||||
|
||||
This is the complete reference template for what gets added to the workspace root `AGENT.md` file.
|
||||
|
||||
---
|
||||
|
||||
# Agent Scientific Skills - Working Instructions
|
||||
|
||||
## IMPORTANT: Use Available Skills First
|
||||
|
||||
Before attempting any scientific task, use available skills.
|
||||
|
||||
---
|
||||
|
||||
## Database Access Tasks
|
||||
|
||||
**Before writing any database access code, use available skills in this repository.**
|
||||
|
||||
This repository contains skills for 24+ scientific databases. Each skill includes:
|
||||
- API endpoints and authentication patterns
|
||||
- Working code examples
|
||||
- Best practices and rate limits
|
||||
- Example scripts
|
||||
|
||||
Always use available database skills before writing custom database access code.
|
||||
|
||||
---
|
||||
|
||||
## Scientific Package Usage
|
||||
|
||||
**Before writing analysis code with scientific packages, use available skills in this repository.**
|
||||
|
||||
This repository contains skills for 40+ scientific Python packages. Each skill includes:
|
||||
- Installation instructions
|
||||
- Complete API usage examples
|
||||
- Best practices and common patterns
|
||||
- Working scripts and reference documentation
|
||||
|
||||
Always use available package skills before writing custom analysis code.
|
||||
|
||||
---
|
||||
|
||||
## Laboratory Platform Integration
|
||||
|
||||
**Before writing any platform integration code, use available skills in this repository.**
|
||||
|
||||
This repository contains skills for 6+ laboratory platforms and cloud services. Each skill includes:
|
||||
- Authentication and setup instructions
|
||||
- API access patterns
|
||||
- Working integration examples
|
||||
- Platform-specific best practices
|
||||
|
||||
Always use available integration skills before writing custom platform code.
|
||||
|
||||
---
|
||||
|
||||
## Scientific Analysis & Research Methods
|
||||
|
||||
**Before attempting any analysis, writing, or research task, use available methodology skills in this repository.**
|
||||
|
||||
This repository contains skills for scientific methodologies including:
|
||||
- Data analysis frameworks (EDA, statistical analysis)
|
||||
- Research methodologies (hypothesis generation, brainstorming, critical thinking)
|
||||
- Communication skills (scientific writing, visualization, peer review)
|
||||
- Document processing (DOCX, PDF, PPTX, XLSX)
|
||||
|
||||
Always use available methodology skills before attempting scientific analysis or writing tasks.
|
||||
|
||||
---
|
||||
|
||||
*This file is auto-generated by context-initialization skills. It ensures the agent uses available skills before attempting to solve scientific tasks from scratch.*
|
||||
|
||||
@@ -1,527 +0,0 @@
|
||||
---
|
||||
name: anndata
|
||||
description: "Manipulate AnnData objects for single-cell genomics. Load/save .h5ad files, manage obs/var metadata, layers, embeddings (PCA/UMAP), concatenate datasets, for scRNA-seq workflows."
|
||||
---
|
||||
|
||||
# AnnData
|
||||
|
||||
## Overview
|
||||
|
||||
AnnData (Annotated Data) is Python's standard for storing and manipulating annotated data matrices, particularly in single-cell genomics. Work with AnnData objects for data creation, manipulation, file I/O, concatenation, and memory-efficient workflows.
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Creating and Structuring AnnData Objects
|
||||
|
||||
Create AnnData objects from various data sources and organize multi-dimensional annotations.
|
||||
|
||||
**Basic creation:**
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# From dense or sparse arrays
|
||||
counts = np.random.poisson(1, size=(100, 2000))
|
||||
adata = ad.AnnData(counts)
|
||||
|
||||
# With sparse matrix (memory-efficient)
|
||||
counts = csr_matrix(np.random.poisson(1, size=(100, 2000)), dtype=np.float32)
|
||||
adata = ad.AnnData(counts)
|
||||
```
|
||||
|
||||
**With metadata:**
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
obs_meta = pd.DataFrame({
|
||||
'cell_type': pd.Categorical(['B', 'T', 'Monocyte'] * 33 + ['B']),
|
||||
'batch': ['batch1'] * 50 + ['batch2'] * 50
|
||||
})
|
||||
var_meta = pd.DataFrame({
|
||||
'gene_name': [f'Gene_{i}' for i in range(2000)],
|
||||
'highly_variable': np.random.choice([True, False], 2000)
|
||||
})
|
||||
|
||||
adata = ad.AnnData(counts, obs=obs_meta, var=var_meta)
|
||||
```
|
||||
|
||||
**Understanding the structure:**
|
||||
- **X**: Primary data matrix (observations × variables)
|
||||
- **obs**: Row (observation) annotations as DataFrame
|
||||
- **var**: Column (variable) annotations as DataFrame
|
||||
- **obsm**: Multi-dimensional observation annotations (e.g., PCA, UMAP coordinates)
|
||||
- **varm**: Multi-dimensional variable annotations (e.g., gene loadings)
|
||||
- **layers**: Alternative data matrices with same dimensions as X
|
||||
- **uns**: Unstructured metadata dictionary
|
||||
- **obsp/varp**: Pairwise relationship matrices (graphs)
|
||||
|
||||
### 2. Adding Annotations and Layers
|
||||
|
||||
Organize different data representations and metadata within a single object.
|
||||
|
||||
**Cell-level metadata (obs):**
|
||||
```python
|
||||
adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
|
||||
adata.obs['total_counts'] = adata.X.sum(axis=1)
|
||||
adata.obs['condition'] = pd.Categorical(['control', 'treated'] * 50)
|
||||
```
|
||||
|
||||
**Gene-level metadata (var):**
|
||||
```python
|
||||
adata.var['highly_variable'] = gene_variance > threshold
|
||||
adata.var['chromosome'] = pd.Categorical(['chr1', 'chr2', ...])
|
||||
```
|
||||
|
||||
**Embeddings (obsm/varm):**
|
||||
```python
|
||||
# Dimensionality reduction results
|
||||
adata.obsm['X_pca'] = pca_coordinates # Shape: (n_obs, n_components)
|
||||
adata.obsm['X_umap'] = umap_coordinates # Shape: (n_obs, 2)
|
||||
adata.obsm['X_tsne'] = tsne_coordinates
|
||||
|
||||
# Gene loadings
|
||||
adata.varm['PCs'] = principal_components # Shape: (n_vars, n_components)
|
||||
```
|
||||
|
||||
**Alternative data representations (layers):**
|
||||
```python
|
||||
# Store multiple versions
|
||||
adata.layers['counts'] = raw_counts
|
||||
adata.layers['log1p'] = np.log1p(adata.X)
|
||||
adata.layers['scaled'] = (adata.X - mean) / std
|
||||
```
|
||||
|
||||
**Unstructured metadata (uns):**
|
||||
```python
|
||||
# Analysis parameters
|
||||
adata.uns['preprocessing'] = {
|
||||
'normalization': 'TPM',
|
||||
'min_genes': 200,
|
||||
'date': '2024-01-15'
|
||||
}
|
||||
|
||||
# Results
|
||||
adata.uns['pca'] = {'variance_ratio': variance_explained}
|
||||
```
|
||||
|
||||
### 3. Subsetting and Views
|
||||
|
||||
Efficiently subset data while managing memory through views and copies.
|
||||
|
||||
**Subsetting operations:**
|
||||
```python
|
||||
# By observation/variable names
|
||||
subset = adata[['Cell_1', 'Cell_10'], ['Gene_5', 'Gene_1900']]
|
||||
|
||||
# By boolean masks
|
||||
b_cells = adata[adata.obs.cell_type == 'B']
|
||||
high_quality = adata[adata.obs.n_genes > 200]
|
||||
|
||||
# By position
|
||||
first_cells = adata[:100, :]
|
||||
top_genes = adata[:, :500]
|
||||
|
||||
# Combined conditions
|
||||
filtered = adata[
|
||||
(adata.obs.batch == 'batch1') & (adata.obs.n_genes > 200),
|
||||
adata.var.highly_variable
|
||||
]
|
||||
```
|
||||
|
||||
**Understanding views:**
|
||||
- Subsetting returns **views** by default (memory-efficient, shares data with original)
|
||||
- Modifying a view affects the original object
|
||||
- Check with `adata.is_view`
|
||||
- Convert to independent copy with `.copy()`
|
||||
|
||||
```python
|
||||
# View (memory-efficient)
|
||||
subset = adata[adata.obs.condition == 'treated']
|
||||
print(subset.is_view) # True
|
||||
|
||||
# Independent copy
|
||||
subset_copy = adata[adata.obs.condition == 'treated'].copy()
|
||||
print(subset_copy.is_view) # False
|
||||
```
|
||||
|
||||
### 4. File I/O and Backed Mode
|
||||
|
||||
Read and write data efficiently, with options for memory-limited environments.
|
||||
|
||||
**Writing data:**
|
||||
```python
|
||||
# Standard format with compression
|
||||
adata.write('results.h5ad', compression='gzip')
|
||||
|
||||
# Alternative formats
|
||||
adata.write_zarr('results.zarr') # For cloud storage
|
||||
adata.write_loom('results.loom') # For compatibility
|
||||
adata.write_csvs('results/') # As CSV files
|
||||
```
|
||||
|
||||
**Reading data:**
|
||||
```python
|
||||
# Load into memory
|
||||
adata = ad.read_h5ad('results.h5ad')
|
||||
|
||||
# Backed mode (disk-backed, memory-efficient)
|
||||
adata = ad.read_h5ad('large_file.h5ad', backed='r')
|
||||
print(adata.isbacked) # True
|
||||
print(adata.filename) # Path to file
|
||||
|
||||
# Close file connection when done
|
||||
adata.file.close()
|
||||
```
|
||||
|
||||
**Reading from other formats:**
|
||||
```python
|
||||
# 10X format
|
||||
adata = ad.read_mtx('matrix.mtx')
|
||||
|
||||
# CSV
|
||||
adata = ad.read_csv('data.csv')
|
||||
|
||||
# Loom
|
||||
adata = ad.read_loom('data.loom')
|
||||
```
|
||||
|
||||
**Working with backed mode:**
|
||||
```python
|
||||
# Read in backed mode for large files
|
||||
adata = ad.read_h5ad('large_dataset.h5ad', backed='r')
|
||||
|
||||
# Process in chunks
|
||||
for chunk in adata.chunk_X(chunk_size=1000):
|
||||
result = process_chunk(chunk)
|
||||
|
||||
# Load to memory if needed
|
||||
adata_memory = adata.to_memory()
|
||||
```
|
||||
|
||||
### 5. Concatenating Multiple Datasets
|
||||
|
||||
Combine multiple AnnData objects with control over how data is merged.
|
||||
|
||||
**Basic concatenation:**
|
||||
```python
|
||||
# Concatenate observations (most common)
|
||||
combined = ad.concat([adata1, adata2, adata3], axis=0)
|
||||
|
||||
# Concatenate variables (rare)
|
||||
combined = ad.concat([adata1, adata2], axis=1)
|
||||
```
|
||||
|
||||
**Join strategies:**
|
||||
```python
|
||||
# Inner join: only shared variables (no missing data)
|
||||
combined = ad.concat([adata1, adata2], join='inner')
|
||||
|
||||
# Outer join: all variables (fills missing with 0)
|
||||
combined = ad.concat([adata1, adata2], join='outer')
|
||||
```
|
||||
|
||||
**Tracking data sources:**
|
||||
```python
|
||||
# Add source labels
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='dataset',
|
||||
keys=['exp1', 'exp2', 'exp3']
|
||||
)
|
||||
# Creates combined.obs['dataset'] with values 'exp1', 'exp2', 'exp3'
|
||||
|
||||
# Make duplicate indices unique
|
||||
combined = ad.concat(
|
||||
[adata1, adata2],
|
||||
keys=['batch1', 'batch2'],
|
||||
index_unique='-'
|
||||
)
|
||||
# Cell names become: Cell_0-batch1, Cell_0-batch2, etc.
|
||||
```
|
||||
|
||||
**Merge strategies for metadata:**
|
||||
```python
|
||||
# merge=None: exclude variable annotations (default)
|
||||
combined = ad.concat([adata1, adata2], merge=None)
|
||||
|
||||
# merge='same': keep only identical annotations
|
||||
combined = ad.concat([adata1, adata2], merge='same')
|
||||
|
||||
# merge='first': use first occurrence
|
||||
combined = ad.concat([adata1, adata2], merge='first')
|
||||
|
||||
# merge='unique': keep annotations with single value
|
||||
combined = ad.concat([adata1, adata2], merge='unique')
|
||||
```
|
||||
|
||||
**Complete example:**
|
||||
```python
|
||||
# Load batches
|
||||
batch1 = ad.read_h5ad('batch1.h5ad')
|
||||
batch2 = ad.read_h5ad('batch2.h5ad')
|
||||
batch3 = ad.read_h5ad('batch3.h5ad')
|
||||
|
||||
# Concatenate with full tracking
|
||||
combined = ad.concat(
|
||||
[batch1, batch2, batch3],
|
||||
axis=0,
|
||||
join='outer', # Keep all genes
|
||||
merge='first', # Use first batch's annotations
|
||||
label='batch_id', # Track source
|
||||
keys=['b1', 'b2', 'b3'], # Custom labels
|
||||
index_unique='-' # Make cell names unique
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Data Conversion and Extraction
|
||||
|
||||
Convert between AnnData and other formats for interoperability.
|
||||
|
||||
**To DataFrame:**
|
||||
```python
|
||||
# Convert X to DataFrame
|
||||
df = adata.to_df()
|
||||
|
||||
# Convert specific layer
|
||||
df = adata.to_df(layer='log1p')
|
||||
```
|
||||
|
||||
**Extract vectors:**
|
||||
```python
|
||||
# Get 1D arrays from data or annotations
|
||||
gene_expression = adata.obs_vector('Gene_100')
|
||||
cell_metadata = adata.obs_vector('n_genes')
|
||||
```
|
||||
|
||||
**Transpose:**
|
||||
```python
|
||||
# Swap observations and variables
|
||||
transposed = adata.T
|
||||
```
|
||||
|
||||
### 7. Memory Optimization
|
||||
|
||||
Strategies for working with large datasets efficiently.
|
||||
|
||||
**Use sparse matrices:**
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Check sparsity
|
||||
density = (adata.X != 0).sum() / adata.X.size
|
||||
if density < 0.3: # Less than 30% non-zero
|
||||
adata.X = csr_matrix(adata.X)
|
||||
```
|
||||
|
||||
**Convert strings to categoricals:**
|
||||
```python
|
||||
# Automatic conversion
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# Manual conversion (more control)
|
||||
adata.obs['cell_type'] = pd.Categorical(adata.obs['cell_type'])
|
||||
```
|
||||
|
||||
**Use backed mode:**
|
||||
```python
|
||||
# Read without loading into memory
|
||||
adata = ad.read_h5ad('large_file.h5ad', backed='r')
|
||||
|
||||
# Work with subsets
|
||||
subset = adata[:1000, :500].copy() # Only this subset in memory
|
||||
```
|
||||
|
||||
**Chunked processing:**
|
||||
```python
|
||||
# Process data in chunks
|
||||
results = []
|
||||
for chunk in adata.chunk_X(chunk_size=1000):
|
||||
result = expensive_computation(chunk)
|
||||
results.append(result)
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Single-Cell RNA-seq Analysis
|
||||
|
||||
Complete workflow from loading to analysis:
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 1. Load data
|
||||
adata = ad.read_mtx('matrix.mtx')
|
||||
adata.obs_names = pd.read_csv('barcodes.tsv', header=None)[0]
|
||||
adata.var_names = pd.read_csv('genes.tsv', header=None)[0]
|
||||
|
||||
# 2. Quality control
|
||||
adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
|
||||
adata.obs['total_counts'] = adata.X.sum(axis=1)
|
||||
adata = adata[adata.obs.n_genes > 200]
|
||||
adata = adata[adata.obs.total_counts < 10000]
|
||||
|
||||
# 3. Filter genes
|
||||
min_cells = 3
|
||||
adata = adata[:, (adata.X > 0).sum(axis=0) >= min_cells]
|
||||
|
||||
# 4. Store raw counts
|
||||
adata.layers['counts'] = adata.X.copy()
|
||||
|
||||
# 5. Normalize
|
||||
adata.X = adata.X / adata.obs.total_counts.values[:, None] * 1e4
|
||||
adata.X = np.log1p(adata.X)
|
||||
|
||||
# 6. Feature selection
|
||||
gene_var = adata.X.var(axis=0)
|
||||
adata.var['highly_variable'] = gene_var > np.percentile(gene_var, 90)
|
||||
|
||||
# 7. Dimensionality reduction (example with external tools)
|
||||
# adata.obsm['X_pca'] = compute_pca(adata.X)
|
||||
# adata.obsm['X_umap'] = compute_umap(adata.obsm['X_pca'])
|
||||
|
||||
# 8. Save results
|
||||
adata.write('analyzed.h5ad', compression='gzip')
|
||||
```
|
||||
|
||||
### Batch Integration
|
||||
|
||||
Combining multiple experimental batches:
|
||||
|
||||
```python
|
||||
# Load batches
|
||||
batches = [ad.read_h5ad(f'batch_{i}.h5ad') for i in range(3)]
|
||||
|
||||
# Concatenate with tracking
|
||||
combined = ad.concat(
|
||||
batches,
|
||||
axis=0,
|
||||
join='outer',
|
||||
label='batch',
|
||||
keys=['batch_0', 'batch_1', 'batch_2'],
|
||||
index_unique='-'
|
||||
)
|
||||
|
||||
# Add batch as numeric for correction algorithms
|
||||
combined.obs['batch_numeric'] = combined.obs['batch'].cat.codes
|
||||
|
||||
# Perform batch correction (with external tools)
|
||||
# corrected_pca = run_harmony(combined.obsm['X_pca'], combined.obs['batch'])
|
||||
# combined.obsm['X_pca_corrected'] = corrected_pca
|
||||
|
||||
# Save integrated data
|
||||
combined.write('integrated.h5ad', compression='gzip')
|
||||
```
|
||||
|
||||
### Memory-Efficient Large Dataset Processing
|
||||
|
||||
Working with datasets too large for memory:
|
||||
|
||||
```python
|
||||
# Read in backed mode
|
||||
adata = ad.read_h5ad('huge_dataset.h5ad', backed='r')
|
||||
|
||||
# Compute statistics in chunks
|
||||
total = 0
|
||||
for chunk in adata.chunk_X(chunk_size=1000):
|
||||
total += chunk.sum()
|
||||
|
||||
mean_expression = total / (adata.n_obs * adata.n_vars)
|
||||
|
||||
# Work with subset
|
||||
high_quality_cells = adata.obs.n_genes > 1000
|
||||
subset = adata[high_quality_cells, :].copy()
|
||||
|
||||
# Close file
|
||||
adata.file.close()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Data Organization
|
||||
|
||||
1. **Use layers for different representations**: Store raw counts, normalized, log-transformed, and scaled data in separate layers
|
||||
2. **Use obsm/varm for multi-dimensional data**: Embeddings, loadings, and other matrix-like annotations
|
||||
3. **Use uns for metadata**: Analysis parameters, dates, version information
|
||||
4. **Use categoricals for efficiency**: Convert repeated strings to categorical types
|
||||
|
||||
### Subsetting
|
||||
|
||||
1. **Understand views vs copies**: Subsetting returns views by default; use `.copy()` when you need independence
|
||||
2. **Chain conditions efficiently**: Combine boolean masks in a single subsetting operation
|
||||
3. **Validate after subsetting**: Check dimensions and data integrity
|
||||
|
||||
### File I/O
|
||||
|
||||
1. **Use compression**: Always use `compression='gzip'` when writing h5ad files
|
||||
2. **Choose the right format**: H5AD for general use, Zarr for cloud storage, Loom for compatibility
|
||||
3. **Close backed files**: Always close file connections when done
|
||||
4. **Use backed mode for large files**: Don't load everything into memory if not needed
|
||||
|
||||
### Concatenation
|
||||
|
||||
1. **Choose appropriate join**: Inner join for complete cases, outer join to preserve all features
|
||||
2. **Track sources**: Use `label` and `keys` to track data origin
|
||||
3. **Handle duplicates**: Use `index_unique` to make observation names unique
|
||||
4. **Select merge strategy**: Choose appropriate merge strategy for variable annotations
|
||||
|
||||
### Memory Management
|
||||
|
||||
1. **Use sparse matrices**: For data with <30% non-zero values
|
||||
2. **Convert to categoricals**: For repeated string values
|
||||
3. **Process in chunks**: For operations on very large matrices
|
||||
4. **Use backed mode**: Read large files with `backed='r'`
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
Follow these conventions for consistency:
|
||||
|
||||
- **Embeddings**: `X_pca`, `X_umap`, `X_tsne`
|
||||
- **Layers**: Descriptive names like `counts`, `log1p`, `scaled`
|
||||
- **Observations**: Use snake_case like `cell_type`, `n_genes`, `total_counts`
|
||||
- **Variables**: Use snake_case like `highly_variable`, `gene_name`
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
For detailed API information, usage patterns, and troubleshooting, refer to the comprehensive reference files in the `references/` directory:
|
||||
|
||||
1. **api_reference.md**: Complete API documentation including all classes, methods, and functions with usage examples. Use `grep -r "pattern" references/api_reference.md` to search for specific functions or parameters.
|
||||
|
||||
2. **workflows_best_practices.md**: Detailed workflows for common tasks (single-cell analysis, batch integration, large datasets), best practices for memory management, data organization, and common pitfalls to avoid. Use `grep -r "pattern" references/workflows_best_practices.md` to search for specific workflows.
|
||||
|
||||
3. **concatenation_guide.md**: Comprehensive guide to concatenation strategies, join types, merge strategies, source tracking, and troubleshooting concatenation issues. Use `grep -r "pattern" references/concatenation_guide.md` to search for concatenation patterns.
|
||||
|
||||
## When to Load References
|
||||
|
||||
Load reference files into context when:
|
||||
- Implementing complex concatenation with specific merge strategies
|
||||
- Troubleshooting errors or unexpected behavior
|
||||
- Optimizing memory usage for large datasets
|
||||
- Implementing complete analysis workflows
|
||||
- Understanding nuances of specific API methods
|
||||
|
||||
To search within references without loading them:
|
||||
```python
|
||||
# Example: Search for information about backed mode
|
||||
grep -r "backed mode" references/
|
||||
```
|
||||
|
||||
## Common Error Patterns
|
||||
|
||||
### Memory Errors
|
||||
**Problem**: "MemoryError: Unable to allocate array"
|
||||
**Solution**: Use backed mode, sparse matrices, or process in chunks
|
||||
|
||||
### Dimension Mismatch
|
||||
**Problem**: "ValueError: operands could not be broadcast together"
|
||||
**Solution**: Use outer join in concatenation or align indices before operations
|
||||
|
||||
### View Modification
|
||||
**Problem**: "ValueError: assignment destination is read-only"
|
||||
**Solution**: Convert view to copy with `.copy()` before modification
|
||||
|
||||
### File Already Open
|
||||
**Problem**: "OSError: Unable to open file (file is already open)"
|
||||
**Solution**: Close previous file connection with `adata.file.close()`
|
||||
@@ -1,218 +0,0 @@
|
||||
# AnnData API Reference
|
||||
|
||||
## Core AnnData Class
|
||||
|
||||
The `AnnData` class is the central data structure for storing and manipulating annotated datasets in single-cell genomics and other domains.
|
||||
|
||||
### Core Attributes
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| **X** | array-like | Primary data matrix (#observations × #variables). Supports NumPy arrays, sparse matrices (CSR/CSC), HDF5 datasets, Zarr arrays, and Dask arrays |
|
||||
| **obs** | DataFrame | One-dimensional annotation of observations (rows). Length equals observation count |
|
||||
| **var** | DataFrame | One-dimensional annotation of variables/features (columns). Length equals variable count |
|
||||
| **uns** | OrderedDict | Unstructured annotation for miscellaneous metadata |
|
||||
| **obsm** | dict-like | Multi-dimensional observation annotations (structured arrays aligned to observation axis) |
|
||||
| **varm** | dict-like | Multi-dimensional variable annotations (structured arrays aligned to variable axis) |
|
||||
| **obsp** | dict-like | Pairwise observation annotations (square matrices representing graphs) |
|
||||
| **varp** | dict-like | Pairwise variable annotations (graphs between features) |
|
||||
| **layers** | dict-like | Additional data matrices matching X's dimensions |
|
||||
| **raw** | AnnData | Stores original versions of X and var before transformations |
|
||||
|
||||
### Dimensional Properties
|
||||
|
||||
- **n_obs**: Number of observations (sample count)
|
||||
- **n_vars**: Number of variables/features
|
||||
- **shape**: Tuple returning (n_obs, n_vars)
|
||||
- **T**: Transposed view of the entire object
|
||||
|
||||
### State Properties
|
||||
|
||||
- **isbacked**: Boolean indicating disk-backed storage status
|
||||
- **is_view**: Boolean identifying whether object is a view of another AnnData
|
||||
- **filename**: Path to backing .h5ad file; setting this enables disk-backed mode
|
||||
|
||||
### Key Methods
|
||||
|
||||
#### Construction and Copying
|
||||
- **`AnnData(X=None, obs=None, var=None, ...)`**: Create new AnnData object
|
||||
- **`copy(filename=None)`**: Create full copy, optionally stored on disk
|
||||
|
||||
#### Subsetting and Views
|
||||
- **`adata[obs_subset, var_subset]`**: Subset observations and variables (returns view by default)
|
||||
- **`.copy()`**: Convert view to independent object
|
||||
|
||||
#### Data Access
|
||||
- **`to_df(layer=None)`**: Generate pandas DataFrame representation
|
||||
- **`obs_vector(k, layer=None)`**: Extract 1D array from X, layers, or annotations
|
||||
- **`var_vector(k, layer=None)`**: Extract 1D array for a variable
|
||||
- **`chunk_X(chunk_size)`**: Iterate over data matrix in chunks
|
||||
- **`chunked_X(chunk_size)`**: Context manager for chunked iteration
|
||||
|
||||
#### Transformation
|
||||
- **`transpose()`**: Return transposed object
|
||||
- **`concatenate(*adatas, ...)`**: Combine multiple AnnData objects along observation axis
|
||||
- **`to_memory(copy=False)`**: Load all backed arrays into RAM
|
||||
|
||||
#### File I/O
|
||||
- **`write_h5ad(filename, compression='gzip')`**: Save as .h5ad HDF5 format
|
||||
- **`write_zarr(store, ...)`**: Export hierarchical Zarr store
|
||||
- **`write_loom(filename, ...)`**: Output .loom format file
|
||||
- **`write_csvs(dirname, ...)`**: Write annotations as separate CSV files
|
||||
|
||||
#### Data Management
|
||||
- **`strings_to_categoricals()`**: Convert string annotations to categorical types
|
||||
- **`rename_categories(key, categories)`**: Update category labels in annotations
|
||||
- **`obs_names_make_unique(sep='-')`**: Append numeric suffixes to duplicate observation names
|
||||
- **`var_names_make_unique(sep='-')`**: Append numeric suffixes to duplicate variable names
|
||||
|
||||
## Module-Level Functions
|
||||
|
||||
### Reading Functions
|
||||
|
||||
#### Native Formats
|
||||
- **`read_h5ad(filename, backed=None, as_sparse=None)`**: Load HDF5-based .h5ad files
|
||||
- **`read_zarr(store)`**: Access hierarchical Zarr array stores
|
||||
|
||||
#### Alternative Formats
|
||||
- **`read_csv(filename, ...)`**: Import from CSV files
|
||||
- **`read_excel(filename, ...)`**: Import from Excel files
|
||||
- **`read_hdf(filename, key)`**: Read from HDF5 files
|
||||
- **`read_loom(filename, ...)`**: Import from .loom files
|
||||
- **`read_mtx(filename, ...)`**: Import from Matrix Market format
|
||||
- **`read_text(filename, ...)`**: Import from text files
|
||||
- **`read_umi_tools(filename, ...)`**: Import from UMI-tools format
|
||||
|
||||
#### Element-Level Access
|
||||
- **`read_elem(elem)`**: Retrieve specific components from storage
|
||||
- **`sparse_dataset(group)`**: Generate backed sparse matrix classes
|
||||
|
||||
### Combining Operations
|
||||
- **`concat(adatas, axis=0, join='inner', merge=None, ...)`**: Merge multiple AnnData objects
|
||||
- **axis**: 0 (observations) or 1 (variables)
|
||||
- **join**: 'inner' (intersection) or 'outer' (union)
|
||||
- **merge**: Strategy for non-concatenation axis ('same', 'unique', 'first', 'only', or None)
|
||||
- **label**: Column name for source tracking
|
||||
- **keys**: Dataset identifiers for source annotation
|
||||
- **index_unique**: Separator for making duplicate indices unique
|
||||
|
||||
### Writing Functions
|
||||
- **`write_h5ad(filename, adata, compression='gzip')`**: Export to HDF5 format
|
||||
- **`write_zarr(store, adata, ...)`**: Save as Zarr hierarchical arrays
|
||||
- **`write_elem(elem, ...)`**: Write individual components
|
||||
|
||||
### Experimental Features
|
||||
- **`AnnCollection`**: Batch processing for large collections
|
||||
- **`AnnLoader`**: PyTorch DataLoader integration
|
||||
- **`concat_on_disk(*adatas, filename, ...)`**: Memory-efficient out-of-core concatenation
|
||||
- **`read_lazy(filename)`**: Lazy loading with deferred computation
|
||||
- **`read_dispatched(filename, ...)`**: Custom I/O with callbacks
|
||||
- **`write_dispatched(filename, ...)`**: Custom writing with callbacks
|
||||
|
||||
### Configuration
|
||||
- **`settings`**: Package-wide configuration object
|
||||
- **`settings.override(**kwargs)`**: Context manager for temporary settings changes
|
||||
|
||||
## Common Usage Patterns
|
||||
|
||||
### Creating AnnData Objects
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# From dense array
|
||||
counts = np.random.poisson(1, size=(100, 2000))
|
||||
adata = ad.AnnData(counts)
|
||||
|
||||
# From sparse matrix
|
||||
counts = csr_matrix(np.random.poisson(1, size=(100, 2000)), dtype=np.float32)
|
||||
adata = ad.AnnData(counts)
|
||||
|
||||
# With metadata
|
||||
import pandas as pd
|
||||
obs_meta = pd.DataFrame({'cell_type': ['B', 'T', 'Monocyte'] * 33 + ['B']})
|
||||
var_meta = pd.DataFrame({'gene_name': [f'Gene_{i}' for i in range(2000)]})
|
||||
adata = ad.AnnData(counts, obs=obs_meta, var=var_meta)
|
||||
```
|
||||
|
||||
### Subsetting
|
||||
|
||||
```python
|
||||
# By names
|
||||
subset = adata[['Cell_1', 'Cell_10'], ['Gene_5', 'Gene_1900']]
|
||||
|
||||
# By boolean mask
|
||||
b_cells = adata[adata.obs.cell_type == 'B']
|
||||
|
||||
# By position
|
||||
first_five = adata[:5, :100]
|
||||
|
||||
# Convert view to copy
|
||||
adata_copy = adata[:5].copy()
|
||||
```
|
||||
|
||||
### Adding Annotations
|
||||
|
||||
```python
|
||||
# Cell-level metadata
|
||||
adata.obs['batch'] = pd.Categorical(['batch1', 'batch2'] * 50)
|
||||
|
||||
# Gene-level metadata
|
||||
adata.var['highly_variable'] = np.random.choice([True, False], size=adata.n_vars)
|
||||
|
||||
# Embeddings
|
||||
adata.obsm['X_pca'] = np.random.normal(size=(adata.n_obs, 50))
|
||||
adata.obsm['X_umap'] = np.random.normal(size=(adata.n_obs, 2))
|
||||
|
||||
# Alternative data representations
|
||||
adata.layers['log_transformed'] = np.log1p(adata.X)
|
||||
adata.layers['scaled'] = (adata.X - adata.X.mean(axis=0)) / adata.X.std(axis=0)
|
||||
|
||||
# Unstructured metadata
|
||||
adata.uns['experiment_date'] = '2024-01-15'
|
||||
adata.uns['parameters'] = {'min_genes': 200, 'min_cells': 3}
|
||||
```
|
||||
|
||||
### File I/O
|
||||
|
||||
```python
|
||||
# Write to disk
|
||||
adata.write('my_results.h5ad', compression='gzip')
|
||||
|
||||
# Read into memory
|
||||
adata = ad.read_h5ad('my_results.h5ad')
|
||||
|
||||
# Read in backed mode (memory-efficient)
|
||||
adata = ad.read_h5ad('my_results.h5ad', backed='r')
|
||||
|
||||
# Close file connection
|
||||
adata.file.close()
|
||||
```
|
||||
|
||||
### Concatenation
|
||||
|
||||
```python
|
||||
# Combine multiple datasets
|
||||
adata1 = ad.AnnData(np.random.poisson(1, size=(100, 2000)))
|
||||
adata2 = ad.AnnData(np.random.poisson(1, size=(150, 2000)))
|
||||
adata3 = ad.AnnData(np.random.poisson(1, size=(80, 2000)))
|
||||
|
||||
# Simple concatenation
|
||||
combined = ad.concat([adata1, adata2, adata3], axis=0)
|
||||
|
||||
# With source labels
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
axis=0,
|
||||
label='dataset',
|
||||
keys=['exp1', 'exp2', 'exp3']
|
||||
)
|
||||
|
||||
# Inner join (only shared variables)
|
||||
combined = ad.concat([adata1, adata2, adata3], axis=0, join='inner')
|
||||
|
||||
# Outer join (all variables, pad with zeros)
|
||||
combined = ad.concat([adata1, adata2, adata3], axis=0, join='outer')
|
||||
```
|
||||
@@ -1,478 +0,0 @@
|
||||
# AnnData Concatenation Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The `concat()` function combines multiple AnnData objects through two fundamental operations:
|
||||
1. **Concatenation**: Stacking sub-elements in order
|
||||
2. **Merging**: Combining collections into one result
|
||||
|
||||
## Basic Concatenation
|
||||
|
||||
### Syntax
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
combined = ad.concat(
|
||||
adatas, # List of AnnData objects
|
||||
axis=0, # 0=observations, 1=variables
|
||||
join='inner', # 'inner' or 'outer'
|
||||
merge=None, # Merge strategy for non-concat axis
|
||||
label=None, # Column name for source tracking
|
||||
keys=None, # Dataset identifiers
|
||||
index_unique=None, # Separator for unique indices
|
||||
fill_value=None, # Fill value for missing data
|
||||
pairwise=False # Include pairwise matrices
|
||||
)
|
||||
```
|
||||
|
||||
### Concatenating Observations (Cells)
|
||||
```python
|
||||
# Most common: combining multiple samples/batches
|
||||
adata1 = ad.AnnData(np.random.rand(100, 2000))
|
||||
adata2 = ad.AnnData(np.random.rand(150, 2000))
|
||||
adata3 = ad.AnnData(np.random.rand(80, 2000))
|
||||
|
||||
combined = ad.concat([adata1, adata2, adata3], axis=0)
|
||||
# Result: (330 observations, 2000 variables)
|
||||
```
|
||||
|
||||
### Concatenating Variables (Genes)
|
||||
```python
|
||||
# Less common: combining different feature sets
|
||||
adata1 = ad.AnnData(np.random.rand(100, 1000))
|
||||
adata2 = ad.AnnData(np.random.rand(100, 500))
|
||||
|
||||
combined = ad.concat([adata1, adata2], axis=1)
|
||||
# Result: (100 observations, 1500 variables)
|
||||
```
|
||||
|
||||
## Join Strategies
|
||||
|
||||
### Inner Join (Intersection)
|
||||
|
||||
Keeps only shared features across all objects.
|
||||
|
||||
```python
|
||||
# Datasets with different genes
|
||||
adata1 = ad.AnnData(
|
||||
np.random.rand(100, 2000),
|
||||
var=pd.DataFrame(index=[f'Gene_{i}' for i in range(2000)])
|
||||
)
|
||||
adata2 = ad.AnnData(
|
||||
np.random.rand(150, 1800),
|
||||
var=pd.DataFrame(index=[f'Gene_{i}' for i in range(200, 2000)])
|
||||
)
|
||||
|
||||
# Inner join: only genes present in both
|
||||
combined = ad.concat([adata1, adata2], join='inner')
|
||||
# Result: (250 observations, 1800 variables)
|
||||
# Only Gene_200 through Gene_1999
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- You want to analyze only features measured in all datasets
|
||||
- Missing features would compromise analysis
|
||||
- You need a complete case analysis
|
||||
|
||||
**Trade-offs:**
|
||||
- May lose many features
|
||||
- Ensures no missing data
|
||||
- Smaller result size
|
||||
|
||||
### Outer Join (Union)
|
||||
|
||||
Keeps all features from all objects, padding with fill values (default 0).
|
||||
|
||||
```python
|
||||
# Outer join: all genes from both datasets
|
||||
combined = ad.concat([adata1, adata2], join='outer')
|
||||
# Result: (250 observations, 2000 variables)
|
||||
# Missing values filled with 0
|
||||
|
||||
# Custom fill value
|
||||
combined = ad.concat([adata1, adata2], join='outer', fill_value=np.nan)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- You want to preserve all features
|
||||
- Sparse data is acceptable
|
||||
- Features are independent
|
||||
|
||||
**Trade-offs:**
|
||||
- Introduces zeros/missing values
|
||||
- Larger result size
|
||||
- May need imputation
|
||||
|
||||
## Merge Strategies
|
||||
|
||||
Merge strategies control how elements on the non-concatenation axis are combined.
|
||||
|
||||
### merge=None (Default)
|
||||
|
||||
Excludes all non-concatenation axis elements.
|
||||
|
||||
```python
|
||||
# Both datasets have var annotations
|
||||
adata1.var['gene_type'] = ['protein_coding'] * 2000
|
||||
adata2.var['gene_type'] = ['protein_coding'] * 1800
|
||||
|
||||
# merge=None: var annotations excluded
|
||||
combined = ad.concat([adata1, adata2], merge=None)
|
||||
assert 'gene_type' not in combined.var.columns
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Annotations are dataset-specific
|
||||
- You'll add new annotations after merging
|
||||
|
||||
### merge='same'
|
||||
|
||||
Keeps only annotations with identical values across datasets.
|
||||
|
||||
```python
|
||||
# Same annotation values
|
||||
adata1.var['chromosome'] = ['chr1'] * 1000 + ['chr2'] * 1000
|
||||
adata2.var['chromosome'] = ['chr1'] * 900 + ['chr2'] * 900
|
||||
|
||||
# merge='same': keeps chromosome annotation
|
||||
combined = ad.concat([adata1, adata2], merge='same')
|
||||
assert 'chromosome' in combined.var.columns
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Annotations should be consistent
|
||||
- You want to validate consistency
|
||||
- Shared metadata is important
|
||||
|
||||
**Note:** Comparison occurs after index alignment - only shared indices need to match.
|
||||
|
||||
### merge='unique'
|
||||
|
||||
Includes annotations with a single possible value.
|
||||
|
||||
```python
|
||||
# Unique values per gene
|
||||
adata1.var['ensembl_id'] = [f'ENSG{i:08d}' for i in range(2000)]
|
||||
adata2.var['ensembl_id'] = [f'ENSG{i:08d}' for i in range(2000)]
|
||||
|
||||
# merge='unique': keeps ensembl_id
|
||||
combined = ad.concat([adata1, adata2], merge='unique')
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Each feature has a unique identifier
|
||||
- Annotations are feature-specific
|
||||
|
||||
### merge='first'
|
||||
|
||||
Takes the first occurrence of each annotation.
|
||||
|
||||
```python
|
||||
# Different annotation versions
|
||||
adata1.var['description'] = ['desc1'] * 2000
|
||||
adata2.var['description'] = ['desc2'] * 2000
|
||||
|
||||
# merge='first': uses adata1's descriptions
|
||||
combined = ad.concat([adata1, adata2], merge='first')
|
||||
# Uses descriptions from adata1
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- One dataset has authoritative annotations
|
||||
- Order matters
|
||||
- You need a simple resolution strategy
|
||||
|
||||
### merge='only'
|
||||
|
||||
Retains annotations appearing in exactly one object.
|
||||
|
||||
```python
|
||||
# Dataset-specific annotations
|
||||
adata1.var['dataset1_specific'] = ['value'] * 2000
|
||||
adata2.var['dataset2_specific'] = ['value'] * 2000
|
||||
|
||||
# merge='only': keeps both (no conflicts)
|
||||
combined = ad.concat([adata1, adata2], merge='only')
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Datasets have non-overlapping annotations
|
||||
- You want to preserve all unique metadata
|
||||
|
||||
## Source Tracking
|
||||
|
||||
### Using label
|
||||
|
||||
Add a categorical column to track data origin.
|
||||
|
||||
```python
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='batch'
|
||||
)
|
||||
|
||||
# Creates obs['batch'] with values 0, 1, 2
|
||||
print(combined.obs['batch'].cat.categories) # ['0', '1', '2']
|
||||
```
|
||||
|
||||
### Using keys
|
||||
|
||||
Provide custom names for source tracking.
|
||||
|
||||
```python
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='study',
|
||||
keys=['control', 'treatment_a', 'treatment_b']
|
||||
)
|
||||
|
||||
# Creates obs['study'] with custom names
|
||||
print(combined.obs['study'].unique()) # ['control', 'treatment_a', 'treatment_b']
|
||||
```
|
||||
|
||||
### Making Indices Unique
|
||||
|
||||
Append source identifiers to duplicate observation names.
|
||||
|
||||
```python
|
||||
# Both datasets have cells named "Cell_0", "Cell_1", etc.
|
||||
adata1.obs_names = [f'Cell_{i}' for i in range(100)]
|
||||
adata2.obs_names = [f'Cell_{i}' for i in range(150)]
|
||||
|
||||
# index_unique adds suffix
|
||||
combined = ad.concat(
|
||||
[adata1, adata2],
|
||||
keys=['batch1', 'batch2'],
|
||||
index_unique='-'
|
||||
)
|
||||
|
||||
# Results in: Cell_0-batch1, Cell_0-batch2, etc.
|
||||
print(combined.obs_names[:5])
|
||||
```
|
||||
|
||||
## Handling Different Attributes
|
||||
|
||||
### X Matrix and Layers
|
||||
|
||||
Follows join strategy. Missing values filled according to `fill_value`.
|
||||
|
||||
```python
|
||||
# Both have layers
|
||||
adata1.layers['counts'] = adata1.X.copy()
|
||||
adata2.layers['counts'] = adata2.X.copy()
|
||||
|
||||
# Concatenates both X and layers
|
||||
combined = ad.concat([adata1, adata2])
|
||||
assert 'counts' in combined.layers
|
||||
```
|
||||
|
||||
### obs and var DataFrames
|
||||
|
||||
- **obs**: Concatenated along concatenation axis
|
||||
- **var**: Handled by merge strategy
|
||||
|
||||
```python
|
||||
adata1.obs['cell_type'] = ['B cell'] * 100
|
||||
adata2.obs['cell_type'] = ['T cell'] * 150
|
||||
|
||||
combined = ad.concat([adata1, adata2])
|
||||
# obs['cell_type'] preserved for all cells
|
||||
```
|
||||
|
||||
### obsm and varm
|
||||
|
||||
Multi-dimensional annotations follow same rules as layers.
|
||||
|
||||
```python
|
||||
adata1.obsm['X_pca'] = np.random.rand(100, 50)
|
||||
adata2.obsm['X_pca'] = np.random.rand(150, 50)
|
||||
|
||||
combined = ad.concat([adata1, adata2])
|
||||
# obsm['X_pca'] concatenated: shape (250, 50)
|
||||
```
|
||||
|
||||
### obsp and varp
|
||||
|
||||
Pairwise matrices excluded by default. Enable with `pairwise=True`.
|
||||
|
||||
```python
|
||||
# Distance matrices
|
||||
adata1.obsp['distances'] = np.random.rand(100, 100)
|
||||
adata2.obsp['distances'] = np.random.rand(150, 150)
|
||||
|
||||
# Excluded by default
|
||||
combined = ad.concat([adata1, adata2])
|
||||
assert 'distances' not in combined.obsp
|
||||
|
||||
# Include if needed
|
||||
combined = ad.concat([adata1, adata2], pairwise=True)
|
||||
# Results in padded block diagonal matrix
|
||||
```
|
||||
|
||||
### uns Dictionary
|
||||
|
||||
Merged recursively, applying merge strategy at any nesting depth.
|
||||
|
||||
```python
|
||||
adata1.uns['experiment'] = {'date': '2024-01', 'lab': 'A'}
|
||||
adata2.uns['experiment'] = {'date': '2024-02', 'lab': 'A'}
|
||||
|
||||
# merge='same' keeps 'lab', excludes 'date'
|
||||
combined = ad.concat([adata1, adata2], merge='same')
|
||||
# combined.uns['experiment'] = {'lab': 'A'}
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Batch Integration Pipeline
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Load batches
|
||||
batches = [
|
||||
ad.read_h5ad(f'batch_{i}.h5ad')
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Concatenate with tracking
|
||||
combined = ad.concat(
|
||||
batches,
|
||||
axis=0,
|
||||
join='outer',
|
||||
merge='first',
|
||||
label='batch_id',
|
||||
keys=[f'batch_{i}' for i in range(5)],
|
||||
index_unique='-'
|
||||
)
|
||||
|
||||
# Add batch effects
|
||||
combined.obs['batch_numeric'] = combined.obs['batch_id'].cat.codes
|
||||
```
|
||||
|
||||
### Multi-Study Meta-Analysis
|
||||
|
||||
```python
|
||||
# Different studies with varying gene coverage
|
||||
studies = {
|
||||
'study_a': ad.read_h5ad('study_a.h5ad'),
|
||||
'study_b': ad.read_h5ad('study_b.h5ad'),
|
||||
'study_c': ad.read_h5ad('study_c.h5ad')
|
||||
}
|
||||
|
||||
# Outer join to keep all genes
|
||||
combined = ad.concat(
|
||||
list(studies.values()),
|
||||
axis=0,
|
||||
join='outer',
|
||||
label='study',
|
||||
keys=list(studies.keys()),
|
||||
merge='unique',
|
||||
fill_value=0
|
||||
)
|
||||
|
||||
# Track coverage
|
||||
for study in studies:
|
||||
n_genes = studies[study].n_vars
|
||||
combined.uns[f'{study}_n_genes'] = n_genes
|
||||
```
|
||||
|
||||
### Incremental Concatenation
|
||||
|
||||
```python
|
||||
# For many datasets, concatenate in batches
|
||||
chunk_size = 10
|
||||
all_files = [f'dataset_{i}.h5ad' for i in range(100)]
|
||||
|
||||
# Process in chunks
|
||||
result = None
|
||||
for i in range(0, len(all_files), chunk_size):
|
||||
chunk_files = all_files[i:i+chunk_size]
|
||||
chunk_adatas = [ad.read_h5ad(f) for f in chunk_files]
|
||||
chunk_combined = ad.concat(chunk_adatas)
|
||||
|
||||
if result is None:
|
||||
result = chunk_combined
|
||||
else:
|
||||
result = ad.concat([result, chunk_combined])
|
||||
```
|
||||
|
||||
### Memory-Efficient On-Disk Concatenation
|
||||
|
||||
```python
|
||||
# Experimental feature for large datasets
|
||||
from anndata.experimental import concat_on_disk
|
||||
|
||||
files = ['dataset1.h5ad', 'dataset2.h5ad', 'dataset3.h5ad']
|
||||
concat_on_disk(
|
||||
files,
|
||||
'combined.h5ad',
|
||||
join='outer'
|
||||
)
|
||||
|
||||
# Read result in backed mode
|
||||
combined = ad.read_h5ad('combined.h5ad', backed='r')
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Dimension Mismatch
|
||||
|
||||
```python
|
||||
# Error: shapes don't match
|
||||
adata1 = ad.AnnData(np.random.rand(100, 2000))
|
||||
adata2 = ad.AnnData(np.random.rand(150, 1500))
|
||||
|
||||
# Solution: use outer join
|
||||
combined = ad.concat([adata1, adata2], join='outer')
|
||||
```
|
||||
|
||||
### Issue: Memory Error
|
||||
|
||||
```python
|
||||
# Problem: too many large objects in memory
|
||||
large_adatas = [ad.read_h5ad(f) for f in many_files]
|
||||
|
||||
# Solution: read and concatenate incrementally
|
||||
result = None
|
||||
for file in many_files:
|
||||
adata = ad.read_h5ad(file)
|
||||
if result is None:
|
||||
result = adata
|
||||
else:
|
||||
result = ad.concat([result, adata])
|
||||
del adata # Free memory
|
||||
```
|
||||
|
||||
### Issue: Duplicate Indices
|
||||
|
||||
```python
|
||||
# Problem: same cell names in different batches
|
||||
# Solution: use index_unique
|
||||
combined = ad.concat(
|
||||
[adata1, adata2],
|
||||
keys=['batch1', 'batch2'],
|
||||
index_unique='-'
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Lost Annotations
|
||||
|
||||
```python
|
||||
# Problem: annotations disappear
|
||||
adata1.var['important'] = values1
|
||||
adata2.var['important'] = values2
|
||||
|
||||
combined = ad.concat([adata1, adata2]) # merge=None by default
|
||||
# Solution: use appropriate merge strategy
|
||||
combined = ad.concat([adata1, adata2], merge='first')
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Pre-align indices**: Ensure consistent naming before concatenation
|
||||
2. **Use sparse matrices**: Convert to sparse before concatenating
|
||||
3. **Batch operations**: Concatenate in groups for many datasets
|
||||
4. **Choose inner join**: When possible, to reduce result size
|
||||
5. **Use categoricals**: Convert string annotations before concatenating
|
||||
6. **Consider on-disk**: For very large datasets, use `concat_on_disk`
|
||||
@@ -1,438 +0,0 @@
|
||||
# AnnData Workflows and Best Practices
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### 1. Single-Cell RNA-seq Analysis Workflow
|
||||
|
||||
#### Loading Data
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Load from 10X format
|
||||
adata = ad.read_mtx('matrix.mtx')
|
||||
adata.var_names = pd.read_csv('genes.tsv', sep='\t', header=None)[0]
|
||||
adata.obs_names = pd.read_csv('barcodes.tsv', sep='\t', header=None)[0]
|
||||
|
||||
# Or load from pre-processed h5ad
|
||||
adata = ad.read_h5ad('preprocessed_data.h5ad')
|
||||
```
|
||||
|
||||
#### Quality Control
|
||||
```python
|
||||
# Calculate QC metrics
|
||||
adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
|
||||
adata.obs['total_counts'] = adata.X.sum(axis=1)
|
||||
|
||||
# Filter cells
|
||||
adata = adata[adata.obs.n_genes > 200]
|
||||
adata = adata[adata.obs.total_counts < 10000]
|
||||
|
||||
# Filter genes
|
||||
min_cells = 3
|
||||
adata = adata[:, (adata.X > 0).sum(axis=0) >= min_cells]
|
||||
```
|
||||
|
||||
#### Normalization and Preprocessing
|
||||
```python
|
||||
# Store raw counts
|
||||
adata.layers['counts'] = adata.X.copy()
|
||||
|
||||
# Normalize
|
||||
adata.X = adata.X / adata.obs.total_counts.values[:, None] * 1e4
|
||||
|
||||
# Log transform
|
||||
adata.layers['log1p'] = np.log1p(adata.X)
|
||||
adata.X = adata.layers['log1p']
|
||||
|
||||
# Identify highly variable genes
|
||||
gene_variance = adata.X.var(axis=0)
|
||||
adata.var['highly_variable'] = gene_variance > np.percentile(gene_variance, 90)
|
||||
```
|
||||
|
||||
#### Dimensionality Reduction
|
||||
```python
|
||||
# PCA
|
||||
from sklearn.decomposition import PCA
|
||||
pca = PCA(n_components=50)
|
||||
adata.obsm['X_pca'] = pca.fit_transform(adata.X)
|
||||
|
||||
# Store PCA variance
|
||||
adata.uns['pca'] = {'variance_ratio': pca.explained_variance_ratio_}
|
||||
|
||||
# UMAP
|
||||
from umap import UMAP
|
||||
umap = UMAP(n_components=2)
|
||||
adata.obsm['X_umap'] = umap.fit_transform(adata.obsm['X_pca'])
|
||||
```
|
||||
|
||||
#### Clustering
|
||||
```python
|
||||
# Store cluster assignments
|
||||
adata.obs['clusters'] = pd.Categorical(['cluster_0', 'cluster_1', ...])
|
||||
|
||||
# Store cluster centroids
|
||||
centroids = np.array([...])
|
||||
adata.varm['cluster_centroids'] = centroids
|
||||
```
|
||||
|
||||
#### Save Results
|
||||
```python
|
||||
# Save complete analysis
|
||||
adata.write('analyzed_data.h5ad', compression='gzip')
|
||||
```
|
||||
|
||||
### 2. Batch Integration Workflow
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Load multiple batches
|
||||
batch1 = ad.read_h5ad('batch1.h5ad')
|
||||
batch2 = ad.read_h5ad('batch2.h5ad')
|
||||
batch3 = ad.read_h5ad('batch3.h5ad')
|
||||
|
||||
# Concatenate with batch labels
|
||||
adata = ad.concat(
|
||||
[batch1, batch2, batch3],
|
||||
axis=0,
|
||||
label='batch',
|
||||
keys=['batch1', 'batch2', 'batch3'],
|
||||
index_unique='-'
|
||||
)
|
||||
|
||||
# Batch effect correction would go here
|
||||
# (using external tools like Harmony, Scanorama, etc.)
|
||||
|
||||
# Store corrected embeddings
|
||||
adata.obsm['X_pca_corrected'] = corrected_pca
|
||||
adata.obsm['X_umap_corrected'] = corrected_umap
|
||||
```
|
||||
|
||||
### 3. Memory-Efficient Large Dataset Workflow
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Read in backed mode
|
||||
adata = ad.read_h5ad('large_dataset.h5ad', backed='r')
|
||||
|
||||
# Check backing status
|
||||
print(f"Is backed: {adata.isbacked}")
|
||||
print(f"File: {adata.filename}")
|
||||
|
||||
# Work with chunks
|
||||
for chunk in adata.chunk_X(chunk_size=1000):
|
||||
# Process chunk
|
||||
result = process_chunk(chunk)
|
||||
|
||||
# Close file when done
|
||||
adata.file.close()
|
||||
```
|
||||
|
||||
### 4. Multi-Dataset Comparison Workflow
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Load datasets
|
||||
datasets = {
|
||||
'study1': ad.read_h5ad('study1.h5ad'),
|
||||
'study2': ad.read_h5ad('study2.h5ad'),
|
||||
'study3': ad.read_h5ad('study3.h5ad')
|
||||
}
|
||||
|
||||
# Outer join to keep all genes
|
||||
combined = ad.concat(
|
||||
list(datasets.values()),
|
||||
axis=0,
|
||||
join='outer',
|
||||
label='study',
|
||||
keys=list(datasets.keys()),
|
||||
merge='first'
|
||||
)
|
||||
|
||||
# Handle missing data
|
||||
combined.X[np.isnan(combined.X)] = 0
|
||||
|
||||
# Add dataset-specific metadata
|
||||
combined.uns['datasets'] = {
|
||||
'study1': {'date': '2023-01', 'n_samples': datasets['study1'].n_obs},
|
||||
'study2': {'date': '2023-06', 'n_samples': datasets['study2'].n_obs},
|
||||
'study3': {'date': '2024-01', 'n_samples': datasets['study3'].n_obs}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Memory Management
|
||||
|
||||
#### Use Sparse Matrices
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Convert to sparse if data is sparse
|
||||
if density < 0.3: # Less than 30% non-zero
|
||||
adata.X = csr_matrix(adata.X)
|
||||
```
|
||||
|
||||
#### Use Backed Mode for Large Files
|
||||
```python
|
||||
# Read with backing
|
||||
adata = ad.read_h5ad('large_file.h5ad', backed='r')
|
||||
|
||||
# Only load what you need
|
||||
subset = adata[:1000, :500].copy() # Now in memory
|
||||
```
|
||||
|
||||
#### Convert Strings to Categoricals
|
||||
```python
|
||||
# Efficient storage for repeated strings
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# Or manually
|
||||
adata.obs['cell_type'] = pd.Categorical(adata.obs['cell_type'])
|
||||
```
|
||||
|
||||
### Data Organization
|
||||
|
||||
#### Use Layers for Different Representations
|
||||
```python
|
||||
# Store multiple versions of the data
|
||||
adata.layers['counts'] = raw_counts
|
||||
adata.layers['normalized'] = normalized_data
|
||||
adata.layers['log1p'] = log_transformed_data
|
||||
adata.layers['scaled'] = scaled_data
|
||||
```
|
||||
|
||||
#### Use obsm/varm for Multi-Dimensional Annotations
|
||||
```python
|
||||
# Embeddings
|
||||
adata.obsm['X_pca'] = pca_coordinates
|
||||
adata.obsm['X_umap'] = umap_coordinates
|
||||
adata.obsm['X_tsne'] = tsne_coordinates
|
||||
|
||||
# Gene loadings
|
||||
adata.varm['PCs'] = principal_components
|
||||
```
|
||||
|
||||
#### Use uns for Analysis Metadata
|
||||
```python
|
||||
# Store parameters
|
||||
adata.uns['preprocessing'] = {
|
||||
'normalization': 'TPM',
|
||||
'min_genes': 200,
|
||||
'min_cells': 3,
|
||||
'date': '2024-01-15'
|
||||
}
|
||||
|
||||
# Store analysis results
|
||||
adata.uns['differential_expression'] = {
|
||||
'method': 't-test',
|
||||
'p_value_threshold': 0.05
|
||||
}
|
||||
```
|
||||
|
||||
### Subsetting and Views
|
||||
|
||||
#### Understand View vs Copy
|
||||
```python
|
||||
# Subsetting returns a view
|
||||
subset = adata[adata.obs.cell_type == 'B cell'] # View
|
||||
print(subset.is_view) # True
|
||||
|
||||
# Views are memory efficient but modifications affect original
|
||||
subset.obs['new_column'] = value # Modifies original adata
|
||||
|
||||
# Create independent copy when needed
|
||||
subset_copy = adata[adata.obs.cell_type == 'B cell'].copy()
|
||||
```
|
||||
|
||||
#### Chain Operations Efficiently
|
||||
```python
|
||||
# Bad - creates multiple intermediate views
|
||||
temp1 = adata[adata.obs.batch == 'batch1']
|
||||
temp2 = temp1[temp1.obs.n_genes > 200]
|
||||
result = temp2[:, temp2.var.highly_variable].copy()
|
||||
|
||||
# Good - chain operations
|
||||
result = adata[
|
||||
(adata.obs.batch == 'batch1') & (adata.obs.n_genes > 200),
|
||||
adata.var.highly_variable
|
||||
].copy()
|
||||
```
|
||||
|
||||
### File I/O
|
||||
|
||||
#### Use Compression
|
||||
```python
|
||||
# Save with compression
|
||||
adata.write('data.h5ad', compression='gzip')
|
||||
```
|
||||
|
||||
#### Choose the Right Format
|
||||
```python
|
||||
# H5AD for general use (good compression, fast)
|
||||
adata.write_h5ad('data.h5ad')
|
||||
|
||||
# Zarr for cloud storage and parallel access
|
||||
adata.write_zarr('data.zarr')
|
||||
|
||||
# Loom for compatibility with other tools
|
||||
adata.write_loom('data.loom')
|
||||
```
|
||||
|
||||
#### Close File Connections
|
||||
```python
|
||||
# Use context manager pattern
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r')
|
||||
try:
|
||||
# Work with data
|
||||
process(adata)
|
||||
finally:
|
||||
adata.file.close()
|
||||
```
|
||||
|
||||
### Concatenation
|
||||
|
||||
#### Choose Appropriate Join Strategy
|
||||
```python
|
||||
# Inner join - only common features (safe, may lose data)
|
||||
combined = ad.concat([adata1, adata2], join='inner')
|
||||
|
||||
# Outer join - all features (keeps all data, may introduce zeros)
|
||||
combined = ad.concat([adata1, adata2], join='outer')
|
||||
```
|
||||
|
||||
#### Track Data Sources
|
||||
```python
|
||||
# Add source labels
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='dataset',
|
||||
keys=['exp1', 'exp2', 'exp3']
|
||||
)
|
||||
|
||||
# Make indices unique
|
||||
combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
index_unique='-'
|
||||
)
|
||||
```
|
||||
|
||||
#### Handle Variable-Specific Metadata
|
||||
```python
|
||||
# Use merge strategy for var annotations
|
||||
combined = ad.concat(
|
||||
[adata1, adata2],
|
||||
merge='same', # Keep only identical annotations
|
||||
join='outer'
|
||||
)
|
||||
```
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
#### Use Consistent Naming
|
||||
```python
|
||||
# Embeddings: X_<method>
|
||||
adata.obsm['X_pca']
|
||||
adata.obsm['X_umap']
|
||||
adata.obsm['X_tsne']
|
||||
|
||||
# Layers: descriptive names
|
||||
adata.layers['counts']
|
||||
adata.layers['log1p']
|
||||
adata.layers['scaled']
|
||||
|
||||
# Observations: snake_case
|
||||
adata.obs['cell_type']
|
||||
adata.obs['n_genes']
|
||||
adata.obs['total_counts']
|
||||
```
|
||||
|
||||
#### Make Indices Unique
|
||||
```python
|
||||
# Ensure unique names
|
||||
adata.obs_names_make_unique()
|
||||
adata.var_names_make_unique()
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
#### Validate Data Structure
|
||||
```python
|
||||
# Check dimensions
|
||||
assert adata.n_obs > 0, "No observations in data"
|
||||
assert adata.n_vars > 0, "No variables in data"
|
||||
|
||||
# Check for NaN values
|
||||
if np.isnan(adata.X).any():
|
||||
print("Warning: NaN values detected")
|
||||
|
||||
# Check for negative values in count data
|
||||
if (adata.X < 0).any():
|
||||
print("Warning: Negative values in count data")
|
||||
```
|
||||
|
||||
#### Handle Missing Data
|
||||
```python
|
||||
# Check for missing annotations
|
||||
if adata.obs['cell_type'].isna().any():
|
||||
print("Warning: Missing cell type annotations")
|
||||
# Fill or remove
|
||||
adata = adata[~adata.obs['cell_type'].isna()]
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Forgetting to Copy Views
|
||||
```python
|
||||
# BAD - modifies original
|
||||
subset = adata[adata.obs.condition == 'treated']
|
||||
subset.X = transformed_data # Changes original adata!
|
||||
|
||||
# GOOD
|
||||
subset = adata[adata.obs.condition == 'treated'].copy()
|
||||
subset.X = transformed_data # Only changes subset
|
||||
```
|
||||
|
||||
### 2. Mixing Backed and In-Memory Operations
|
||||
```python
|
||||
# BAD - trying to modify backed data
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r')
|
||||
adata.X[0, 0] = 100 # Error: can't modify backed data
|
||||
|
||||
# GOOD - load to memory first
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r')
|
||||
adata = adata.to_memory()
|
||||
adata.X[0, 0] = 100 # Works
|
||||
```
|
||||
|
||||
### 3. Not Using Categoricals for Metadata
|
||||
```python
|
||||
# BAD - stores as strings (memory inefficient)
|
||||
adata.obs['cell_type'] = ['B cell', 'T cell', ...] * 1000
|
||||
|
||||
# GOOD - use categorical
|
||||
adata.obs['cell_type'] = pd.Categorical(['B cell', 'T cell', ...] * 1000)
|
||||
```
|
||||
|
||||
### 4. Incorrect Concatenation Axis
|
||||
```python
|
||||
# Concatenating observations (cells)
|
||||
combined = ad.concat([adata1, adata2], axis=0) # Correct
|
||||
|
||||
# Concatenating variables (genes) - rare
|
||||
combined = ad.concat([adata1, adata2], axis=1) # Less common
|
||||
```
|
||||
|
||||
### 5. Not Preserving Raw Data
|
||||
```python
|
||||
# BAD - loses original data
|
||||
adata.X = normalized_data
|
||||
|
||||
# GOOD - preserve original
|
||||
adata.layers['counts'] = adata.X.copy()
|
||||
adata.X = normalized_data
|
||||
```
|
||||
@@ -1,415 +0,0 @@
|
||||
---
|
||||
name: arboreto
|
||||
description: "Gene regulatory network inference with GRNBoost2/GENIE3 algorithms. Infer TF-target relationships from expression data, scalable with Dask, for scRNA-seq and GRN analysis."
|
||||
---
|
||||
|
||||
# Arboreto - Gene Regulatory Network Inference
|
||||
|
||||
## Overview
|
||||
|
||||
Arboreto is a Python library for inferring gene regulatory networks (GRNs) from gene expression data using machine learning algorithms. It enables scalable GRN inference from single machines to multi-node clusters using Dask for distributed computing. The skill provides comprehensive support for both GRNBoost2 (fast gradient boosting) and GENIE3 (Random Forest) algorithms.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Inferring regulatory relationships between genes from expression data
|
||||
- Analyzing single-cell or bulk RNA-seq data to identify transcription factor targets
|
||||
- Building the GRN inference component of a pySCENIC pipeline
|
||||
- Comparing GRNBoost2 and GENIE3 algorithm performance
|
||||
- Setting up distributed computing for large-scale genomic analyses
|
||||
- Troubleshooting arboreto installation or runtime issues
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Basic GRN Inference
|
||||
|
||||
For standard gene regulatory network inference tasks:
|
||||
|
||||
**Key considerations:**
|
||||
- Expression data format: Rows = observations (cells/samples), Columns = genes
|
||||
- If data has genes as rows, transpose it first: `expression_df.T`
|
||||
- Always include `seed` parameter for reproducible results
|
||||
- Transcription factor list is optional but recommended for focused analysis
|
||||
|
||||
**Typical workflow:**
|
||||
```python
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
# Load expression data (ensure correct orientation)
|
||||
expression_data = pd.read_csv('expression_data.tsv', sep='\t', index_col=0)
|
||||
|
||||
# Optional: Load TF names
|
||||
tf_names = load_tf_names('transcription_factors.txt')
|
||||
|
||||
# Run inference
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=42 # For reproducibility
|
||||
)
|
||||
|
||||
# Save results
|
||||
network.to_csv('network_output.tsv', sep='\t', index=False)
|
||||
```
|
||||
|
||||
**Output format:**
|
||||
- DataFrame with columns: `['TF', 'target', 'importance']`
|
||||
- Higher importance scores indicate stronger predicted regulatory relationships
|
||||
- Typically sorted by importance (descending)
|
||||
|
||||
**Multiprocessing requirement:**
|
||||
All arboreto code must include `if __name__ == '__main__':` protection due to Dask's multiprocessing requirements:
|
||||
|
||||
```python
|
||||
if __name__ == '__main__':
|
||||
# Arboreto code goes here
|
||||
network = grnboost2(expression_data=expr_data, seed=42)
|
||||
```
|
||||
|
||||
### 2. Algorithm Selection
|
||||
|
||||
**GRNBoost2 (Recommended for most cases):**
|
||||
- ~10-100x faster than GENIE3
|
||||
- Uses stochastic gradient boosting with early-stopping
|
||||
- Best for: Large datasets (>10k observations), time-sensitive analyses
|
||||
- Function: `arboreto.algo.grnboost2()`
|
||||
|
||||
**GENIE3:**
|
||||
- Uses Random Forest regression
|
||||
- More established, classical approach
|
||||
- Best for: Small datasets, methodological comparisons, reproducing published results
|
||||
- Function: `arboreto.algo.genie3()`
|
||||
|
||||
**When to compare both algorithms:**
|
||||
Use the provided `compare_algorithms.py` script when:
|
||||
- Validating results for critical analyses
|
||||
- Benchmarking performance on new datasets
|
||||
- Publishing research requiring methodological comparisons
|
||||
|
||||
### 3. Distributed Computing
|
||||
|
||||
**Local execution (default):**
|
||||
Arboreto automatically creates a local Dask client. No configuration needed:
|
||||
```python
|
||||
network = grnboost2(expression_data=expr_data)
|
||||
```
|
||||
|
||||
**Custom local cluster (recommended for better control):**
|
||||
```python
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
# Configure cluster
|
||||
cluster = LocalCluster(
|
||||
n_workers=4,
|
||||
threads_per_worker=2,
|
||||
memory_limit='4GB',
|
||||
diagnostics_port=8787 # Dashboard at http://localhost:8787
|
||||
)
|
||||
client = Client(cluster)
|
||||
|
||||
# Run inference
|
||||
network = grnboost2(
|
||||
expression_data=expr_data,
|
||||
client_or_address=client
|
||||
)
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
**Distributed cluster (multi-node):**
|
||||
On scheduler node:
|
||||
```bash
|
||||
dask-scheduler --no-bokeh
|
||||
```
|
||||
|
||||
On worker nodes:
|
||||
```bash
|
||||
dask-worker scheduler-address:8786 --local-dir /tmp
|
||||
```
|
||||
|
||||
In Python:
|
||||
```python
|
||||
from dask.distributed import Client
|
||||
|
||||
client = Client('scheduler-address:8786')
|
||||
network = grnboost2(expression_data=expr_data, client_or_address=client)
|
||||
```
|
||||
|
||||
### 4. Data Preparation
|
||||
|
||||
**Common data format issues:**
|
||||
|
||||
1. **Transposed data** (genes as rows instead of columns):
|
||||
```python
|
||||
# If genes are rows, transpose
|
||||
expression_data = pd.read_csv('data.tsv', sep='\t', index_col=0).T
|
||||
```
|
||||
|
||||
2. **Missing gene names:**
|
||||
```python
|
||||
# Provide gene names if using numpy array
|
||||
network = grnboost2(
|
||||
expression_data=expr_array,
|
||||
gene_names=['Gene1', 'Gene2', 'Gene3', ...],
|
||||
seed=42
|
||||
)
|
||||
```
|
||||
|
||||
3. **Transcription factor specification:**
|
||||
```python
|
||||
# Option 1: Python list
|
||||
tf_names = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
|
||||
|
||||
# Option 2: Load from file (one TF per line)
|
||||
from arboreto.utils import load_tf_names
|
||||
tf_names = load_tf_names('tf_names.txt')
|
||||
```
|
||||
|
||||
### 5. Reproducibility
|
||||
|
||||
Always specify a seed for consistent results:
|
||||
```python
|
||||
network = grnboost2(expression_data=expr_data, seed=42)
|
||||
```
|
||||
|
||||
Without a seed, results will vary between runs due to algorithm randomness.
|
||||
|
||||
### 6. Result Interpretation
|
||||
|
||||
**Understanding the output:**
|
||||
- `TF`: Transcription factor (regulator) gene
|
||||
- `target`: Target gene being regulated
|
||||
- `importance`: Strength of predicted regulatory relationship
|
||||
|
||||
**Typical post-processing:**
|
||||
```python
|
||||
# Filter by importance threshold
|
||||
high_confidence = network[network['importance'] > 10]
|
||||
|
||||
# Get top N predictions
|
||||
top_predictions = network.head(1000)
|
||||
|
||||
# Find all targets of a specific TF
|
||||
sox2_targets = network[network['TF'] == 'Sox2']
|
||||
|
||||
# Count regulations per TF
|
||||
tf_counts = network['TF'].value_counts()
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
**Recommended (via conda):**
|
||||
```bash
|
||||
conda install -c bioconda arboreto
|
||||
```
|
||||
|
||||
**Via pip:**
|
||||
```bash
|
||||
pip install arboreto
|
||||
```
|
||||
|
||||
**From source:**
|
||||
```bash
|
||||
git clone https://github.com/tmoerman/arboreto.git
|
||||
cd arboreto
|
||||
pip install .
|
||||
```
|
||||
|
||||
**Dependencies:**
|
||||
- pandas
|
||||
- numpy
|
||||
- scikit-learn
|
||||
- scipy
|
||||
- dask
|
||||
- distributed
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Bokeh error when launching Dask scheduler
|
||||
|
||||
**Error:** `TypeError: got an unexpected keyword argument 'host'`
|
||||
|
||||
**Solutions:**
|
||||
- Use `dask-scheduler --no-bokeh` to disable Bokeh
|
||||
- Upgrade to Dask distributed >= 0.20.0
|
||||
|
||||
### Issue: Workers not connecting to scheduler
|
||||
|
||||
**Symptoms:** Worker processes start but fail to establish connections
|
||||
|
||||
**Solutions:**
|
||||
- Remove `dask-worker-space` directory before restarting workers
|
||||
- Specify adequate `local_dir` when creating cluster:
|
||||
```python
|
||||
cluster = LocalCluster(
|
||||
worker_kwargs={'local_dir': '/tmp'}
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Memory errors with large datasets
|
||||
|
||||
**Solutions:**
|
||||
- Increase worker memory limits: `memory_limit='8GB'`
|
||||
- Distribute across more nodes
|
||||
- Reduce dataset size through preprocessing (e.g., feature selection)
|
||||
- Ensure expression matrix fits in available RAM
|
||||
|
||||
### Issue: Inconsistent results across runs
|
||||
|
||||
**Solution:** Always specify a `seed` parameter:
|
||||
```python
|
||||
network = grnboost2(expression_data=expr_data, seed=42)
|
||||
```
|
||||
|
||||
### Issue: Import errors or missing dependencies
|
||||
|
||||
**Solution:** Use conda installation to handle numerical library dependencies:
|
||||
```bash
|
||||
conda create --name arboreto-env
|
||||
conda activate arboreto-env
|
||||
conda install -c bioconda arboreto
|
||||
```
|
||||
|
||||
## Provided Scripts
|
||||
|
||||
This skill includes ready-to-use scripts for common workflows:
|
||||
|
||||
### scripts/basic_grn_inference.py
|
||||
|
||||
Command-line tool for standard GRN inference workflow.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python scripts/basic_grn_inference.py expression_data.tsv \
|
||||
-t tf_names.txt \
|
||||
-o network.tsv \
|
||||
-s 42 \
|
||||
--transpose # if genes are rows
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Automatic data loading and validation
|
||||
- Optional TF list specification
|
||||
- Configurable output format
|
||||
- Data transposition support
|
||||
- Summary statistics
|
||||
|
||||
### scripts/distributed_inference.py
|
||||
|
||||
GRN inference with custom Dask cluster configuration.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python scripts/distributed_inference.py expression_data.tsv \
|
||||
-t tf_names.txt \
|
||||
-w 8 \
|
||||
-m 4GB \
|
||||
--threads 2 \
|
||||
--dashboard-port 8787
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Configurable worker count and memory limits
|
||||
- Dask dashboard integration
|
||||
- Thread configuration
|
||||
- Resource monitoring
|
||||
|
||||
### scripts/compare_algorithms.py
|
||||
|
||||
Compare GRNBoost2 and GENIE3 side-by-side.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python scripts/compare_algorithms.py expression_data.tsv \
|
||||
-t tf_names.txt \
|
||||
--top-n 100
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Runtime comparison
|
||||
- Network statistics
|
||||
- Prediction overlap analysis
|
||||
- Top prediction comparison
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
Detailed API documentation is available in [references/api_reference.md](references/api_reference.md), including:
|
||||
- Complete parameter descriptions for all functions
|
||||
- Data format specifications
|
||||
- Distributed computing configuration
|
||||
- Performance optimization tips
|
||||
- Integration with pySCENIC
|
||||
- Comprehensive examples
|
||||
|
||||
Load this reference when:
|
||||
- Working with advanced Dask configurations
|
||||
- Troubleshooting complex deployment scenarios
|
||||
- Understanding algorithm internals
|
||||
- Optimizing performance for specific use cases
|
||||
|
||||
## Integration with pySCENIC
|
||||
|
||||
Arboreto is the first step in the pySCENIC single-cell analysis pipeline:
|
||||
|
||||
1. **GRN Inference (arboreto)** ← This skill
|
||||
- Input: Expression matrix
|
||||
- Output: Regulatory network
|
||||
|
||||
2. **Regulon Prediction (pySCENIC)**
|
||||
- Input: Network from arboreto
|
||||
- Output: Refined regulons
|
||||
|
||||
3. **Cell Type Identification (pySCENIC)**
|
||||
- Input: Regulons
|
||||
- Output: Cell type scores
|
||||
|
||||
When working with pySCENIC, use arboreto to generate the initial network, then pass results to the pySCENIC pipeline.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use seed parameter** for reproducible research
|
||||
2. **Validate data orientation** (rows = observations, columns = genes)
|
||||
3. **Specify TF list** when known to focus inference and improve speed
|
||||
4. **Monitor with Dask dashboard** for distributed computing
|
||||
5. **Save intermediate results** to avoid re-running long computations
|
||||
6. **Filter results** by importance threshold for downstream analysis
|
||||
7. **Use GRNBoost2 by default** unless specifically requiring GENIE3
|
||||
8. **Include multiprocessing guard** (`if __name__ == '__main__':`) in all scripts
|
||||
|
||||
## Quick Reference
|
||||
|
||||
**Basic inference:**
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
network = grnboost2(expression_data=expr_df, seed=42)
|
||||
```
|
||||
|
||||
**With TF specification:**
|
||||
```python
|
||||
network = grnboost2(expression_data=expr_df, tf_names=tf_list, seed=42)
|
||||
```
|
||||
|
||||
**With custom Dask client:**
|
||||
```python
|
||||
from dask.distributed import Client, LocalCluster
|
||||
cluster = LocalCluster(n_workers=4)
|
||||
client = Client(cluster)
|
||||
network = grnboost2(expression_data=expr_df, client_or_address=client, seed=42)
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
**Load TF names:**
|
||||
```python
|
||||
from arboreto.utils import load_tf_names
|
||||
tf_names = load_tf_names('transcription_factors.txt')
|
||||
```
|
||||
|
||||
**Transpose data:**
|
||||
```python
|
||||
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T
|
||||
```
|
||||
@@ -1,271 +0,0 @@
|
||||
# Arboreto API Reference
|
||||
|
||||
This document provides comprehensive API documentation for the arboreto package, a Python library for gene regulatory network (GRN) inference.
|
||||
|
||||
## Overview
|
||||
|
||||
Arboreto enables inference of gene regulatory networks from expression data using machine learning algorithms. It supports distributed computing via Dask for scalability from single machines to multi-node clusters.
|
||||
|
||||
**Current Version:** 0.1.5
|
||||
**GitHub:** https://github.com/tmoerman/arboreto
|
||||
**License:** BSD 3-Clause
|
||||
|
||||
## Core Algorithms
|
||||
|
||||
### GRNBoost2
|
||||
|
||||
The flagship algorithm for fast gene regulatory network inference using stochastic gradient boosting.
|
||||
|
||||
**Function:** `arboreto.algo.grnboost2()`
|
||||
|
||||
**Parameters:**
|
||||
- `expression_data` (pandas.DataFrame or numpy.ndarray): Expression matrix where rows are observations (cells/samples) and columns are genes. Required.
|
||||
- `gene_names` (list, optional): List of gene names matching column order. If None, uses DataFrame column names.
|
||||
- `tf_names` (list, optional): List of transcription factor names to consider as regulators. If None, all genes are considered potential regulators.
|
||||
- `seed` (int, optional): Random seed for reproducibility. Recommended when consistent results are needed across runs.
|
||||
- `client_or_address` (dask.distributed.Client or str, optional): Custom Dask client or scheduler address for distributed computing. If None, creates a default local client.
|
||||
- `verbose` (bool, optional): Enable verbose output for debugging.
|
||||
|
||||
**Returns:**
|
||||
- pandas.DataFrame with columns `['TF', 'target', 'importance']` representing inferred regulatory links. Each row represents a regulatory relationship with an importance score.
|
||||
|
||||
**Algorithm Details:**
|
||||
- Uses stochastic gradient boosting with early-stopping regularization
|
||||
- Much faster than GENIE3, especially for large datasets (tens of thousands of observations)
|
||||
- Extracts important features from trained regression models to identify regulatory relationships
|
||||
- Recommended as the default choice for most use cases
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
import pandas as pd
|
||||
|
||||
# Load expression data
|
||||
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
|
||||
tf_list = ['TF1', 'TF2', 'TF3'] # Optional: specify TFs
|
||||
|
||||
# Run inference
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_list,
|
||||
seed=42 # For reproducibility
|
||||
)
|
||||
|
||||
# Save results
|
||||
network.to_csv('output_network.tsv', sep='\t', index=False)
|
||||
```
|
||||
|
||||
### GENIE3
|
||||
|
||||
Classical gene regulatory network inference using Random Forest regression.
|
||||
|
||||
**Function:** `arboreto.algo.genie3()`
|
||||
|
||||
**Parameters:**
|
||||
Same as GRNBoost2 (see above).
|
||||
|
||||
**Returns:**
|
||||
Same format as GRNBoost2 (see above).
|
||||
|
||||
**Algorithm Details:**
|
||||
- Uses Random Forest or ExtraTrees regression models
|
||||
- Blueprint for multiple regression GRN inference strategy
|
||||
- More computationally expensive than GRNBoost2
|
||||
- Better suited for smaller datasets or when maximum accuracy is needed
|
||||
|
||||
**When to Use GENIE3 vs GRNBoost2:**
|
||||
- **Use GRNBoost2:** For large datasets, faster results, or when computational resources are limited
|
||||
- **Use GENIE3:** For smaller datasets, when following established protocols, or for comparison with published results
|
||||
|
||||
## Module Structure
|
||||
|
||||
### arboreto.algo
|
||||
|
||||
Primary module for typical users. Contains high-level inference functions.
|
||||
|
||||
**Main Functions:**
|
||||
- `grnboost2()` - Fast GRN inference using gradient boosting
|
||||
- `genie3()` - Classical GRN inference using Random Forest
|
||||
|
||||
### arboreto.core
|
||||
|
||||
Advanced module for power users. Contains low-level framework components for custom implementations.
|
||||
|
||||
**Use cases:**
|
||||
- Custom inference pipelines
|
||||
- Algorithm modifications
|
||||
- Performance tuning
|
||||
|
||||
### arboreto.utils
|
||||
|
||||
Utility functions for common data processing tasks.
|
||||
|
||||
**Key Functions:**
|
||||
- `load_tf_names(filename)` - Load transcription factor names from file
|
||||
- Reads a text file with one TF name per line
|
||||
- Returns a list of TF names
|
||||
- Example: `tf_names = load_tf_names('transcription_factors.txt')`
|
||||
|
||||
## Data Format Requirements
|
||||
|
||||
### Input Format
|
||||
|
||||
**Expression Matrix:**
|
||||
- **Format:** pandas DataFrame or numpy ndarray
|
||||
- **Orientation:** Rows = observations (cells/samples), Columns = genes
|
||||
- **Convention:** Follows scikit-learn format
|
||||
- **Gene Names:** Column names (DataFrame) or separate `gene_names` parameter
|
||||
- **Data Type:** Numeric (float or int)
|
||||
|
||||
**Common Mistake:** If data is transposed (genes as rows), use pandas to transpose:
|
||||
```python
|
||||
expression_df = pd.read_csv('data.tsv', sep='\t', index_col=0).T
|
||||
```
|
||||
|
||||
**Transcription Factor List:**
|
||||
- **Format:** Python list of strings or text file (one TF per line)
|
||||
- **Optional:** If not provided, all genes are considered potential regulators
|
||||
- **Example:** `['Sox2', 'Oct4', 'Nanog']`
|
||||
|
||||
### Output Format
|
||||
|
||||
**Network DataFrame:**
|
||||
- **Columns:**
|
||||
- `TF` (str): Transcription factor (regulator) gene name
|
||||
- `target` (str): Target gene name
|
||||
- `importance` (float): Importance score of the regulatory relationship
|
||||
- **Interpretation:** Higher importance scores indicate stronger predicted regulatory relationships
|
||||
- **Sorting:** Typically sorted by importance (descending) for prioritization
|
||||
|
||||
**Example Output:**
|
||||
```
|
||||
TF target importance
|
||||
Sox2 Gene1 15.234
|
||||
Oct4 Gene1 12.456
|
||||
Sox2 Gene2 8.901
|
||||
```
|
||||
|
||||
## Distributed Computing with Dask
|
||||
|
||||
### Local Execution (Default)
|
||||
|
||||
Arboreto automatically creates a local Dask client if none is provided:
|
||||
|
||||
```python
|
||||
network = grnboost2(expression_data=expr_matrix, tf_names=tf_list)
|
||||
```
|
||||
|
||||
### Custom Local Cluster
|
||||
|
||||
For better control over resources or multiple inferences:
|
||||
|
||||
```python
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
# Configure cluster
|
||||
cluster = LocalCluster(
|
||||
n_workers=4,
|
||||
threads_per_worker=2,
|
||||
memory_limit='4GB'
|
||||
)
|
||||
client = Client(cluster)
|
||||
|
||||
# Run inference
|
||||
network = grnboost2(
|
||||
expression_data=expr_matrix,
|
||||
tf_names=tf_list,
|
||||
client_or_address=client
|
||||
)
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
### Distributed Cluster
|
||||
|
||||
For multi-node computation:
|
||||
|
||||
**On scheduler node:**
|
||||
```bash
|
||||
dask-scheduler --no-bokeh # Use --no-bokeh to avoid Bokeh errors
|
||||
```
|
||||
|
||||
**On worker nodes:**
|
||||
```bash
|
||||
dask-worker scheduler-address:8786 --local-dir /tmp
|
||||
```
|
||||
|
||||
**In Python script:**
|
||||
```python
|
||||
from dask.distributed import Client
|
||||
|
||||
client = Client('scheduler-address:8786')
|
||||
network = grnboost2(
|
||||
expression_data=expr_matrix,
|
||||
tf_names=tf_list,
|
||||
client_or_address=client
|
||||
)
|
||||
```
|
||||
|
||||
### Dask Dashboard
|
||||
|
||||
Monitor computation progress via the Dask dashboard:
|
||||
|
||||
```python
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
cluster = LocalCluster(diagnostics_port=8787)
|
||||
client = Client(cluster)
|
||||
|
||||
# Dashboard available at: http://localhost:8787
|
||||
```
|
||||
|
||||
## Reproducibility
|
||||
|
||||
To ensure reproducible results across runs:
|
||||
|
||||
```python
|
||||
network = grnboost2(
|
||||
expression_data=expr_matrix,
|
||||
tf_names=tf_list,
|
||||
seed=42 # Fixed seed ensures identical results
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** Without a seed parameter, results may vary slightly between runs due to randomness in the algorithms.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Management
|
||||
|
||||
- Expression matrices should fit in memory (RAM)
|
||||
- For very large datasets, consider:
|
||||
- Using a machine with more RAM
|
||||
- Distributing across multiple nodes
|
||||
- Preprocessing to reduce dimensionality
|
||||
|
||||
### Worker Configuration
|
||||
|
||||
- **Local execution:** Number of workers = number of CPU cores (default)
|
||||
- **Custom cluster:** Balance workers and threads based on available resources
|
||||
- **Distributed execution:** Ensure adequate `local_dir` space on worker nodes
|
||||
|
||||
### Algorithm Choice
|
||||
|
||||
- **GRNBoost2:** ~10-100x faster than GENIE3 for large datasets
|
||||
- **GENIE3:** More established but slower, better for small datasets (<10k observations)
|
||||
|
||||
## Integration with pySCENIC
|
||||
|
||||
Arboreto is a core component of the pySCENIC pipeline for single-cell RNA sequencing analysis:
|
||||
|
||||
1. **GRN Inference (Arboreto):** Infer regulatory networks using GRNBoost2
|
||||
2. **Regulon Prediction:** Prune network and identify regulons
|
||||
3. **Cell Type Identification:** Score regulons across cells
|
||||
|
||||
For pySCENIC workflows, arboreto is typically used in the first step to generate the initial regulatory network.
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
See the main SKILL.md for troubleshooting guidance.
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Basic GRN inference script using arboreto GRNBoost2.
|
||||
|
||||
This script demonstrates the standard workflow for gene regulatory network inference:
|
||||
1. Load expression data
|
||||
2. Optionally load transcription factor names
|
||||
3. Run GRNBoost2 inference
|
||||
4. Save results
|
||||
|
||||
Usage:
|
||||
python basic_grn_inference.py <expression_file> [options]
|
||||
|
||||
Example:
|
||||
python basic_grn_inference.py expression_data.tsv -t tf_names.txt -o network.tsv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Infer gene regulatory network using GRNBoost2'
|
||||
)
|
||||
parser.add_argument(
|
||||
'expression_file',
|
||||
help='Path to expression data file (TSV/CSV format)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-t', '--tf-file',
|
||||
help='Path to file containing transcription factor names (one per line)',
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
help='Output file path for network results',
|
||||
default='network_output.tsv'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-s', '--seed',
|
||||
type=int,
|
||||
help='Random seed for reproducibility',
|
||||
default=42
|
||||
)
|
||||
parser.add_argument(
|
||||
'--sep',
|
||||
help='Separator for input file (default: tab)',
|
||||
default='\t'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--transpose',
|
||||
action='store_true',
|
||||
help='Transpose the expression matrix (use if genes are rows)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load expression data
|
||||
print(f"Loading expression data from {args.expression_file}...")
|
||||
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
|
||||
|
||||
# Transpose if needed
|
||||
if args.transpose:
|
||||
print("Transposing expression matrix...")
|
||||
expression_data = expression_data.T
|
||||
|
||||
print(f"Expression data shape: {expression_data.shape}")
|
||||
print(f" Observations (rows): {expression_data.shape[0]}")
|
||||
print(f" Genes (columns): {expression_data.shape[1]}")
|
||||
|
||||
# Load TF names if provided
|
||||
tf_names = None
|
||||
if args.tf_file:
|
||||
print(f"Loading transcription factor names from {args.tf_file}...")
|
||||
tf_names = load_tf_names(args.tf_file)
|
||||
print(f" Found {len(tf_names)} transcription factors")
|
||||
else:
|
||||
print("No TF file provided. Using all genes as potential regulators.")
|
||||
|
||||
# Run GRNBoost2
|
||||
print("\nRunning GRNBoost2 inference...")
|
||||
print(" (This may take a while depending on dataset size)")
|
||||
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
print(f"\nInference complete!")
|
||||
print(f" Total regulatory links inferred: {len(network)}")
|
||||
print(f" Unique TFs: {network['TF'].nunique()}")
|
||||
print(f" Unique targets: {network['target'].nunique()}")
|
||||
|
||||
# Save results
|
||||
print(f"\nSaving results to {args.output}...")
|
||||
network.to_csv(args.output, sep='\t', index=False)
|
||||
|
||||
# Display top 10 predictions
|
||||
print("\nTop 10 predicted regulatory relationships:")
|
||||
print(network.head(10).to_string(index=False))
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare GRNBoost2 and GENIE3 algorithms on the same dataset.
|
||||
|
||||
This script runs both algorithms on the same expression data and compares:
|
||||
- Runtime
|
||||
- Number of predicted links
|
||||
- Top predicted relationships
|
||||
- Overlap between predictions
|
||||
|
||||
Usage:
|
||||
python compare_algorithms.py <expression_file> [options]
|
||||
|
||||
Example:
|
||||
python compare_algorithms.py expression_data.tsv -t tf_names.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2, genie3
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
|
||||
def compare_networks(network1, network2, name1, name2, top_n=100):
|
||||
"""Compare two inferred networks."""
|
||||
print(f"\n{'='*60}")
|
||||
print("Network Comparison")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Basic statistics
|
||||
print(f"\n{name1} Statistics:")
|
||||
print(f" Total links: {len(network1)}")
|
||||
print(f" Unique TFs: {network1['TF'].nunique()}")
|
||||
print(f" Unique targets: {network1['target'].nunique()}")
|
||||
print(f" Importance range: [{network1['importance'].min():.3f}, {network1['importance'].max():.3f}]")
|
||||
|
||||
print(f"\n{name2} Statistics:")
|
||||
print(f" Total links: {len(network2)}")
|
||||
print(f" Unique TFs: {network2['TF'].nunique()}")
|
||||
print(f" Unique targets: {network2['target'].nunique()}")
|
||||
print(f" Importance range: [{network2['importance'].min():.3f}, {network2['importance'].max():.3f}]")
|
||||
|
||||
# Compare top predictions
|
||||
print(f"\nTop {top_n} Predictions Overlap:")
|
||||
|
||||
# Create edge sets for top N predictions
|
||||
top_edges1 = set(
|
||||
zip(network1.head(top_n)['TF'], network1.head(top_n)['target'])
|
||||
)
|
||||
top_edges2 = set(
|
||||
zip(network2.head(top_n)['TF'], network2.head(top_n)['target'])
|
||||
)
|
||||
|
||||
# Calculate overlap
|
||||
overlap = top_edges1 & top_edges2
|
||||
only_net1 = top_edges1 - top_edges2
|
||||
only_net2 = top_edges2 - top_edges1
|
||||
|
||||
overlap_pct = (len(overlap) / top_n) * 100
|
||||
|
||||
print(f" Shared edges: {len(overlap)} ({overlap_pct:.1f}%)")
|
||||
print(f" Only in {name1}: {len(only_net1)}")
|
||||
print(f" Only in {name2}: {len(only_net2)}")
|
||||
|
||||
# Show some example overlapping edges
|
||||
if overlap:
|
||||
print(f"\nExample overlapping predictions:")
|
||||
for i, (tf, target) in enumerate(list(overlap)[:5], 1):
|
||||
print(f" {i}. {tf} -> {target}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compare GRNBoost2 and GENIE3 algorithms'
|
||||
)
|
||||
parser.add_argument(
|
||||
'expression_file',
|
||||
help='Path to expression data file (TSV/CSV format)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-t', '--tf-file',
|
||||
help='Path to file containing transcription factor names (one per line)',
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'--grnboost2-output',
|
||||
help='Output file path for GRNBoost2 results',
|
||||
default='grnboost2_network.tsv'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--genie3-output',
|
||||
help='Output file path for GENIE3 results',
|
||||
default='genie3_network.tsv'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-s', '--seed',
|
||||
type=int,
|
||||
help='Random seed for reproducibility',
|
||||
default=42
|
||||
)
|
||||
parser.add_argument(
|
||||
'--sep',
|
||||
help='Separator for input file (default: tab)',
|
||||
default='\t'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--transpose',
|
||||
action='store_true',
|
||||
help='Transpose the expression matrix (use if genes are rows)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--top-n',
|
||||
type=int,
|
||||
help='Number of top predictions to compare (default: 100)',
|
||||
default=100
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load expression data
|
||||
print(f"Loading expression data from {args.expression_file}...")
|
||||
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
|
||||
|
||||
# Transpose if needed
|
||||
if args.transpose:
|
||||
print("Transposing expression matrix...")
|
||||
expression_data = expression_data.T
|
||||
|
||||
print(f"Expression data shape: {expression_data.shape}")
|
||||
print(f" Observations (rows): {expression_data.shape[0]}")
|
||||
print(f" Genes (columns): {expression_data.shape[1]}")
|
||||
|
||||
# Load TF names if provided
|
||||
tf_names = None
|
||||
if args.tf_file:
|
||||
print(f"Loading transcription factor names from {args.tf_file}...")
|
||||
tf_names = load_tf_names(args.tf_file)
|
||||
print(f" Found {len(tf_names)} transcription factors")
|
||||
else:
|
||||
print("No TF file provided. Using all genes as potential regulators.")
|
||||
|
||||
# Run GRNBoost2
|
||||
print("\n" + "="*60)
|
||||
print("Running GRNBoost2...")
|
||||
print("="*60)
|
||||
start_time = time.time()
|
||||
|
||||
grnboost2_network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
grnboost2_time = time.time() - start_time
|
||||
print(f"GRNBoost2 completed in {grnboost2_time:.2f} seconds")
|
||||
|
||||
# Save GRNBoost2 results
|
||||
grnboost2_network.to_csv(args.grnboost2_output, sep='\t', index=False)
|
||||
print(f"Results saved to {args.grnboost2_output}")
|
||||
|
||||
# Run GENIE3
|
||||
print("\n" + "="*60)
|
||||
print("Running GENIE3...")
|
||||
print("="*60)
|
||||
start_time = time.time()
|
||||
|
||||
genie3_network = genie3(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
genie3_time = time.time() - start_time
|
||||
print(f"GENIE3 completed in {genie3_time:.2f} seconds")
|
||||
|
||||
# Save GENIE3 results
|
||||
genie3_network.to_csv(args.genie3_output, sep='\t', index=False)
|
||||
print(f"Results saved to {args.genie3_output}")
|
||||
|
||||
# Compare runtimes
|
||||
print("\n" + "="*60)
|
||||
print("Runtime Comparison")
|
||||
print("="*60)
|
||||
print(f"GRNBoost2: {grnboost2_time:.2f} seconds")
|
||||
print(f"GENIE3: {genie3_time:.2f} seconds")
|
||||
speedup = genie3_time / grnboost2_time
|
||||
print(f"Speedup: {speedup:.2f}x (GRNBoost2 is {speedup:.2f}x faster)")
|
||||
|
||||
# Compare networks
|
||||
compare_networks(
|
||||
grnboost2_network,
|
||||
genie3_network,
|
||||
"GRNBoost2",
|
||||
"GENIE3",
|
||||
top_n=args.top_n
|
||||
)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Comparison complete!")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Distributed GRN inference script using arboreto with custom Dask configuration.
|
||||
|
||||
This script demonstrates how to use arboreto with a custom Dask LocalCluster
|
||||
for better control over computational resources.
|
||||
|
||||
Usage:
|
||||
python distributed_inference.py <expression_file> [options]
|
||||
|
||||
Example:
|
||||
python distributed_inference.py expression_data.tsv -t tf_names.txt -w 8 -m 4GB
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from dask.distributed import Client, LocalCluster
|
||||
from arboreto.algo import grnboost2
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Distributed GRN inference using GRNBoost2 with custom Dask cluster'
|
||||
)
|
||||
parser.add_argument(
|
||||
'expression_file',
|
||||
help='Path to expression data file (TSV/CSV format)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-t', '--tf-file',
|
||||
help='Path to file containing transcription factor names (one per line)',
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
help='Output file path for network results',
|
||||
default='network_output.tsv'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-s', '--seed',
|
||||
type=int,
|
||||
help='Random seed for reproducibility',
|
||||
default=42
|
||||
)
|
||||
parser.add_argument(
|
||||
'-w', '--workers',
|
||||
type=int,
|
||||
help='Number of Dask workers',
|
||||
default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
'-m', '--memory-limit',
|
||||
help='Memory limit per worker (e.g., "4GB", "2000MB")',
|
||||
default='4GB'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--threads',
|
||||
type=int,
|
||||
help='Threads per worker',
|
||||
default=2
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dashboard-port',
|
||||
type=int,
|
||||
help='Port for Dask dashboard (default: 8787)',
|
||||
default=8787
|
||||
)
|
||||
parser.add_argument(
|
||||
'--sep',
|
||||
help='Separator for input file (default: tab)',
|
||||
default='\t'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--transpose',
|
||||
action='store_true',
|
||||
help='Transpose the expression matrix (use if genes are rows)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load expression data
|
||||
print(f"Loading expression data from {args.expression_file}...")
|
||||
expression_data = pd.read_csv(args.expression_file, sep=args.sep, index_col=0)
|
||||
|
||||
# Transpose if needed
|
||||
if args.transpose:
|
||||
print("Transposing expression matrix...")
|
||||
expression_data = expression_data.T
|
||||
|
||||
print(f"Expression data shape: {expression_data.shape}")
|
||||
print(f" Observations (rows): {expression_data.shape[0]}")
|
||||
print(f" Genes (columns): {expression_data.shape[1]}")
|
||||
|
||||
# Load TF names if provided
|
||||
tf_names = None
|
||||
if args.tf_file:
|
||||
print(f"Loading transcription factor names from {args.tf_file}...")
|
||||
tf_names = load_tf_names(args.tf_file)
|
||||
print(f" Found {len(tf_names)} transcription factors")
|
||||
else:
|
||||
print("No TF file provided. Using all genes as potential regulators.")
|
||||
|
||||
# Set up Dask cluster
|
||||
print(f"\nSetting up Dask LocalCluster...")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Threads per worker: {args.threads}")
|
||||
print(f" Memory limit per worker: {args.memory_limit}")
|
||||
print(f" Dashboard: http://localhost:{args.dashboard_port}")
|
||||
|
||||
cluster = LocalCluster(
|
||||
n_workers=args.workers,
|
||||
threads_per_worker=args.threads,
|
||||
memory_limit=args.memory_limit,
|
||||
diagnostics_port=args.dashboard_port
|
||||
)
|
||||
client = Client(cluster)
|
||||
|
||||
print(f"\nDask cluster ready!")
|
||||
print(f" Dashboard available at: {client.dashboard_link}")
|
||||
|
||||
# Run GRNBoost2
|
||||
print("\nRunning GRNBoost2 inference with distributed computation...")
|
||||
print(" (Monitor progress via the Dask dashboard)")
|
||||
|
||||
try:
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=args.seed,
|
||||
client_or_address=client
|
||||
)
|
||||
|
||||
print(f"\nInference complete!")
|
||||
print(f" Total regulatory links inferred: {len(network)}")
|
||||
print(f" Unique TFs: {network['TF'].nunique()}")
|
||||
print(f" Unique targets: {network['target'].nunique()}")
|
||||
|
||||
# Save results
|
||||
print(f"\nSaving results to {args.output}...")
|
||||
network.to_csv(args.output, sep='\t', index=False)
|
||||
|
||||
# Display top 10 predictions
|
||||
print("\nTop 10 predicted regulatory relationships:")
|
||||
print(network.head(10).to_string(index=False))
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
finally:
|
||||
# Clean up Dask resources
|
||||
print("\nClosing Dask cluster...")
|
||||
client.close()
|
||||
cluster.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,790 +0,0 @@
|
||||
---
|
||||
name: astropy
|
||||
description: "Astronomy toolkit. FITS I/O, celestial coordinate transforms, cosmology calculations, time systems, WCS, units, astronomical tables, for astronomical data analysis and imaging."
|
||||
---
|
||||
|
||||
# Astropy
|
||||
|
||||
## Overview
|
||||
|
||||
Astropy is the community standard Python library for astronomy, providing core functionality for astronomical data analysis and computation. This skill provides comprehensive guidance and tools for working with astropy's extensive capabilities across coordinate systems, file I/O, units and quantities, time systems, cosmology, modeling, and more.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Working with FITS files (reading, writing, inspecting, modifying)
|
||||
- Performing coordinate transformations between astronomical reference frames
|
||||
- Calculating cosmological distances, ages, or other quantities
|
||||
- Handling astronomical time systems and conversions
|
||||
- Working with physical units and dimensional analysis
|
||||
- Processing astronomical data tables with specialized column types
|
||||
- Fitting models to astronomical data
|
||||
- Converting between pixel and world coordinates (WCS)
|
||||
- Performing robust statistical analysis on astronomical data
|
||||
- Visualizing astronomical images with proper scaling and stretching
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. FITS File Operations
|
||||
|
||||
FITS (Flexible Image Transport System) is the standard file format in astronomy. Astropy provides comprehensive FITS support.
|
||||
|
||||
**Quick FITS Inspection**:
|
||||
Use the included `scripts/fits_info.py` script for rapid file inspection:
|
||||
```bash
|
||||
python scripts/fits_info.py observation.fits
|
||||
python scripts/fits_info.py observation.fits --detailed
|
||||
python scripts/fits_info.py observation.fits --ext 1
|
||||
```
|
||||
|
||||
**Common FITS workflows**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
|
||||
# Read FITS file
|
||||
with fits.open('image.fits') as hdul:
|
||||
hdul.info() # Display structure
|
||||
data = hdul[0].data
|
||||
header = hdul[0].header
|
||||
|
||||
# Write FITS file
|
||||
fits.writeto('output.fits', data, header, overwrite=True)
|
||||
|
||||
# Quick access (less efficient for multiple operations)
|
||||
data = fits.getdata('image.fits', ext=0)
|
||||
header = fits.getheader('image.fits', ext=0)
|
||||
|
||||
# Update specific header keyword
|
||||
fits.setval('image.fits', 'OBJECT', value='M31')
|
||||
```
|
||||
|
||||
**Multi-extension FITS**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
|
||||
# Create multi-extension FITS
|
||||
primary = fits.PrimaryHDU(primary_data)
|
||||
image_ext = fits.ImageHDU(science_data, name='SCI')
|
||||
error_ext = fits.ImageHDU(error_data, name='ERR')
|
||||
|
||||
hdul = fits.HDUList([primary, image_ext, error_ext])
|
||||
hdul.writeto('multi_ext.fits', overwrite=True)
|
||||
```
|
||||
|
||||
**Binary tables**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
|
||||
# Read binary table
|
||||
with fits.open('catalog.fits') as hdul:
|
||||
table_data = hdul[1].data
|
||||
ra = table_data['RA']
|
||||
dec = table_data['DEC']
|
||||
|
||||
# Better: use astropy.table for table operations (see section 5)
|
||||
```
|
||||
|
||||
### 2. Coordinate Systems and Transformations
|
||||
|
||||
Astropy supports ~25 coordinate frames with seamless transformations.
|
||||
|
||||
**Quick Coordinate Conversion**:
|
||||
Use the included `scripts/coord_convert.py` script:
|
||||
```bash
|
||||
python scripts/coord_convert.py 10.68 41.27 --from icrs --to galactic
|
||||
python scripts/coord_convert.py --file coords.txt --from icrs --to galactic --output sexagesimal
|
||||
```
|
||||
|
||||
**Basic coordinate operations**:
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Create coordinate (multiple input formats supported)
|
||||
c = SkyCoord(ra=10.68*u.degree, dec=41.27*u.degree, frame='icrs')
|
||||
c = SkyCoord('00:42:44.3 +41:16:09', unit=(u.hourangle, u.deg))
|
||||
c = SkyCoord('00h42m44.3s +41d16m09s')
|
||||
|
||||
# Transform between frames
|
||||
c_galactic = c.galactic
|
||||
c_fk5 = c.fk5
|
||||
|
||||
print(f"Galactic: l={c_galactic.l.deg:.3f}, b={c_galactic.b.deg:.3f}")
|
||||
```
|
||||
|
||||
**Working with coordinate arrays**:
|
||||
```python
|
||||
import numpy as np
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Arrays of coordinates
|
||||
ra = np.array([10.1, 10.2, 10.3]) * u.degree
|
||||
dec = np.array([40.1, 40.2, 40.3]) * u.degree
|
||||
coords = SkyCoord(ra=ra, dec=dec, frame='icrs')
|
||||
|
||||
# Calculate separations
|
||||
sep = coords[0].separation(coords[1])
|
||||
print(f"Separation: {sep.to(u.arcmin)}")
|
||||
|
||||
# Position angle
|
||||
pa = coords[0].position_angle(coords[1])
|
||||
```
|
||||
|
||||
**Catalog matching**:
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
catalog1 = SkyCoord(ra=[10, 11, 12]*u.degree, dec=[40, 41, 42]*u.degree)
|
||||
catalog2 = SkyCoord(ra=[10.01, 11.02, 13]*u.degree, dec=[40.01, 41.01, 43]*u.degree)
|
||||
|
||||
# Find nearest neighbors
|
||||
idx, sep2d, dist3d = catalog1.match_to_catalog_sky(catalog2)
|
||||
|
||||
# Filter by separation threshold
|
||||
max_sep = 1 * u.arcsec
|
||||
matched = sep2d < max_sep
|
||||
```
|
||||
|
||||
**Horizontal coordinates (Alt/Az)**:
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord, EarthLocation, AltAz
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
location = EarthLocation(lat=40*u.deg, lon=-70*u.deg, height=300*u.m)
|
||||
obstime = Time('2023-01-01 03:00:00')
|
||||
target = SkyCoord(ra=10*u.degree, dec=40*u.degree, frame='icrs')
|
||||
|
||||
altaz_frame = AltAz(obstime=obstime, location=location)
|
||||
target_altaz = target.transform_to(altaz_frame)
|
||||
|
||||
print(f"Alt: {target_altaz.alt.deg:.2f}°, Az: {target_altaz.az.deg:.2f}°")
|
||||
```
|
||||
|
||||
**Available coordinate frames**:
|
||||
- `icrs` - International Celestial Reference System (default, preferred)
|
||||
- `fk5`, `fk4` - Fifth/Fourth Fundamental Katalog
|
||||
- `galactic` - Galactic coordinates
|
||||
- `supergalactic` - Supergalactic coordinates
|
||||
- `altaz` - Horizontal (altitude-azimuth) coordinates
|
||||
- `gcrs`, `cirs`, `itrs` - Earth-based systems
|
||||
- Ecliptic frames: `BarycentricMeanEcliptic`, `HeliocentricMeanEcliptic`, `GeocentricMeanEcliptic`
|
||||
|
||||
### 3. Units and Quantities
|
||||
|
||||
Physical units are fundamental to astronomical calculations. Astropy's units system provides dimensional analysis and automatic conversions.
|
||||
|
||||
**Basic unit operations**:
|
||||
```python
|
||||
import astropy.units as u
|
||||
|
||||
# Create quantities
|
||||
distance = 5.2 * u.parsec
|
||||
velocity = 300 * u.km / u.s
|
||||
time = 10 * u.year
|
||||
|
||||
# Convert units
|
||||
distance_ly = distance.to(u.lightyear)
|
||||
velocity_mps = velocity.to(u.m / u.s)
|
||||
|
||||
# Arithmetic with units
|
||||
wavelength = 500 * u.nm
|
||||
frequency = wavelength.to(u.Hz, equivalencies=u.spectral())
|
||||
```
|
||||
|
||||
**Working with arrays**:
|
||||
```python
|
||||
import numpy as np
|
||||
import astropy.units as u
|
||||
|
||||
wavelengths = np.array([400, 500, 600]) * u.nm
|
||||
frequencies = wavelengths.to(u.THz, equivalencies=u.spectral())
|
||||
|
||||
fluxes = np.array([1.2, 2.3, 1.8]) * u.Jy
|
||||
luminosities = 4 * np.pi * (10*u.pc)**2 * fluxes
|
||||
```
|
||||
|
||||
**Important equivalencies**:
|
||||
- `u.spectral()` - Convert wavelength ↔ frequency ↔ energy
|
||||
- `u.doppler_optical(rest)` - Optical Doppler velocity
|
||||
- `u.doppler_radio(rest)` - Radio Doppler velocity
|
||||
- `u.doppler_relativistic(rest)` - Relativistic Doppler
|
||||
- `u.temperature()` - Temperature unit conversions
|
||||
- `u.brightness_temperature(freq)` - Brightness temperature
|
||||
|
||||
**Physical constants**:
|
||||
```python
|
||||
from astropy import constants as const
|
||||
|
||||
print(const.c) # Speed of light
|
||||
print(const.G) # Gravitational constant
|
||||
print(const.M_sun) # Solar mass
|
||||
print(const.R_sun) # Solar radius
|
||||
print(const.L_sun) # Solar luminosity
|
||||
```
|
||||
|
||||
**Performance tip**: Use the `<<` operator for fast unit assignment to arrays:
|
||||
```python
|
||||
# Fast
|
||||
result = large_array << u.m
|
||||
|
||||
# Slower
|
||||
result = large_array * u.m
|
||||
```
|
||||
|
||||
### 4. Time Systems
|
||||
|
||||
Astronomical time systems require high precision and multiple time scales.
|
||||
|
||||
**Creating time objects**:
|
||||
```python
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
# Various input formats
|
||||
t1 = Time('2023-01-01T00:00:00', format='isot', scale='utc')
|
||||
t2 = Time(2459945.5, format='jd', scale='utc')
|
||||
t3 = Time(['2023-01-01', '2023-06-01'], format='iso')
|
||||
|
||||
# Convert formats
|
||||
print(t1.jd) # Julian Date
|
||||
print(t1.mjd) # Modified Julian Date
|
||||
print(t1.unix) # Unix timestamp
|
||||
print(t1.iso) # ISO format
|
||||
|
||||
# Convert time scales
|
||||
print(t1.tai) # International Atomic Time
|
||||
print(t1.tt) # Terrestrial Time
|
||||
print(t1.tdb) # Barycentric Dynamical Time
|
||||
```
|
||||
|
||||
**Time arithmetic**:
|
||||
```python
|
||||
from astropy.time import Time, TimeDelta
|
||||
import astropy.units as u
|
||||
|
||||
t1 = Time('2023-01-01T00:00:00')
|
||||
dt = TimeDelta(1*u.day)
|
||||
|
||||
t2 = t1 + dt
|
||||
diff = t2 - t1
|
||||
print(diff.to(u.hour))
|
||||
|
||||
# Array of times
|
||||
times = t1 + np.arange(10) * u.day
|
||||
```
|
||||
|
||||
**Astronomical time calculations**:
|
||||
```python
|
||||
from astropy.time import Time
|
||||
from astropy.coordinates import SkyCoord, EarthLocation
|
||||
import astropy.units as u
|
||||
|
||||
location = EarthLocation(lat=40*u.deg, lon=-70*u.deg)
|
||||
t = Time('2023-01-01T00:00:00')
|
||||
|
||||
# Local sidereal time
|
||||
lst = t.sidereal_time('apparent', longitude=location.lon)
|
||||
|
||||
# Barycentric correction
|
||||
target = SkyCoord(ra=10*u.deg, dec=40*u.deg)
|
||||
ltt = t.light_travel_time(target, location=location)
|
||||
t_bary = t.tdb + ltt
|
||||
```
|
||||
|
||||
**Available time scales**:
|
||||
- `utc` - Coordinated Universal Time
|
||||
- `tai` - International Atomic Time
|
||||
- `tt` - Terrestrial Time
|
||||
- `tcb`, `tcg` - Barycentric/Geocentric Coordinate Time
|
||||
- `tdb` - Barycentric Dynamical Time
|
||||
- `ut1` - Universal Time
|
||||
|
||||
### 5. Data Tables
|
||||
|
||||
Astropy tables provide astronomy-specific enhancements over pandas.
|
||||
|
||||
**Creating and manipulating tables**:
|
||||
```python
|
||||
from astropy.table import Table
|
||||
import astropy.units as u
|
||||
|
||||
# Create table
|
||||
t = Table()
|
||||
t['name'] = ['Star1', 'Star2', 'Star3']
|
||||
t['ra'] = [10.5, 11.2, 12.3] * u.degree
|
||||
t['dec'] = [41.2, 42.1, 43.5] * u.degree
|
||||
t['flux'] = [1.2, 2.3, 0.8] * u.Jy
|
||||
|
||||
# Column metadata
|
||||
t['flux'].description = 'Flux at 1.4 GHz'
|
||||
t['flux'].format = '.2f'
|
||||
|
||||
# Add calculated column
|
||||
t['flux_mJy'] = t['flux'].to(u.mJy)
|
||||
|
||||
# Filter and sort
|
||||
bright = t[t['flux'] > 1.0 * u.Jy]
|
||||
t.sort('flux')
|
||||
```
|
||||
|
||||
**Table I/O**:
|
||||
```python
|
||||
from astropy.table import Table
|
||||
|
||||
# Read (format auto-detected from extension)
|
||||
t = Table.read('data.fits')
|
||||
t = Table.read('data.csv', format='ascii.csv')
|
||||
t = Table.read('data.ecsv', format='ascii.ecsv') # Preserves units!
|
||||
t = Table.read('data.votable', format='votable')
|
||||
|
||||
# Write
|
||||
t.write('output.fits', overwrite=True)
|
||||
t.write('output.ecsv', format='ascii.ecsv', overwrite=True)
|
||||
```
|
||||
|
||||
**Advanced operations**:
|
||||
```python
|
||||
from astropy.table import Table, join, vstack, hstack
|
||||
|
||||
# Join tables (like SQL)
|
||||
joined = join(table1, table2, keys='id')
|
||||
|
||||
# Stack tables
|
||||
combined_rows = vstack([t1, t2])
|
||||
combined_cols = hstack([t1, t2])
|
||||
|
||||
# Grouping and aggregation
|
||||
t.group_by('category').groups.aggregate(np.mean)
|
||||
```
|
||||
|
||||
**Tables with astronomical objects**:
|
||||
```python
|
||||
from astropy.table import Table
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
coords = SkyCoord(ra=[10, 11, 12]*u.deg, dec=[40, 41, 42]*u.deg)
|
||||
times = Time(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||
|
||||
t = Table([coords, times], names=['coords', 'obstime'])
|
||||
print(t['coords'][0].ra) # Access coordinate properties
|
||||
```
|
||||
|
||||
### 6. Cosmological Calculations
|
||||
|
||||
Quick cosmology calculations using standard models.
|
||||
|
||||
**Using the cosmology calculator**:
|
||||
```bash
|
||||
python scripts/cosmo_calc.py 0.5 1.0 1.5
|
||||
python scripts/cosmo_calc.py --range 0 3 0.5 --cosmology Planck18
|
||||
python scripts/cosmo_calc.py 0.5 --verbose
|
||||
python scripts/cosmo_calc.py --convert 1000 --from luminosity_distance
|
||||
```
|
||||
|
||||
**Programmatic usage**:
|
||||
```python
|
||||
from astropy.cosmology import Planck18
|
||||
import astropy.units as u
|
||||
import numpy as np
|
||||
|
||||
cosmo = Planck18
|
||||
|
||||
# Calculate distances
|
||||
z = 1.5
|
||||
d_L = cosmo.luminosity_distance(z)
|
||||
d_A = cosmo.angular_diameter_distance(z)
|
||||
d_C = cosmo.comoving_distance(z)
|
||||
|
||||
# Time calculations
|
||||
age = cosmo.age(z)
|
||||
lookback = cosmo.lookback_time(z)
|
||||
|
||||
# Hubble parameter
|
||||
H_z = cosmo.H(z)
|
||||
|
||||
print(f"At z={z}:")
|
||||
print(f" Luminosity distance: {d_L:.2f}")
|
||||
print(f" Age of universe: {age:.2f}")
|
||||
```
|
||||
|
||||
**Convert observables**:
|
||||
```python
|
||||
from astropy.cosmology import Planck18
|
||||
import astropy.units as u
|
||||
|
||||
cosmo = Planck18
|
||||
z = 1.5
|
||||
|
||||
# Angular size to physical size
|
||||
d_A = cosmo.angular_diameter_distance(z)
|
||||
angular_size = 1 * u.arcsec
|
||||
physical_size = (angular_size.to(u.radian) * d_A).to(u.kpc)
|
||||
|
||||
# Flux to luminosity
|
||||
flux = 1e-17 * u.erg / u.s / u.cm**2
|
||||
d_L = cosmo.luminosity_distance(z)
|
||||
luminosity = flux * 4 * np.pi * d_L**2
|
||||
|
||||
# Find redshift for given distance
|
||||
from astropy.cosmology import z_at_value
|
||||
z = z_at_value(cosmo.luminosity_distance, 1000*u.Mpc)
|
||||
```
|
||||
|
||||
**Available cosmologies**:
|
||||
- `Planck18`, `Planck15`, `Planck13` - Planck satellite parameters
|
||||
- `WMAP9`, `WMAP7`, `WMAP5` - WMAP satellite parameters
|
||||
- Custom: `FlatLambdaCDM(H0=70*u.km/u.s/u.Mpc, Om0=0.3)`
|
||||
|
||||
### 7. Model Fitting
|
||||
|
||||
Fit mathematical models to astronomical data.
|
||||
|
||||
**1D fitting example**:
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
import numpy as np
|
||||
|
||||
# Generate data
|
||||
x = np.linspace(0, 10, 100)
|
||||
y_data = 10 * np.exp(-0.5 * ((x - 5) / 1)**2) + np.random.normal(0, 0.5, x.shape)
|
||||
|
||||
# Create and fit model
|
||||
g_init = models.Gaussian1D(amplitude=8, mean=4.5, stddev=0.8)
|
||||
fitter = fitting.LevMarLSQFitter()
|
||||
g_fit = fitter(g_init, x, y_data)
|
||||
|
||||
# Results
|
||||
print(f"Amplitude: {g_fit.amplitude.value:.3f}")
|
||||
print(f"Mean: {g_fit.mean.value:.3f}")
|
||||
print(f"Stddev: {g_fit.stddev.value:.3f}")
|
||||
|
||||
# Evaluate fitted model
|
||||
y_fit = g_fit(x)
|
||||
```
|
||||
|
||||
**Common 1D models**:
|
||||
- `Gaussian1D` - Gaussian profile
|
||||
- `Lorentz1D` - Lorentzian profile
|
||||
- `Voigt1D` - Voigt profile
|
||||
- `Moffat1D` - Moffat profile (PSF modeling)
|
||||
- `Polynomial1D` - Polynomial
|
||||
- `PowerLaw1D` - Power law
|
||||
- `BlackBody` - Blackbody spectrum
|
||||
|
||||
**Common 2D models**:
|
||||
- `Gaussian2D` - 2D Gaussian
|
||||
- `Moffat2D` - 2D Moffat (stellar PSF)
|
||||
- `AiryDisk2D` - Airy disk (diffraction pattern)
|
||||
- `Disk2D` - Circular disk
|
||||
|
||||
**Fitting with constraints**:
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
|
||||
g = models.Gaussian1D(amplitude=10, mean=5, stddev=1)
|
||||
|
||||
# Set bounds
|
||||
g.amplitude.bounds = (0, None) # Positive only
|
||||
g.mean.bounds = (4, 6) # Constrain center
|
||||
|
||||
# Fix parameters
|
||||
g.stddev.fixed = True
|
||||
|
||||
# Compound models
|
||||
model = models.Gaussian1D() + models.Polynomial1D(degree=1)
|
||||
```
|
||||
|
||||
**Available fitters**:
|
||||
- `LinearLSQFitter` - Linear least squares (fast, for linear models)
|
||||
- `LevMarLSQFitter` - Levenberg-Marquardt (most common)
|
||||
- `SimplexLSQFitter` - Downhill simplex
|
||||
- `SLSQPLSQFitter` - Sequential Least Squares with constraints
|
||||
|
||||
### 8. World Coordinate System (WCS)
|
||||
|
||||
Transform between pixel and world coordinates in images.
|
||||
|
||||
**Basic WCS usage**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
|
||||
# Read FITS with WCS
|
||||
hdu = fits.open('image.fits')[0]
|
||||
wcs = WCS(hdu.header)
|
||||
|
||||
# Pixel to world
|
||||
ra, dec = wcs.pixel_to_world_values(100, 200)
|
||||
|
||||
# World to pixel
|
||||
x, y = wcs.world_to_pixel_values(ra, dec)
|
||||
|
||||
# Using SkyCoord (more powerful)
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
coord = SkyCoord(ra=150*u.deg, dec=-30*u.deg)
|
||||
x, y = wcs.world_to_pixel(coord)
|
||||
```
|
||||
|
||||
**Plotting with WCS**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
from astropy.visualization import ImageNormalize, LogStretch, PercentileInterval
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
hdu = fits.open('image.fits')[0]
|
||||
wcs = WCS(hdu.header)
|
||||
data = hdu.data
|
||||
|
||||
# Create figure with WCS projection
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection=wcs)
|
||||
|
||||
# Plot with coordinate grid
|
||||
norm = ImageNormalize(data, interval=PercentileInterval(99.5),
|
||||
stretch=LogStretch())
|
||||
ax.imshow(data, norm=norm, origin='lower', cmap='viridis')
|
||||
|
||||
# Coordinate labels and grid
|
||||
ax.set_xlabel('RA')
|
||||
ax.set_ylabel('Dec')
|
||||
ax.coords.grid(color='white', alpha=0.5)
|
||||
```
|
||||
|
||||
### 9. Statistics and Data Processing
|
||||
|
||||
Robust statistical tools for astronomical data.
|
||||
|
||||
**Sigma clipping** (remove outliers):
|
||||
```python
|
||||
from astropy.stats import sigma_clip, sigma_clipped_stats
|
||||
|
||||
# Remove outliers
|
||||
clipped = sigma_clip(data, sigma=3, maxiters=5)
|
||||
|
||||
# Get statistics on cleaned data
|
||||
mean, median, std = sigma_clipped_stats(data, sigma=3)
|
||||
|
||||
# Use clipped data
|
||||
background = median
|
||||
signal = data - background
|
||||
snr = signal / std
|
||||
```
|
||||
|
||||
**Other statistical functions**:
|
||||
```python
|
||||
from astropy.stats import mad_std, biweight_location, biweight_scale
|
||||
|
||||
# Robust standard deviation
|
||||
std_robust = mad_std(data)
|
||||
|
||||
# Robust central location
|
||||
center = biweight_location(data)
|
||||
|
||||
# Robust scale
|
||||
scale = biweight_scale(data)
|
||||
```
|
||||
|
||||
### 10. Visualization
|
||||
|
||||
Display astronomical images with proper scaling.
|
||||
|
||||
**Image normalization and stretching**:
|
||||
```python
|
||||
from astropy.visualization import (ImageNormalize, MinMaxInterval,
|
||||
PercentileInterval, ZScaleInterval,
|
||||
SqrtStretch, LogStretch, PowerStretch,
|
||||
AsinhStretch)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Common combination: percentile interval + sqrt stretch
|
||||
norm = ImageNormalize(data,
|
||||
interval=PercentileInterval(99),
|
||||
stretch=SqrtStretch())
|
||||
|
||||
plt.imshow(data, norm=norm, origin='lower', cmap='gray')
|
||||
plt.colorbar()
|
||||
```
|
||||
|
||||
**Available intervals** (determine min/max):
|
||||
- `MinMaxInterval()` - Use actual min/max
|
||||
- `PercentileInterval(percentile)` - Clip to percentile (e.g., 99%)
|
||||
- `ZScaleInterval()` - IRAF's zscale algorithm
|
||||
- `ManualInterval(vmin, vmax)` - Specify manually
|
||||
|
||||
**Available stretches** (nonlinear scaling):
|
||||
- `LinearStretch()` - Linear (default)
|
||||
- `SqrtStretch()` - Square root (common for images)
|
||||
- `LogStretch()` - Logarithmic (for high dynamic range)
|
||||
- `PowerStretch(power)` - Power law
|
||||
- `AsinhStretch()` - Arcsinh (good for wide range)
|
||||
|
||||
## Bundled Resources
|
||||
|
||||
### scripts/
|
||||
|
||||
**`fits_info.py`** - Comprehensive FITS file inspection tool
|
||||
```bash
|
||||
python scripts/fits_info.py observation.fits
|
||||
python scripts/fits_info.py observation.fits --detailed
|
||||
python scripts/fits_info.py observation.fits --ext 1
|
||||
```
|
||||
|
||||
**`coord_convert.py`** - Batch coordinate transformation utility
|
||||
```bash
|
||||
python scripts/coord_convert.py 10.68 41.27 --from icrs --to galactic
|
||||
python scripts/coord_convert.py --file coords.txt --from icrs --to galactic
|
||||
```
|
||||
|
||||
**`cosmo_calc.py`** - Cosmological calculator
|
||||
```bash
|
||||
python scripts/cosmo_calc.py 0.5 1.0 1.5
|
||||
python scripts/cosmo_calc.py --range 0 3 0.5 --cosmology Planck18
|
||||
```
|
||||
|
||||
### references/
|
||||
|
||||
**`module_overview.md`** - Comprehensive reference of all astropy subpackages, classes, and methods. Consult this for detailed API information, available functions, and module capabilities.
|
||||
|
||||
**`common_workflows.md`** - Complete working examples for common astronomical data analysis tasks. Contains full code examples for FITS operations, coordinate transformations, cosmology, modeling, and complete analysis pipelines.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use context managers for FITS files**:
|
||||
```python
|
||||
with fits.open('file.fits') as hdul:
|
||||
# Work with file
|
||||
```
|
||||
|
||||
2. **Prefer astropy.table over raw FITS tables** for better unit/metadata support
|
||||
|
||||
3. **Use SkyCoord for coordinates** (high-level interface) rather than low-level frame classes
|
||||
|
||||
4. **Always attach units** to quantities when possible for dimensional safety
|
||||
|
||||
5. **Use ECSV format** for saving tables when you want to preserve units and metadata
|
||||
|
||||
6. **Vectorize coordinate operations** rather than looping for performance
|
||||
|
||||
7. **Use memmap=True** when opening large FITS files to save memory
|
||||
|
||||
8. **Install Bottleneck** package for faster statistics operations
|
||||
|
||||
9. **Pre-compute composite units** for repeated operations to improve performance
|
||||
|
||||
10. **Consult `references/module_overview.md`** for detailed module information and `references/common_workflows.md`** for complete working examples
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern: FITS → Process → FITS
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.stats import sigma_clipped_stats
|
||||
|
||||
# Read
|
||||
with fits.open('input.fits') as hdul:
|
||||
data = hdul[0].data
|
||||
header = hdul[0].header
|
||||
|
||||
# Process
|
||||
mean, median, std = sigma_clipped_stats(data, sigma=3)
|
||||
processed = (data - median) / std
|
||||
|
||||
# Write
|
||||
fits.writeto('output.fits', processed, header, overwrite=True)
|
||||
```
|
||||
|
||||
### Pattern: Catalog Matching
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.table import Table
|
||||
import astropy.units as u
|
||||
|
||||
# Load catalogs
|
||||
cat1 = Table.read('catalog1.fits')
|
||||
cat2 = Table.read('catalog2.fits')
|
||||
|
||||
# Create coordinate objects
|
||||
coords1 = SkyCoord(ra=cat1['RA'], dec=cat1['DEC'], unit=u.degree)
|
||||
coords2 = SkyCoord(ra=cat2['RA'], dec=cat2['DEC'], unit=u.degree)
|
||||
|
||||
# Match
|
||||
idx, sep2d, dist3d = coords1.match_to_catalog_sky(coords2)
|
||||
|
||||
# Filter by separation
|
||||
max_sep = 1 * u.arcsec
|
||||
matched_mask = sep2d < max_sep
|
||||
|
||||
# Create matched catalog
|
||||
matched_cat1 = cat1[matched_mask]
|
||||
matched_cat2 = cat2[idx[matched_mask]]
|
||||
```
|
||||
|
||||
### Pattern: Time Series Analysis
|
||||
```python
|
||||
from astropy.time import Time
|
||||
from astropy.timeseries import TimeSeries
|
||||
import astropy.units as u
|
||||
|
||||
# Create time series
|
||||
times = Time(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||
flux = [1.2, 2.3, 1.8] * u.Jy
|
||||
|
||||
ts = TimeSeries(time=times)
|
||||
ts['flux'] = flux
|
||||
|
||||
# Fold on period
|
||||
from astropy.timeseries import aggregate_downsample
|
||||
period = 1.5 * u.day
|
||||
folded = ts.fold(period=period)
|
||||
```
|
||||
|
||||
### Pattern: Image Display with WCS
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
hdu = fits.open('image.fits')[0]
|
||||
wcs = WCS(hdu.header)
|
||||
data = hdu.data
|
||||
|
||||
fig = plt.figure(figsize=(10, 10))
|
||||
ax = fig.add_subplot(111, projection=wcs)
|
||||
|
||||
norm = ImageNormalize(data, interval=PercentileInterval(99),
|
||||
stretch=SqrtStretch())
|
||||
im = ax.imshow(data, norm=norm, origin='lower', cmap='viridis')
|
||||
|
||||
ax.set_xlabel('RA')
|
||||
ax.set_ylabel('Dec')
|
||||
ax.coords.grid(color='white', alpha=0.5, linestyle='solid')
|
||||
plt.colorbar(im, ax=ax)
|
||||
```
|
||||
|
||||
## Installation Note
|
||||
|
||||
Ensure astropy is installed in the Python environment:
|
||||
```bash
|
||||
pip install astropy
|
||||
```
|
||||
|
||||
For additional performance and features:
|
||||
```bash
|
||||
pip install astropy[all] # Includes optional dependencies
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Official documentation: https://docs.astropy.org
|
||||
- Tutorials: https://learn.astropy.org
|
||||
- API reference: Consult `references/module_overview.md` in this skill
|
||||
- Working examples: Consult `references/common_workflows.md` in this skill
|
||||
@@ -1,618 +0,0 @@
|
||||
# Common Astropy Workflows
|
||||
|
||||
This document describes frequently used workflows when working with astronomical data using astropy.
|
||||
|
||||
## 1. Working with FITS Files
|
||||
|
||||
### Basic FITS Reading
|
||||
```python
|
||||
from astropy.io import fits
|
||||
import numpy as np
|
||||
|
||||
# Open and examine structure
|
||||
with fits.open('observation.fits') as hdul:
|
||||
hdul.info()
|
||||
|
||||
# Access primary HDU
|
||||
primary_hdr = hdul[0].header
|
||||
primary_data = hdul[0].data
|
||||
|
||||
# Access extension
|
||||
ext_data = hdul[1].data
|
||||
ext_hdr = hdul[1].header
|
||||
|
||||
# Read specific header keywords
|
||||
object_name = primary_hdr['OBJECT']
|
||||
exposure = primary_hdr['EXPTIME']
|
||||
```
|
||||
|
||||
### Writing FITS Files
|
||||
```python
|
||||
# Create new FITS file
|
||||
from astropy.io import fits
|
||||
import numpy as np
|
||||
|
||||
# Create data
|
||||
data = np.random.random((100, 100))
|
||||
|
||||
# Create primary HDU
|
||||
hdu = fits.PrimaryHDU(data)
|
||||
hdu.header['OBJECT'] = 'M31'
|
||||
hdu.header['EXPTIME'] = 300.0
|
||||
|
||||
# Write to file
|
||||
hdu.writeto('output.fits', overwrite=True)
|
||||
|
||||
# Multi-extension FITS
|
||||
hdul = fits.HDUList([
|
||||
fits.PrimaryHDU(data1),
|
||||
fits.ImageHDU(data2, name='SCI'),
|
||||
fits.ImageHDU(data3, name='ERR')
|
||||
])
|
||||
hdul.writeto('multi_ext.fits', overwrite=True)
|
||||
```
|
||||
|
||||
### FITS Table Operations
|
||||
```python
|
||||
from astropy.io import fits
|
||||
|
||||
# Read binary table
|
||||
with fits.open('catalog.fits') as hdul:
|
||||
table_data = hdul[1].data
|
||||
|
||||
# Access columns
|
||||
ra = table_data['RA']
|
||||
dec = table_data['DEC']
|
||||
mag = table_data['MAG']
|
||||
|
||||
# Filter data
|
||||
bright = table_data[table_data['MAG'] < 15]
|
||||
|
||||
# Write binary table
|
||||
from astropy.table import Table
|
||||
import astropy.units as u
|
||||
|
||||
t = Table([ra, dec, mag], names=['RA', 'DEC', 'MAG'])
|
||||
t['RA'].unit = u.degree
|
||||
t['DEC'].unit = u.degree
|
||||
t.write('output_catalog.fits', format='fits', overwrite=True)
|
||||
```
|
||||
|
||||
## 2. Coordinate Transformations
|
||||
|
||||
### Basic Coordinate Creation and Transformation
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Create from RA/Dec
|
||||
c = SkyCoord(ra=10.68458*u.degree, dec=41.26917*u.degree, frame='icrs')
|
||||
|
||||
# Alternative creation methods
|
||||
c = SkyCoord('00:42:44.3 +41:16:09', unit=(u.hourangle, u.deg))
|
||||
c = SkyCoord('00h42m44.3s +41d16m09s')
|
||||
|
||||
# Transform to different frames
|
||||
c_gal = c.galactic
|
||||
c_fk5 = c.fk5
|
||||
print(f"Galactic: l={c_gal.l.deg}, b={c_gal.b.deg}")
|
||||
```
|
||||
|
||||
### Coordinate Arrays and Separations
|
||||
```python
|
||||
import numpy as np
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Create array of coordinates
|
||||
ra_array = np.array([10.1, 10.2, 10.3]) * u.degree
|
||||
dec_array = np.array([40.1, 40.2, 40.3]) * u.degree
|
||||
coords = SkyCoord(ra=ra_array, dec=dec_array, frame='icrs')
|
||||
|
||||
# Calculate separations
|
||||
c1 = SkyCoord(ra=10*u.degree, dec=40*u.degree)
|
||||
c2 = SkyCoord(ra=11*u.degree, dec=41*u.degree)
|
||||
sep = c1.separation(c2)
|
||||
print(f"Separation: {sep.to(u.arcmin)}")
|
||||
|
||||
# Position angle
|
||||
pa = c1.position_angle(c2)
|
||||
```
|
||||
|
||||
### Catalog Matching
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord, match_coordinates_sky
|
||||
import astropy.units as u
|
||||
|
||||
# Two catalogs of coordinates
|
||||
catalog1 = SkyCoord(ra=[10, 11, 12]*u.degree, dec=[40, 41, 42]*u.degree)
|
||||
catalog2 = SkyCoord(ra=[10.01, 11.02, 13]*u.degree, dec=[40.01, 41.01, 43]*u.degree)
|
||||
|
||||
# Find nearest neighbors
|
||||
idx, sep2d, dist3d = catalog1.match_to_catalog_sky(catalog2)
|
||||
|
||||
# Filter by separation threshold
|
||||
max_sep = 1 * u.arcsec
|
||||
matched = sep2d < max_sep
|
||||
matching_indices = idx[matched]
|
||||
```
|
||||
|
||||
### Horizontal Coordinates (Alt/Az)
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord, EarthLocation, AltAz
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
# Observer location
|
||||
location = EarthLocation(lat=40*u.deg, lon=-70*u.deg, height=300*u.m)
|
||||
|
||||
# Observation time
|
||||
obstime = Time('2023-01-01 03:00:00')
|
||||
|
||||
# Target coordinate
|
||||
target = SkyCoord(ra=10*u.degree, dec=40*u.degree, frame='icrs')
|
||||
|
||||
# Transform to Alt/Az
|
||||
altaz_frame = AltAz(obstime=obstime, location=location)
|
||||
target_altaz = target.transform_to(altaz_frame)
|
||||
|
||||
print(f"Altitude: {target_altaz.alt.deg}")
|
||||
print(f"Azimuth: {target_altaz.az.deg}")
|
||||
```
|
||||
|
||||
## 3. Units and Quantities
|
||||
|
||||
### Basic Unit Operations
|
||||
```python
|
||||
import astropy.units as u
|
||||
|
||||
# Create quantities
|
||||
distance = 5.2 * u.parsec
|
||||
time = 10 * u.year
|
||||
velocity = 300 * u.km / u.s
|
||||
|
||||
# Unit conversion
|
||||
distance_ly = distance.to(u.lightyear)
|
||||
velocity_mps = velocity.to(u.m / u.s)
|
||||
|
||||
# Arithmetic with units
|
||||
wavelength = 500 * u.nm
|
||||
frequency = wavelength.to(u.Hz, equivalencies=u.spectral())
|
||||
|
||||
# Compose/decompose units
|
||||
composite = (1 * u.kg * u.m**2 / u.s**2)
|
||||
print(composite.decompose()) # Base SI units
|
||||
print(composite.compose()) # Known compound units (Joule)
|
||||
```
|
||||
|
||||
### Working with Arrays
|
||||
```python
|
||||
import numpy as np
|
||||
import astropy.units as u
|
||||
|
||||
# Quantity arrays
|
||||
wavelengths = np.array([400, 500, 600]) * u.nm
|
||||
frequencies = wavelengths.to(u.THz, equivalencies=u.spectral())
|
||||
|
||||
# Mathematical operations preserve units
|
||||
fluxes = np.array([1.2, 2.3, 1.8]) * u.Jy
|
||||
luminosities = 4 * np.pi * (10*u.pc)**2 * fluxes
|
||||
```
|
||||
|
||||
### Custom Units and Equivalencies
|
||||
```python
|
||||
import astropy.units as u
|
||||
|
||||
# Define custom unit
|
||||
beam = u.def_unit('beam', 1.5e-10 * u.steradian)
|
||||
|
||||
# Register for session
|
||||
u.add_enabled_units([beam])
|
||||
|
||||
# Use in calculations
|
||||
flux_per_beam = 1.5 * u.Jy / beam
|
||||
|
||||
# Doppler equivalencies
|
||||
rest_wavelength = 656.3 * u.nm # H-alpha
|
||||
observed = 656.5 * u.nm
|
||||
velocity = observed.to(u.km/u.s,
|
||||
equivalencies=u.doppler_optical(rest_wavelength))
|
||||
```
|
||||
|
||||
## 4. Time Handling
|
||||
|
||||
### Time Creation and Conversion
|
||||
```python
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
# Create time objects
|
||||
t1 = Time('2023-01-01T00:00:00', format='isot', scale='utc')
|
||||
t2 = Time(2459945.5, format='jd', scale='utc')
|
||||
t3 = Time(['2023-01-01', '2023-06-01'], format='iso')
|
||||
|
||||
# Convert formats
|
||||
print(t1.jd) # Julian Date
|
||||
print(t1.mjd) # Modified Julian Date
|
||||
print(t1.unix) # Unix timestamp
|
||||
print(t1.iso) # ISO format
|
||||
|
||||
# Convert time scales
|
||||
print(t1.tai) # Convert to TAI
|
||||
print(t1.tt) # Convert to TT
|
||||
print(t1.tdb) # Convert to TDB
|
||||
```
|
||||
|
||||
### Time Arithmetic
|
||||
```python
|
||||
from astropy.time import Time, TimeDelta
|
||||
import astropy.units as u
|
||||
|
||||
t1 = Time('2023-01-01T00:00:00')
|
||||
dt = TimeDelta(1*u.day)
|
||||
|
||||
# Add time delta
|
||||
t2 = t1 + dt
|
||||
|
||||
# Difference between times
|
||||
diff = t2 - t1
|
||||
print(diff.to(u.hour))
|
||||
|
||||
# Array of times
|
||||
times = t1 + np.arange(10) * u.day
|
||||
```
|
||||
|
||||
### Sidereal Time and Astronomical Calculations
|
||||
```python
|
||||
from astropy.time import Time
|
||||
from astropy.coordinates import EarthLocation
|
||||
import astropy.units as u
|
||||
|
||||
location = EarthLocation(lat=40*u.deg, lon=-70*u.deg)
|
||||
t = Time('2023-01-01T00:00:00')
|
||||
|
||||
# Local sidereal time
|
||||
lst = t.sidereal_time('apparent', longitude=location.lon)
|
||||
|
||||
# Light travel time correction
|
||||
from astropy.coordinates import SkyCoord
|
||||
target = SkyCoord(ra=10*u.deg, dec=40*u.deg)
|
||||
ltt_bary = t.light_travel_time(target, location=location)
|
||||
t_bary = t + ltt_bary
|
||||
```
|
||||
|
||||
## 5. Tables and Data Management
|
||||
|
||||
### Creating and Manipulating Tables
|
||||
```python
|
||||
from astropy.table import Table, Column
|
||||
import astropy.units as u
|
||||
import numpy as np
|
||||
|
||||
# Create table
|
||||
t = Table()
|
||||
t['name'] = ['Star1', 'Star2', 'Star3']
|
||||
t['ra'] = [10.5, 11.2, 12.3] * u.degree
|
||||
t['dec'] = [41.2, 42.1, 43.5] * u.degree
|
||||
t['flux'] = [1.2, 2.3, 0.8] * u.Jy
|
||||
|
||||
# Add column metadata
|
||||
t['flux'].description = 'Flux at 1.4 GHz'
|
||||
t['flux'].format = '.2f'
|
||||
|
||||
# Add new column
|
||||
t['flux_mJy'] = t['flux'].to(u.mJy)
|
||||
|
||||
# Filter rows
|
||||
bright = t[t['flux'] > 1.0 * u.Jy]
|
||||
|
||||
# Sort
|
||||
t.sort('flux')
|
||||
```
|
||||
|
||||
### Table I/O
|
||||
```python
|
||||
from astropy.table import Table
|
||||
|
||||
# Read various formats
|
||||
t = Table.read('data.fits')
|
||||
t = Table.read('data.csv', format='ascii.csv')
|
||||
t = Table.read('data.ecsv', format='ascii.ecsv') # Preserves units
|
||||
t = Table.read('data.votable', format='votable')
|
||||
|
||||
# Write various formats
|
||||
t.write('output.fits', overwrite=True)
|
||||
t.write('output.csv', format='ascii.csv', overwrite=True)
|
||||
t.write('output.ecsv', format='ascii.ecsv', overwrite=True)
|
||||
t.write('output.votable', format='votable', overwrite=True)
|
||||
```
|
||||
|
||||
### Advanced Table Operations
|
||||
```python
|
||||
from astropy.table import Table, join, vstack, hstack
|
||||
|
||||
# Join tables
|
||||
t1 = Table([[1, 2], ['a', 'b']], names=['id', 'val1'])
|
||||
t2 = Table([[1, 2], ['c', 'd']], names=['id', 'val2'])
|
||||
joined = join(t1, t2, keys='id')
|
||||
|
||||
# Stack tables vertically
|
||||
combined = vstack([t1, t2])
|
||||
|
||||
# Stack horizontally
|
||||
combined = hstack([t1, t2])
|
||||
|
||||
# Grouping
|
||||
t.group_by('category').groups.aggregate(np.mean)
|
||||
```
|
||||
|
||||
### Tables with Astronomical Objects
|
||||
```python
|
||||
from astropy.table import Table
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.time import Time
|
||||
import astropy.units as u
|
||||
|
||||
# Table with SkyCoord column
|
||||
coords = SkyCoord(ra=[10, 11, 12]*u.deg, dec=[40, 41, 42]*u.deg)
|
||||
times = Time(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||
|
||||
t = Table([coords, times], names=['coords', 'obstime'])
|
||||
|
||||
# Access individual coordinates
|
||||
print(t['coords'][0].ra)
|
||||
print(t['coords'][0].dec)
|
||||
```
|
||||
|
||||
## 6. Cosmological Calculations
|
||||
|
||||
### Distance Calculations
|
||||
```python
|
||||
from astropy.cosmology import Planck18, FlatLambdaCDM
|
||||
import astropy.units as u
|
||||
import numpy as np
|
||||
|
||||
# Use built-in cosmology
|
||||
cosmo = Planck18
|
||||
|
||||
# Redshifts
|
||||
z = np.linspace(0, 5, 50)
|
||||
|
||||
# Calculate distances
|
||||
comoving_dist = cosmo.comoving_distance(z)
|
||||
angular_diam_dist = cosmo.angular_diameter_distance(z)
|
||||
luminosity_dist = cosmo.luminosity_distance(z)
|
||||
|
||||
# Age of universe
|
||||
age_at_z = cosmo.age(z)
|
||||
lookback_time = cosmo.lookback_time(z)
|
||||
|
||||
# Hubble parameter
|
||||
H_z = cosmo.H(z)
|
||||
```
|
||||
|
||||
### Converting Observables
|
||||
```python
|
||||
from astropy.cosmology import Planck18
|
||||
import astropy.units as u
|
||||
|
||||
cosmo = Planck18
|
||||
z = 1.5
|
||||
|
||||
# Angular diameter distance
|
||||
d_A = cosmo.angular_diameter_distance(z)
|
||||
|
||||
# Convert angular size to physical size
|
||||
angular_size = 1 * u.arcsec
|
||||
physical_size = (angular_size.to(u.radian) * d_A).to(u.kpc)
|
||||
|
||||
# Convert flux to luminosity
|
||||
flux = 1e-17 * u.erg / u.s / u.cm**2
|
||||
d_L = cosmo.luminosity_distance(z)
|
||||
luminosity = flux * 4 * np.pi * d_L**2
|
||||
|
||||
# Find redshift for given distance
|
||||
from astropy.cosmology import z_at_value
|
||||
z_result = z_at_value(cosmo.luminosity_distance, 1000*u.Mpc)
|
||||
```
|
||||
|
||||
### Custom Cosmology
|
||||
```python
|
||||
from astropy.cosmology import FlatLambdaCDM
|
||||
import astropy.units as u
|
||||
|
||||
# Define custom cosmology
|
||||
my_cosmo = FlatLambdaCDM(H0=70 * u.km/u.s/u.Mpc,
|
||||
Om0=0.3,
|
||||
Tcmb0=2.725 * u.K)
|
||||
|
||||
# Use it for calculations
|
||||
print(my_cosmo.age(0))
|
||||
print(my_cosmo.luminosity_distance(1.5))
|
||||
```
|
||||
|
||||
## 7. Model Fitting
|
||||
|
||||
### Fitting 1D Models
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Generate data with noise
|
||||
x = np.linspace(0, 10, 100)
|
||||
true_model = models.Gaussian1D(amplitude=10, mean=5, stddev=1)
|
||||
y = true_model(x) + np.random.normal(0, 0.5, x.shape)
|
||||
|
||||
# Create and fit model
|
||||
g_init = models.Gaussian1D(amplitude=8, mean=4.5, stddev=0.8)
|
||||
fitter = fitting.LevMarLSQFitter()
|
||||
g_fit = fitter(g_init, x, y)
|
||||
|
||||
# Plot results
|
||||
plt.plot(x, y, 'o', label='Data')
|
||||
plt.plot(x, g_fit(x), label='Fit')
|
||||
plt.legend()
|
||||
|
||||
# Get fitted parameters
|
||||
print(f"Amplitude: {g_fit.amplitude.value}")
|
||||
print(f"Mean: {g_fit.mean.value}")
|
||||
print(f"Stddev: {g_fit.stddev.value}")
|
||||
```
|
||||
|
||||
### Fitting with Constraints
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
|
||||
# Set parameter bounds
|
||||
g = models.Gaussian1D(amplitude=10, mean=5, stddev=1)
|
||||
g.amplitude.bounds = (0, None) # Positive only
|
||||
g.mean.bounds = (4, 6) # Constrain center
|
||||
g.stddev.fixed = True # Fix width
|
||||
|
||||
# Tie parameters (for multi-component models)
|
||||
g1 = models.Gaussian1D(amplitude=10, mean=5, stddev=1, name='g1')
|
||||
g2 = models.Gaussian1D(amplitude=5, mean=6, stddev=1, name='g2')
|
||||
g2.stddev.tied = lambda model: model.g1.stddev
|
||||
|
||||
# Compound model
|
||||
model = g1 + g2
|
||||
```
|
||||
|
||||
### 2D Image Fitting
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
import numpy as np
|
||||
|
||||
# Create 2D data
|
||||
y, x = np.mgrid[0:100, 0:100]
|
||||
z = models.Gaussian2D(amplitude=100, x_mean=50, y_mean=50,
|
||||
x_stddev=5, y_stddev=5)(x, y)
|
||||
z += np.random.normal(0, 5, z.shape)
|
||||
|
||||
# Fit 2D Gaussian
|
||||
g_init = models.Gaussian2D(amplitude=90, x_mean=48, y_mean=48,
|
||||
x_stddev=4, y_stddev=4)
|
||||
fitter = fitting.LevMarLSQFitter()
|
||||
g_fit = fitter(g_init, x, y, z)
|
||||
|
||||
# Get parameters
|
||||
print(f"Center: ({g_fit.x_mean.value}, {g_fit.y_mean.value})")
|
||||
print(f"Width: ({g_fit.x_stddev.value}, {g_fit.y_stddev.value})")
|
||||
```
|
||||
|
||||
## 8. Image Processing and Visualization
|
||||
|
||||
### Image Display with Proper Scaling
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Read FITS image
|
||||
data = fits.getdata('image.fits')
|
||||
|
||||
# Apply normalization
|
||||
norm = ImageNormalize(data,
|
||||
interval=PercentileInterval(99),
|
||||
stretch=SqrtStretch())
|
||||
|
||||
# Display
|
||||
plt.imshow(data, norm=norm, origin='lower', cmap='gray')
|
||||
plt.colorbar()
|
||||
```
|
||||
|
||||
### WCS Plotting
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
from astropy.visualization import ImageNormalize, LogStretch, PercentileInterval
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Read FITS with WCS
|
||||
hdu = fits.open('image.fits')[0]
|
||||
wcs = WCS(hdu.header)
|
||||
data = hdu.data
|
||||
|
||||
# Create figure with WCS projection
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection=wcs)
|
||||
|
||||
# Plot with coordinate grid
|
||||
norm = ImageNormalize(data, interval=PercentileInterval(99.5),
|
||||
stretch=LogStretch())
|
||||
im = ax.imshow(data, norm=norm, origin='lower', cmap='viridis')
|
||||
|
||||
# Add coordinate labels
|
||||
ax.set_xlabel('RA')
|
||||
ax.set_ylabel('Dec')
|
||||
ax.coords.grid(color='white', alpha=0.5)
|
||||
plt.colorbar(im)
|
||||
```
|
||||
|
||||
### Sigma Clipping and Statistics
|
||||
```python
|
||||
from astropy.stats import sigma_clip, sigma_clipped_stats
|
||||
import numpy as np
|
||||
|
||||
# Data with outliers
|
||||
data = np.random.normal(100, 15, 1000)
|
||||
data[0:50] = np.random.normal(200, 10, 50) # Add outliers
|
||||
|
||||
# Sigma clipping
|
||||
clipped = sigma_clip(data, sigma=3, maxiters=5)
|
||||
|
||||
# Get statistics on clipped data
|
||||
mean, median, std = sigma_clipped_stats(data, sigma=3)
|
||||
|
||||
print(f"Mean: {mean:.2f}")
|
||||
print(f"Median: {median:.2f}")
|
||||
print(f"Std: {std:.2f}")
|
||||
print(f"Clipped {clipped.mask.sum()} values")
|
||||
```
|
||||
|
||||
## 9. Complete Analysis Example
|
||||
|
||||
### Photometry Pipeline
|
||||
```python
|
||||
from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.stats import sigma_clipped_stats
|
||||
from astropy.visualization import ImageNormalize, LogStretch
|
||||
import astropy.units as u
|
||||
import numpy as np
|
||||
|
||||
# Read FITS file
|
||||
hdu = fits.open('observation.fits')[0]
|
||||
data = hdu.data
|
||||
header = hdu.header
|
||||
wcs = WCS(header)
|
||||
|
||||
# Calculate background statistics
|
||||
mean, median, std = sigma_clipped_stats(data, sigma=3.0)
|
||||
print(f"Background: {median:.2f} +/- {std:.2f}")
|
||||
|
||||
# Subtract background
|
||||
data_sub = data - median
|
||||
|
||||
# Known source coordinates
|
||||
source_coord = SkyCoord(ra='10:42:30', dec='+41:16:09', unit=(u.hourangle, u.deg))
|
||||
|
||||
# Convert to pixel coordinates
|
||||
x_pix, y_pix = wcs.world_to_pixel(source_coord)
|
||||
|
||||
# Simple aperture photometry
|
||||
aperture_radius = 10 # pixels
|
||||
y, x = np.ogrid[:data.shape[0], :data.shape[1]]
|
||||
mask = (x - x_pix)**2 + (y - y_pix)**2 <= aperture_radius**2
|
||||
|
||||
aperture_sum = np.sum(data_sub[mask])
|
||||
npix = np.sum(mask)
|
||||
|
||||
print(f"Source position: ({x_pix:.1f}, {y_pix:.1f})")
|
||||
print(f"Aperture sum: {aperture_sum:.2f}")
|
||||
print(f"S/N: {aperture_sum / (std * np.sqrt(npix)):.2f}")
|
||||
```
|
||||
|
||||
This workflow document provides practical examples for common astronomical data analysis tasks using astropy.
|
||||
@@ -1,340 +0,0 @@
|
||||
# Astropy Module Overview
|
||||
|
||||
This document provides a comprehensive reference of all major astropy subpackages and their capabilities.
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### astropy.units
|
||||
**Purpose**: Handle physical units and dimensional analysis in computations.
|
||||
|
||||
**Key Classes**:
|
||||
- `Quantity` - Combines numerical values with units
|
||||
- `Unit` - Represents physical units
|
||||
|
||||
**Common Operations**:
|
||||
```python
|
||||
import astropy.units as u
|
||||
distance = 5 * u.meter
|
||||
time = 2 * u.second
|
||||
velocity = distance / time # Returns Quantity in m/s
|
||||
wavelength = 500 * u.nm
|
||||
frequency = wavelength.to(u.Hz, equivalencies=u.spectral())
|
||||
```
|
||||
|
||||
**Equivalencies**:
|
||||
- `u.spectral()` - Convert wavelength ↔ frequency
|
||||
- `u.doppler_optical()`, `u.doppler_radio()` - Velocity conversions
|
||||
- `u.temperature()` - Temperature unit conversions
|
||||
- `u.pixel_scale()` - Pixel to physical units
|
||||
|
||||
### astropy.constants
|
||||
**Purpose**: Provide physical and astronomical constants.
|
||||
|
||||
**Common Constants**:
|
||||
- `c` - Speed of light
|
||||
- `G` - Gravitational constant
|
||||
- `h` - Planck constant
|
||||
- `M_sun`, `R_sun`, `L_sun` - Solar mass, radius, luminosity
|
||||
- `M_earth`, `R_earth` - Earth mass, radius
|
||||
- `pc`, `au` - Parsec, astronomical unit
|
||||
|
||||
### astropy.time
|
||||
**Purpose**: Represent and manipulate times and dates with astronomical precision.
|
||||
|
||||
**Time Scales**:
|
||||
- `UTC` - Coordinated Universal Time
|
||||
- `TAI` - International Atomic Time
|
||||
- `TT` - Terrestrial Time
|
||||
- `TCB`, `TCG` - Barycentric/Geocentric Coordinate Time
|
||||
- `TDB` - Barycentric Dynamical Time
|
||||
- `UT1` - Universal Time
|
||||
|
||||
**Common Formats**:
|
||||
- `iso`, `isot` - ISO 8601 strings
|
||||
- `jd`, `mjd` - Julian/Modified Julian Date
|
||||
- `unix`, `gps` - Unix/GPS timestamps
|
||||
- `datetime` - Python datetime objects
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.time import Time
|
||||
t = Time('2023-01-01T00:00:00', format='isot', scale='utc')
|
||||
print(t.mjd) # Modified Julian Date
|
||||
print(t.jd) # Julian Date
|
||||
print(t.tt) # Convert to TT scale
|
||||
```
|
||||
|
||||
### astropy.table
|
||||
**Purpose**: Work with tabular data optimized for astronomical applications.
|
||||
|
||||
**Key Features**:
|
||||
- Native support for astropy Quantity, Time, and SkyCoord columns
|
||||
- Multi-dimensional columns
|
||||
- Metadata preservation (units, descriptions, formats)
|
||||
- Advanced operations: joins, grouping, binning
|
||||
- File I/O via unified interface
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.table import Table
|
||||
import astropy.units as u
|
||||
|
||||
t = Table()
|
||||
t['name'] = ['Star1', 'Star2', 'Star3']
|
||||
t['ra'] = [10.5, 11.2, 12.3] * u.degree
|
||||
t['dec'] = [41.2, 42.1, 43.5] * u.degree
|
||||
t['flux'] = [1.2, 2.3, 0.8] * u.Jy
|
||||
```
|
||||
|
||||
## Coordinates and World Coordinate Systems
|
||||
|
||||
### astropy.coordinates
|
||||
**Purpose**: Represent and transform celestial coordinates.
|
||||
|
||||
**Primary Interface**: `SkyCoord` - High-level class for sky positions
|
||||
|
||||
**Coordinate Frames**:
|
||||
- `ICRS` - International Celestial Reference System (default)
|
||||
- `FK5`, `FK4` - Fifth/Fourth Fundamental Katalog
|
||||
- `Galactic`, `Supergalactic` - Galactic coordinates
|
||||
- `AltAz` - Horizontal (altitude-azimuth) coordinates
|
||||
- `GCRS`, `CIRS`, `ITRS` - Earth-based systems
|
||||
- `BarycentricMeanEcliptic`, `HeliocentricMeanEcliptic`, `GeocentricMeanEcliptic` - Ecliptic coordinates
|
||||
|
||||
**Common Operations**:
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Create coordinate
|
||||
c = SkyCoord(ra=10.625*u.degree, dec=41.2*u.degree, frame='icrs')
|
||||
|
||||
# Transform to galactic
|
||||
c_gal = c.galactic
|
||||
|
||||
# Calculate separation
|
||||
c2 = SkyCoord(ra=11*u.degree, dec=42*u.degree, frame='icrs')
|
||||
sep = c.separation(c2)
|
||||
|
||||
# Match catalogs
|
||||
idx, sep2d, dist3d = c.match_to_catalog_sky(catalog_coords)
|
||||
```
|
||||
|
||||
### astropy.wcs
|
||||
**Purpose**: Handle World Coordinate System transformations for astronomical images.
|
||||
|
||||
**Key Class**: `WCS` - Maps between pixel and world coordinates
|
||||
|
||||
**Common Use Cases**:
|
||||
- Convert pixel coordinates to RA/Dec
|
||||
- Convert RA/Dec to pixel coordinates
|
||||
- Handle distortion corrections (SIP, lookup tables)
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.wcs import WCS
|
||||
from astropy.io import fits
|
||||
|
||||
hdu = fits.open('image.fits')[0]
|
||||
wcs = WCS(hdu.header)
|
||||
|
||||
# Pixel to world
|
||||
ra, dec = wcs.pixel_to_world_values(100, 200)
|
||||
|
||||
# World to pixel
|
||||
x, y = wcs.world_to_pixel_values(ra, dec)
|
||||
```
|
||||
|
||||
## File I/O
|
||||
|
||||
### astropy.io.fits
|
||||
**Purpose**: Read and write FITS (Flexible Image Transport System) files.
|
||||
|
||||
**Key Classes**:
|
||||
- `HDUList` - Container for all HDUs in a file
|
||||
- `PrimaryHDU` - Primary header data unit
|
||||
- `ImageHDU` - Image extension
|
||||
- `BinTableHDU` - Binary table extension
|
||||
- `Header` - FITS header keywords
|
||||
|
||||
**Common Operations**:
|
||||
```python
|
||||
from astropy.io import fits
|
||||
|
||||
# Read FITS file
|
||||
with fits.open('file.fits') as hdul:
|
||||
hdul.info() # Display structure
|
||||
header = hdul[0].header
|
||||
data = hdul[0].data
|
||||
|
||||
# Write FITS file
|
||||
fits.writeto('output.fits', data, header)
|
||||
|
||||
# Update header keyword
|
||||
fits.setval('file.fits', 'OBJECT', value='M31')
|
||||
```
|
||||
|
||||
### astropy.io.ascii
|
||||
**Purpose**: Read and write ASCII tables in various formats.
|
||||
|
||||
**Supported Formats**:
|
||||
- Basic, CSV, tab-delimited
|
||||
- CDS/MRT (Machine Readable Tables)
|
||||
- IPAC, Daophot, SExtractor
|
||||
- LaTeX tables
|
||||
- HTML tables
|
||||
|
||||
### astropy.io.votable
|
||||
**Purpose**: Handle Virtual Observatory (VO) table format.
|
||||
|
||||
### astropy.io.misc
|
||||
**Purpose**: Additional formats including HDF5, Parquet, and YAML.
|
||||
|
||||
## Scientific Calculations
|
||||
|
||||
### astropy.cosmology
|
||||
**Purpose**: Perform cosmological calculations.
|
||||
|
||||
**Common Models**:
|
||||
- `FlatLambdaCDM` - Flat universe with cosmological constant (most common)
|
||||
- `LambdaCDM` - Universe with cosmological constant
|
||||
- `Planck18`, `Planck15`, `Planck13` - Pre-defined Planck parameters
|
||||
- `WMAP9`, `WMAP7`, `WMAP5` - Pre-defined WMAP parameters
|
||||
|
||||
**Common Methods**:
|
||||
```python
|
||||
from astropy.cosmology import FlatLambdaCDM, Planck18
|
||||
import astropy.units as u
|
||||
|
||||
cosmo = FlatLambdaCDM(H0=70, Om0=0.3)
|
||||
# Or use built-in: cosmo = Planck18
|
||||
|
||||
z = 1.5
|
||||
print(cosmo.age(z)) # Age of universe at z
|
||||
print(cosmo.luminosity_distance(z)) # Luminosity distance
|
||||
print(cosmo.angular_diameter_distance(z)) # Angular diameter distance
|
||||
print(cosmo.comoving_distance(z)) # Comoving distance
|
||||
print(cosmo.H(z)) # Hubble parameter at z
|
||||
```
|
||||
|
||||
### astropy.modeling
|
||||
**Purpose**: Framework for model evaluation and fitting.
|
||||
|
||||
**Model Categories**:
|
||||
- 1D models: Gaussian1D, Lorentz1D, Voigt1D, Polynomial1D
|
||||
- 2D models: Gaussian2D, Disk2D, Moffat2D
|
||||
- Physical models: BlackBody, Drude1D, NFW
|
||||
- Polynomial models: Chebyshev, Legendre
|
||||
|
||||
**Common Fitters**:
|
||||
- `LinearLSQFitter` - Linear least squares
|
||||
- `LevMarLSQFitter` - Levenberg-Marquardt
|
||||
- `SimplexLSQFitter` - Downhill simplex
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.modeling import models, fitting
|
||||
|
||||
# Create model
|
||||
g = models.Gaussian1D(amplitude=10, mean=5, stddev=1)
|
||||
|
||||
# Fit to data
|
||||
fitter = fitting.LevMarLSQFitter()
|
||||
fitted_model = fitter(g, x_data, y_data)
|
||||
```
|
||||
|
||||
### astropy.convolution
|
||||
**Purpose**: Convolve and filter astronomical data.
|
||||
|
||||
**Common Kernels**:
|
||||
- `Gaussian2DKernel` - 2D Gaussian smoothing
|
||||
- `Box2DKernel` - 2D boxcar smoothing
|
||||
- `Tophat2DKernel` - 2D tophat filter
|
||||
- Custom kernels via arrays
|
||||
|
||||
### astropy.stats
|
||||
**Purpose**: Statistical tools for astronomical data analysis.
|
||||
|
||||
**Key Functions**:
|
||||
- `sigma_clip()` - Remove outliers via sigma clipping
|
||||
- `sigma_clipped_stats()` - Compute mean, median, std with clipping
|
||||
- `mad_std()` - Median Absolute Deviation
|
||||
- `biweight_location()`, `biweight_scale()` - Robust statistics
|
||||
- `circmean()`, `circstd()` - Circular statistics
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.stats import sigma_clip, sigma_clipped_stats
|
||||
|
||||
# Remove outliers
|
||||
filtered_data = sigma_clip(data, sigma=3, maxiters=5)
|
||||
|
||||
# Get statistics
|
||||
mean, median, std = sigma_clipped_stats(data, sigma=3)
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### astropy.nddata
|
||||
**Purpose**: Handle N-dimensional datasets with metadata.
|
||||
|
||||
**Key Class**: `NDData` - Container for array data with units, uncertainty, mask, and WCS
|
||||
|
||||
### astropy.timeseries
|
||||
**Purpose**: Work with time series data.
|
||||
|
||||
**Key Classes**:
|
||||
- `TimeSeries` - Time-indexed data table
|
||||
- `BinnedTimeSeries` - Time-binned data
|
||||
|
||||
**Common Operations**:
|
||||
- Period finding (Lomb-Scargle)
|
||||
- Folding time series
|
||||
- Binning data
|
||||
|
||||
### astropy.visualization
|
||||
**Purpose**: Display astronomical data effectively.
|
||||
|
||||
**Key Features**:
|
||||
- Image normalization (LogStretch, PowerStretch, SqrtStretch, etc.)
|
||||
- Interval scaling (MinMaxInterval, PercentileInterval, ZScaleInterval)
|
||||
- WCSAxes for plotting with coordinate overlays
|
||||
- RGB image creation with stretching
|
||||
- Astronomical colormaps
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
from astropy.visualization import ImageNormalize, SqrtStretch, PercentileInterval
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
norm = ImageNormalize(data, interval=PercentileInterval(99),
|
||||
stretch=SqrtStretch())
|
||||
plt.imshow(data, norm=norm, origin='lower')
|
||||
```
|
||||
|
||||
## Utilities
|
||||
|
||||
### astropy.samp
|
||||
**Purpose**: Simple Application Messaging Protocol for inter-application communication.
|
||||
|
||||
**Use Case**: Connect Python scripts with other astronomical tools (e.g., DS9, TOPCAT).
|
||||
|
||||
## Module Import Patterns
|
||||
|
||||
**Standard imports**:
|
||||
```python
|
||||
import astropy.units as u
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.time import Time
|
||||
from astropy.io import fits
|
||||
from astropy.table import Table
|
||||
from astropy import constants as const
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Pre-compute composite units** for repeated operations
|
||||
2. **Use `<<` operator** for fast unit assignments: `array << u.m` instead of `array * u.m`
|
||||
3. **Vectorize operations** rather than looping over coordinates/times
|
||||
4. **Use memmap=True** when opening large FITS files
|
||||
5. **Install Bottleneck** for faster stats operations
|
||||
@@ -1,226 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Coordinate conversion utility for astronomical coordinates.
|
||||
|
||||
This script provides batch coordinate transformations between different
|
||||
astronomical coordinate systems using astropy.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
|
||||
def convert_coordinates(coords_input, input_frame='icrs', output_frame='galactic',
|
||||
input_format='decimal', output_format='decimal'):
|
||||
"""
|
||||
Convert astronomical coordinates between different frames.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coords_input : list of tuples or str
|
||||
Input coordinates as (lon, lat) pairs or strings
|
||||
input_frame : str
|
||||
Input coordinate frame (icrs, fk5, galactic, etc.)
|
||||
output_frame : str
|
||||
Output coordinate frame
|
||||
input_format : str
|
||||
Format of input coordinates ('decimal', 'sexagesimal', 'hourangle')
|
||||
output_format : str
|
||||
Format for output display ('decimal', 'sexagesimal', 'hourangle')
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
Converted coordinates
|
||||
"""
|
||||
results = []
|
||||
|
||||
for coord in coords_input:
|
||||
try:
|
||||
# Parse input coordinate
|
||||
if input_format == 'decimal':
|
||||
if isinstance(coord, str):
|
||||
parts = coord.split()
|
||||
lon, lat = float(parts[0]), float(parts[1])
|
||||
else:
|
||||
lon, lat = coord
|
||||
c = SkyCoord(lon*u.degree, lat*u.degree, frame=input_frame)
|
||||
|
||||
elif input_format == 'sexagesimal':
|
||||
c = SkyCoord(coord, frame=input_frame, unit=(u.hourangle, u.deg))
|
||||
|
||||
elif input_format == 'hourangle':
|
||||
if isinstance(coord, str):
|
||||
parts = coord.split()
|
||||
lon, lat = parts[0], parts[1]
|
||||
else:
|
||||
lon, lat = coord
|
||||
c = SkyCoord(lon, lat, frame=input_frame, unit=(u.hourangle, u.deg))
|
||||
|
||||
# Transform to output frame
|
||||
if output_frame == 'icrs':
|
||||
c_out = c.icrs
|
||||
elif output_frame == 'fk5':
|
||||
c_out = c.fk5
|
||||
elif output_frame == 'fk4':
|
||||
c_out = c.fk4
|
||||
elif output_frame == 'galactic':
|
||||
c_out = c.galactic
|
||||
elif output_frame == 'supergalactic':
|
||||
c_out = c.supergalactic
|
||||
else:
|
||||
c_out = c.transform_to(output_frame)
|
||||
|
||||
results.append(c_out)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error converting coordinate {coord}: {e}", file=sys.stderr)
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_output(coords, frame, output_format='decimal'):
|
||||
"""Format coordinates for display."""
|
||||
output = []
|
||||
|
||||
for c in coords:
|
||||
if c is None:
|
||||
output.append("ERROR")
|
||||
continue
|
||||
|
||||
if frame in ['icrs', 'fk5', 'fk4']:
|
||||
lon_name, lat_name = 'RA', 'Dec'
|
||||
lon = c.ra
|
||||
lat = c.dec
|
||||
elif frame == 'galactic':
|
||||
lon_name, lat_name = 'l', 'b'
|
||||
lon = c.l
|
||||
lat = c.b
|
||||
elif frame == 'supergalactic':
|
||||
lon_name, lat_name = 'sgl', 'sgb'
|
||||
lon = c.sgl
|
||||
lat = c.sgb
|
||||
else:
|
||||
lon_name, lat_name = 'lon', 'lat'
|
||||
lon = c.spherical.lon
|
||||
lat = c.spherical.lat
|
||||
|
||||
if output_format == 'decimal':
|
||||
out_str = f"{lon.degree:12.8f} {lat.degree:+12.8f}"
|
||||
elif output_format == 'sexagesimal':
|
||||
if frame in ['icrs', 'fk5', 'fk4']:
|
||||
out_str = f"{lon.to_string(unit=u.hourangle, sep=':', pad=True)} "
|
||||
out_str += f"{lat.to_string(unit=u.degree, sep=':', pad=True)}"
|
||||
else:
|
||||
out_str = f"{lon.to_string(unit=u.degree, sep=':', pad=True)} "
|
||||
out_str += f"{lat.to_string(unit=u.degree, sep=':', pad=True)}"
|
||||
elif output_format == 'hourangle':
|
||||
out_str = f"{lon.to_string(unit=u.hourangle, sep=' ', pad=True)} "
|
||||
out_str += f"{lat.to_string(unit=u.degree, sep=' ', pad=True)}"
|
||||
|
||||
output.append(out_str)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function for command-line usage."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert astronomical coordinates between different frames',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Supported frames: icrs, fk5, fk4, galactic, supergalactic
|
||||
|
||||
Input formats:
|
||||
decimal : Degrees (e.g., "10.68 41.27")
|
||||
sexagesimal : HMS/DMS (e.g., "00:42:44.3 +41:16:09")
|
||||
hourangle : Hours and degrees (e.g., "10.5h 41.5d")
|
||||
|
||||
Examples:
|
||||
%(prog)s --from icrs --to galactic "10.68 41.27"
|
||||
%(prog)s --from icrs --to galactic --input decimal --output sexagesimal "150.5 -30.2"
|
||||
%(prog)s --from galactic --to icrs "120.5 45.3"
|
||||
%(prog)s --file coords.txt --from icrs --to galactic
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('coordinates', nargs='*',
|
||||
help='Coordinates to convert (lon lat pairs)')
|
||||
parser.add_argument('-f', '--from', dest='input_frame', default='icrs',
|
||||
help='Input coordinate frame (default: icrs)')
|
||||
parser.add_argument('-t', '--to', dest='output_frame', default='galactic',
|
||||
help='Output coordinate frame (default: galactic)')
|
||||
parser.add_argument('-i', '--input', dest='input_format', default='decimal',
|
||||
choices=['decimal', 'sexagesimal', 'hourangle'],
|
||||
help='Input format (default: decimal)')
|
||||
parser.add_argument('-o', '--output', dest='output_format', default='decimal',
|
||||
choices=['decimal', 'sexagesimal', 'hourangle'],
|
||||
help='Output format (default: decimal)')
|
||||
parser.add_argument('--file', dest='input_file',
|
||||
help='Read coordinates from file (one per line)')
|
||||
parser.add_argument('--header', action='store_true',
|
||||
help='Print header line with coordinate names')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get coordinates from file or command line
|
||||
if args.input_file:
|
||||
try:
|
||||
with open(args.input_file, 'r') as f:
|
||||
coords = [line.strip() for line in f if line.strip()]
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{args.input_file}' not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
if not args.coordinates:
|
||||
print("Error: No coordinates provided.", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Combine pairs of arguments
|
||||
if args.input_format == 'decimal':
|
||||
coords = []
|
||||
i = 0
|
||||
while i < len(args.coordinates):
|
||||
if i + 1 < len(args.coordinates):
|
||||
coords.append(f"{args.coordinates[i]} {args.coordinates[i+1]}")
|
||||
i += 2
|
||||
else:
|
||||
print(f"Warning: Odd number of coordinates, skipping last value",
|
||||
file=sys.stderr)
|
||||
break
|
||||
else:
|
||||
coords = args.coordinates
|
||||
|
||||
# Convert coordinates
|
||||
converted = convert_coordinates(coords,
|
||||
input_frame=args.input_frame,
|
||||
output_frame=args.output_frame,
|
||||
input_format=args.input_format,
|
||||
output_format=args.output_format)
|
||||
|
||||
# Format and print output
|
||||
formatted = format_output(converted, args.output_frame, args.output_format)
|
||||
|
||||
# Print header if requested
|
||||
if args.header:
|
||||
if args.output_frame in ['icrs', 'fk5', 'fk4']:
|
||||
if args.output_format == 'decimal':
|
||||
print(f"{'RA (deg)':>12s} {'Dec (deg)':>13s}")
|
||||
else:
|
||||
print(f"{'RA':>25s} {'Dec':>26s}")
|
||||
elif args.output_frame == 'galactic':
|
||||
if args.output_format == 'decimal':
|
||||
print(f"{'l (deg)':>12s} {'b (deg)':>13s}")
|
||||
else:
|
||||
print(f"{'l':>25s} {'b':>26s}")
|
||||
|
||||
for line in formatted:
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,250 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Cosmological calculator using astropy.cosmology.
|
||||
|
||||
This script provides quick calculations of cosmological distances,
|
||||
ages, and other quantities for given redshifts.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
from astropy.cosmology import FlatLambdaCDM, Planck18, Planck15, WMAP9
|
||||
import astropy.units as u
|
||||
|
||||
|
||||
def calculate_cosmology(redshifts, cosmology='Planck18', H0=None, Om0=None):
|
||||
"""
|
||||
Calculate cosmological quantities for given redshifts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
redshifts : array-like
|
||||
Redshift values
|
||||
cosmology : str
|
||||
Cosmology to use ('Planck18', 'Planck15', 'WMAP9', 'custom')
|
||||
H0 : float, optional
|
||||
Hubble constant for custom cosmology (km/s/Mpc)
|
||||
Om0 : float, optional
|
||||
Matter density parameter for custom cosmology
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary containing calculated quantities
|
||||
"""
|
||||
# Select cosmology
|
||||
if cosmology == 'Planck18':
|
||||
cosmo = Planck18
|
||||
elif cosmology == 'Planck15':
|
||||
cosmo = Planck15
|
||||
elif cosmology == 'WMAP9':
|
||||
cosmo = WMAP9
|
||||
elif cosmology == 'custom':
|
||||
if H0 is None or Om0 is None:
|
||||
raise ValueError("Must provide H0 and Om0 for custom cosmology")
|
||||
cosmo = FlatLambdaCDM(H0=H0 * u.km/u.s/u.Mpc, Om0=Om0)
|
||||
else:
|
||||
raise ValueError(f"Unknown cosmology: {cosmology}")
|
||||
|
||||
z = np.atleast_1d(redshifts)
|
||||
|
||||
results = {
|
||||
'redshift': z,
|
||||
'cosmology': str(cosmo),
|
||||
'luminosity_distance': cosmo.luminosity_distance(z),
|
||||
'angular_diameter_distance': cosmo.angular_diameter_distance(z),
|
||||
'comoving_distance': cosmo.comoving_distance(z),
|
||||
'comoving_volume': cosmo.comoving_volume(z),
|
||||
'age': cosmo.age(z),
|
||||
'lookback_time': cosmo.lookback_time(z),
|
||||
'H': cosmo.H(z),
|
||||
'scale_factor': 1.0 / (1.0 + z)
|
||||
}
|
||||
|
||||
return results, cosmo
|
||||
|
||||
|
||||
def print_results(results, verbose=False, csv=False):
|
||||
"""Print calculation results."""
|
||||
|
||||
z = results['redshift']
|
||||
|
||||
if csv:
|
||||
# CSV output
|
||||
print("z,D_L(Mpc),D_A(Mpc),D_C(Mpc),Age(Gyr),t_lookback(Gyr),H(km/s/Mpc)")
|
||||
for i in range(len(z)):
|
||||
print(f"{z[i]:.6f},"
|
||||
f"{results['luminosity_distance'][i].value:.6f},"
|
||||
f"{results['angular_diameter_distance'][i].value:.6f},"
|
||||
f"{results['comoving_distance'][i].value:.6f},"
|
||||
f"{results['age'][i].value:.6f},"
|
||||
f"{results['lookback_time'][i].value:.6f},"
|
||||
f"{results['H'][i].value:.6f}")
|
||||
else:
|
||||
# Formatted table output
|
||||
if verbose:
|
||||
print(f"\nCosmology: {results['cosmology']}")
|
||||
print("-" * 80)
|
||||
|
||||
print(f"\n{'z':>8s} {'D_L':>12s} {'D_A':>12s} {'D_C':>12s} "
|
||||
f"{'Age':>10s} {'t_lb':>10s} {'H(z)':>10s}")
|
||||
print(f"{'':>8s} {'(Mpc)':>12s} {'(Mpc)':>12s} {'(Mpc)':>12s} "
|
||||
f"{'(Gyr)':>10s} {'(Gyr)':>10s} {'(km/s/Mpc)':>10s}")
|
||||
print("-" * 80)
|
||||
|
||||
for i in range(len(z)):
|
||||
print(f"{z[i]:8.4f} "
|
||||
f"{results['luminosity_distance'][i].value:12.3f} "
|
||||
f"{results['angular_diameter_distance'][i].value:12.3f} "
|
||||
f"{results['comoving_distance'][i].value:12.3f} "
|
||||
f"{results['age'][i].value:10.4f} "
|
||||
f"{results['lookback_time'][i].value:10.4f} "
|
||||
f"{results['H'][i].value:10.4f}")
|
||||
|
||||
if verbose:
|
||||
print("\nLegend:")
|
||||
print(" z : Redshift")
|
||||
print(" D_L : Luminosity distance")
|
||||
print(" D_A : Angular diameter distance")
|
||||
print(" D_C : Comoving distance")
|
||||
print(" Age : Age of universe at z")
|
||||
print(" t_lb : Lookback time to z")
|
||||
print(" H(z) : Hubble parameter at z")
|
||||
|
||||
|
||||
def convert_quantity(value, quantity_type, cosmo, to_redshift=False):
|
||||
"""
|
||||
Convert between redshift and cosmological quantity.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : float
|
||||
Value to convert
|
||||
quantity_type : str
|
||||
Type of quantity ('luminosity_distance', 'age', etc.)
|
||||
cosmo : Cosmology
|
||||
Cosmology object
|
||||
to_redshift : bool
|
||||
If True, convert quantity to redshift; else convert z to quantity
|
||||
"""
|
||||
from astropy.cosmology import z_at_value
|
||||
|
||||
if to_redshift:
|
||||
# Convert quantity to redshift
|
||||
if quantity_type == 'luminosity_distance':
|
||||
z = z_at_value(cosmo.luminosity_distance, value * u.Mpc)
|
||||
elif quantity_type == 'age':
|
||||
z = z_at_value(cosmo.age, value * u.Gyr)
|
||||
elif quantity_type == 'lookback_time':
|
||||
z = z_at_value(cosmo.lookback_time, value * u.Gyr)
|
||||
elif quantity_type == 'comoving_distance':
|
||||
z = z_at_value(cosmo.comoving_distance, value * u.Mpc)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantity type: {quantity_type}")
|
||||
return z
|
||||
else:
|
||||
# Convert redshift to quantity
|
||||
if quantity_type == 'luminosity_distance':
|
||||
return cosmo.luminosity_distance(value)
|
||||
elif quantity_type == 'age':
|
||||
return cosmo.age(value)
|
||||
elif quantity_type == 'lookback_time':
|
||||
return cosmo.lookback_time(value)
|
||||
elif quantity_type == 'comoving_distance':
|
||||
return cosmo.comoving_distance(value)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantity type: {quantity_type}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function for command-line usage."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Calculate cosmological quantities for given redshifts',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Available cosmologies: Planck18, Planck15, WMAP9, custom
|
||||
|
||||
Examples:
|
||||
%(prog)s 0.5 1.0 1.5
|
||||
%(prog)s 0.5 --cosmology Planck15
|
||||
%(prog)s 0.5 --cosmology custom --H0 70 --Om0 0.3
|
||||
%(prog)s --range 0 3 0.5
|
||||
%(prog)s 0.5 --verbose
|
||||
%(prog)s 0.5 1.0 --csv
|
||||
%(prog)s --convert 1000 --from luminosity_distance --cosmology Planck18
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('redshifts', nargs='*', type=float,
|
||||
help='Redshift values to calculate')
|
||||
parser.add_argument('-c', '--cosmology', default='Planck18',
|
||||
choices=['Planck18', 'Planck15', 'WMAP9', 'custom'],
|
||||
help='Cosmology to use (default: Planck18)')
|
||||
parser.add_argument('--H0', type=float,
|
||||
help='Hubble constant for custom cosmology (km/s/Mpc)')
|
||||
parser.add_argument('--Om0', type=float,
|
||||
help='Matter density parameter for custom cosmology')
|
||||
parser.add_argument('-r', '--range', nargs=3, type=float, metavar=('START', 'STOP', 'STEP'),
|
||||
help='Generate redshift range (start stop step)')
|
||||
parser.add_argument('-v', '--verbose', action='store_true',
|
||||
help='Print verbose output with cosmology details')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Output in CSV format')
|
||||
parser.add_argument('--convert', type=float,
|
||||
help='Convert a quantity to redshift')
|
||||
parser.add_argument('--from', dest='from_quantity',
|
||||
choices=['luminosity_distance', 'age', 'lookback_time', 'comoving_distance'],
|
||||
help='Type of quantity to convert from')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle conversion mode
|
||||
if args.convert is not None:
|
||||
if args.from_quantity is None:
|
||||
print("Error: Must specify --from when using --convert", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Get cosmology
|
||||
if args.cosmology == 'Planck18':
|
||||
cosmo = Planck18
|
||||
elif args.cosmology == 'Planck15':
|
||||
cosmo = Planck15
|
||||
elif args.cosmology == 'WMAP9':
|
||||
cosmo = WMAP9
|
||||
elif args.cosmology == 'custom':
|
||||
if args.H0 is None or args.Om0 is None:
|
||||
print("Error: Must provide --H0 and --Om0 for custom cosmology",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
cosmo = FlatLambdaCDM(H0=args.H0 * u.km/u.s/u.Mpc, Om0=args.Om0)
|
||||
|
||||
z = convert_quantity(args.convert, args.from_quantity, cosmo, to_redshift=True)
|
||||
print(f"\n{args.from_quantity.replace('_', ' ').title()} = {args.convert}")
|
||||
print(f"Redshift z = {z:.6f}")
|
||||
print(f"(using {args.cosmology} cosmology)")
|
||||
return
|
||||
|
||||
# Get redshifts
|
||||
if args.range:
|
||||
start, stop, step = args.range
|
||||
redshifts = np.arange(start, stop + step/2, step)
|
||||
elif args.redshifts:
|
||||
redshifts = np.array(args.redshifts)
|
||||
else:
|
||||
print("Error: No redshifts provided.", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Calculate
|
||||
try:
|
||||
results, cosmo = calculate_cosmology(redshifts, args.cosmology,
|
||||
H0=args.H0, Om0=args.Om0)
|
||||
print_results(results, verbose=args.verbose, csv=args.csv)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick FITS file inspection tool.
|
||||
|
||||
This script provides a convenient way to inspect FITS file structure,
|
||||
headers, and basic statistics without writing custom code each time.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from astropy.io import fits
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_fits_info(filename, detailed=False, ext=None):
|
||||
"""
|
||||
Print comprehensive information about a FITS file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Path to FITS file
|
||||
detailed : bool
|
||||
If True, print detailed statistics for each HDU
|
||||
ext : int or str, optional
|
||||
Specific extension to examine in detail
|
||||
"""
|
||||
print(f"\n{'='*70}")
|
||||
print(f"FITS File: {filename}")
|
||||
print(f"{'='*70}\n")
|
||||
|
||||
try:
|
||||
with fits.open(filename) as hdul:
|
||||
# Print file structure
|
||||
print("File Structure:")
|
||||
print("-" * 70)
|
||||
hdul.info()
|
||||
print()
|
||||
|
||||
# If specific extension requested
|
||||
if ext is not None:
|
||||
print(f"\nDetailed view of extension: {ext}")
|
||||
print("-" * 70)
|
||||
hdu = hdul[ext]
|
||||
print_hdu_details(hdu, detailed=True)
|
||||
return
|
||||
|
||||
# Print header and data info for each HDU
|
||||
for i, hdu in enumerate(hdul):
|
||||
print(f"\n{'='*70}")
|
||||
print(f"HDU {i}: {hdu.name}")
|
||||
print(f"{'='*70}")
|
||||
print_hdu_details(hdu, detailed=detailed)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{filename}' not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error reading FITS file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def print_hdu_details(hdu, detailed=False):
|
||||
"""Print details for a single HDU."""
|
||||
|
||||
# Header information
|
||||
print("\nHeader Information:")
|
||||
print("-" * 70)
|
||||
|
||||
# Key header keywords
|
||||
important_keywords = ['SIMPLE', 'BITPIX', 'NAXIS', 'EXTEND',
|
||||
'OBJECT', 'TELESCOP', 'INSTRUME', 'OBSERVER',
|
||||
'DATE-OBS', 'EXPTIME', 'FILTER', 'AIRMASS',
|
||||
'RA', 'DEC', 'EQUINOX', 'CTYPE1', 'CTYPE2']
|
||||
|
||||
header = hdu.header
|
||||
for key in important_keywords:
|
||||
if key in header:
|
||||
value = header[key]
|
||||
comment = header.comments[key]
|
||||
print(f" {key:12s} = {str(value):20s} / {comment}")
|
||||
|
||||
# NAXIS keywords
|
||||
if 'NAXIS' in header:
|
||||
naxis = header['NAXIS']
|
||||
for i in range(1, naxis + 1):
|
||||
key = f'NAXIS{i}'
|
||||
if key in header:
|
||||
print(f" {key:12s} = {str(header[key]):20s} / {header.comments[key]}")
|
||||
|
||||
# Data information
|
||||
if hdu.data is not None:
|
||||
print("\nData Information:")
|
||||
print("-" * 70)
|
||||
|
||||
data = hdu.data
|
||||
print(f" Data type: {data.dtype}")
|
||||
print(f" Shape: {data.shape}")
|
||||
|
||||
# For image data
|
||||
if hasattr(data, 'ndim') and data.ndim >= 1:
|
||||
try:
|
||||
# Calculate statistics
|
||||
finite_data = data[np.isfinite(data)]
|
||||
if len(finite_data) > 0:
|
||||
print(f" Min: {np.min(finite_data):.6g}")
|
||||
print(f" Max: {np.max(finite_data):.6g}")
|
||||
print(f" Mean: {np.mean(finite_data):.6g}")
|
||||
print(f" Median: {np.median(finite_data):.6g}")
|
||||
print(f" Std: {np.std(finite_data):.6g}")
|
||||
|
||||
# Count special values
|
||||
n_nan = np.sum(np.isnan(data))
|
||||
n_inf = np.sum(np.isinf(data))
|
||||
if n_nan > 0:
|
||||
print(f" NaN values: {n_nan}")
|
||||
if n_inf > 0:
|
||||
print(f" Inf values: {n_inf}")
|
||||
except Exception as e:
|
||||
print(f" Could not calculate statistics: {e}")
|
||||
|
||||
# For table data
|
||||
if hasattr(data, 'columns'):
|
||||
print(f"\n Table Columns ({len(data.columns)}):")
|
||||
for col in data.columns:
|
||||
print(f" {col.name:20s} {col.format:10s} {col.unit or ''}")
|
||||
|
||||
if detailed:
|
||||
print(f"\n First few rows:")
|
||||
print(data[:min(5, len(data))])
|
||||
else:
|
||||
print("\n No data in this HDU")
|
||||
|
||||
# WCS information if present
|
||||
try:
|
||||
from astropy.wcs import WCS
|
||||
wcs = WCS(hdu.header)
|
||||
if wcs.has_celestial:
|
||||
print("\nWCS Information:")
|
||||
print("-" * 70)
|
||||
print(f" Has celestial WCS: Yes")
|
||||
print(f" CTYPE: {wcs.wcs.ctype}")
|
||||
if wcs.wcs.crval is not None:
|
||||
print(f" CRVAL: {wcs.wcs.crval}")
|
||||
if wcs.wcs.crpix is not None:
|
||||
print(f" CRPIX: {wcs.wcs.crpix}")
|
||||
if wcs.wcs.cdelt is not None:
|
||||
print(f" CDELT: {wcs.wcs.cdelt}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function for command-line usage."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Inspect FITS file structure and contents',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
%(prog)s image.fits
|
||||
%(prog)s image.fits --detailed
|
||||
%(prog)s image.fits --ext 1
|
||||
%(prog)s image.fits --ext SCI
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('filename', help='FITS file to inspect')
|
||||
parser.add_argument('-d', '--detailed', action='store_true',
|
||||
help='Show detailed statistics for each HDU')
|
||||
parser.add_argument('-e', '--ext', type=str, default=None,
|
||||
help='Show details for specific extension only (number or name)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert extension to int if numeric
|
||||
ext = args.ext
|
||||
if ext is not None:
|
||||
try:
|
||||
ext = int(ext)
|
||||
except ValueError:
|
||||
pass # Keep as string for extension name
|
||||
|
||||
print_fits_info(args.filename, detailed=args.detailed, ext=ext)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,530 +0,0 @@
|
||||
---
|
||||
name: pyopenms
|
||||
description: "Mass spectrometry toolkit (OpenMS Python). Process mzML/mzXML, peak picking, feature detection, peptide ID, proteomics/metabolomics workflows, for LC-MS/MS analysis."
|
||||
---
|
||||
|
||||
# pyOpenMS
|
||||
|
||||
## Overview
|
||||
|
||||
pyOpenMS is an open-source Python library for mass spectrometry data analysis in proteomics and metabolomics. Process LC-MS/MS data, perform peptide identification, detect and quantify features, and integrate with common proteomics tools (Comet, Mascot, MSGF+, Percolator, MSstats) using Python bindings to the OpenMS C++ library.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Processing mass spectrometry data (mzML, mzXML files)
|
||||
- Performing peak picking and feature detection in LC-MS data
|
||||
- Conducting peptide and protein identification workflows
|
||||
- Quantifying metabolites or proteins
|
||||
- Integrating proteomics or metabolomics tools into Python pipelines
|
||||
- Working with OpenMS tools and file formats
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. File I/O and Data Import/Export
|
||||
|
||||
Handle diverse mass spectrometry file formats efficiently:
|
||||
|
||||
**Supported Formats:**
|
||||
- **mzML/mzXML**: Primary raw MS data formats (profile or centroid)
|
||||
- **FASTA**: Protein/peptide sequence databases
|
||||
- **mzTab**: Standardized reporting format for identification and quantification
|
||||
- **mzIdentML**: Peptide and protein identification data
|
||||
- **TraML**: Transition lists for targeted experiments
|
||||
- **pepXML/protXML**: Search engine results
|
||||
|
||||
**Reading mzML Files:**
|
||||
```python
|
||||
import pyopenms as oms
|
||||
|
||||
# Load MS data
|
||||
exp = oms.MSExperiment()
|
||||
oms.MzMLFile().load("input_data.mzML", exp)
|
||||
|
||||
# Access basic information
|
||||
print(f"Number of spectra: {exp.getNrSpectra()}")
|
||||
print(f"Number of chromatograms: {exp.getNrChromatograms()}")
|
||||
```
|
||||
|
||||
**Writing mzML Files:**
|
||||
```python
|
||||
# Save processed data
|
||||
oms.MzMLFile().store("output_data.mzML", exp)
|
||||
```
|
||||
|
||||
**File Encoding:** pyOpenMS automatically handles Base64 encoding, zlib compression, and Numpress compression internally.
|
||||
|
||||
### 2. MS Data Structures and Manipulation
|
||||
|
||||
Work with core mass spectrometry data structures. See `references/data_structures.md` for comprehensive details.
|
||||
|
||||
**MSSpectrum** - Individual mass spectrum:
|
||||
```python
|
||||
# Create spectrum with metadata
|
||||
spectrum = oms.MSSpectrum()
|
||||
spectrum.setRT(205.2) # Retention time in seconds
|
||||
spectrum.setMSLevel(2) # MS2 spectrum
|
||||
|
||||
# Set peak data (m/z, intensity arrays)
|
||||
mz_array = [100.5, 200.3, 300.7, 400.2]
|
||||
intensity_array = [1000, 5000, 3000, 2000]
|
||||
spectrum.set_peaks((mz_array, intensity_array))
|
||||
|
||||
# Add precursor information for MS2
|
||||
precursor = oms.Precursor()
|
||||
precursor.setMZ(450.5)
|
||||
precursor.setCharge(2)
|
||||
spectrum.setPrecursors([precursor])
|
||||
```
|
||||
|
||||
**MSExperiment** - Complete LC-MS/MS run:
|
||||
```python
|
||||
# Create experiment and add spectra
|
||||
exp = oms.MSExperiment()
|
||||
exp.addSpectrum(spectrum)
|
||||
|
||||
# Access spectra
|
||||
first_spectrum = exp.getSpectrum(0)
|
||||
for spec in exp:
|
||||
print(f"RT: {spec.getRT()}, MS Level: {spec.getMSLevel()}")
|
||||
```
|
||||
|
||||
**MSChromatogram** - Extracted ion chromatogram:
|
||||
```python
|
||||
# Create chromatogram
|
||||
chrom = oms.MSChromatogram()
|
||||
chrom.set_peaks(([10.5, 11.2, 11.8], [1000, 5000, 3000])) # RT, intensity
|
||||
exp.addChromatogram(chrom)
|
||||
```
|
||||
|
||||
**Efficient Peak Access:**
|
||||
```python
|
||||
# Get peaks as numpy arrays for fast processing
|
||||
mz_array, intensity_array = spectrum.get_peaks()
|
||||
|
||||
# Modify and set back
|
||||
intensity_array *= 2 # Double all intensities
|
||||
spectrum.set_peaks((mz_array, intensity_array))
|
||||
```
|
||||
|
||||
### 3. Chemistry and Peptide Handling
|
||||
|
||||
Perform chemical calculations for proteomics and metabolomics. See `references/chemistry.md` for detailed examples.
|
||||
|
||||
**Molecular Formulas and Mass Calculations:**
|
||||
```python
|
||||
# Create empirical formula
|
||||
formula = oms.EmpiricalFormula("C6H12O6") # Glucose
|
||||
print(f"Monoisotopic mass: {formula.getMonoWeight()}")
|
||||
print(f"Average mass: {formula.getAverageWeight()}")
|
||||
|
||||
# Formula arithmetic
|
||||
water = oms.EmpiricalFormula("H2O")
|
||||
dehydrated = formula - water
|
||||
|
||||
# Isotope-specific formulas
|
||||
heavy_carbon = oms.EmpiricalFormula("(13)C6H12O6")
|
||||
```
|
||||
|
||||
**Isotopic Distributions:**
|
||||
```python
|
||||
# Generate coarse isotope pattern (unit mass resolution)
|
||||
coarse_gen = oms.CoarseIsotopePatternGenerator()
|
||||
pattern = coarse_gen.run(formula)
|
||||
|
||||
# Generate fine structure (high resolution)
|
||||
fine_gen = oms.FineIsotopePatternGenerator(0.01) # 0.01 Da resolution
|
||||
fine_pattern = fine_gen.run(formula)
|
||||
```
|
||||
|
||||
**Amino Acids and Residues:**
|
||||
```python
|
||||
# Access residue information
|
||||
res_db = oms.ResidueDB()
|
||||
leucine = res_db.getResidue("Leucine")
|
||||
print(f"L monoisotopic mass: {leucine.getMonoWeight()}")
|
||||
print(f"L formula: {leucine.getFormula()}")
|
||||
print(f"L pKa: {leucine.getPka()}")
|
||||
```
|
||||
|
||||
**Peptide Sequences:**
|
||||
```python
|
||||
# Create peptide sequence
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
print(f"Peptide mass: {peptide.getMonoWeight()}")
|
||||
print(f"Formula: {peptide.getFormula()}")
|
||||
|
||||
# Add modifications
|
||||
modified = oms.AASequence.fromString("PEPTIDEM(Oxidation)")
|
||||
print(f"Modified mass: {modified.getMonoWeight()}")
|
||||
|
||||
# Theoretical fragmentation
|
||||
ions = []
|
||||
for i in range(1, peptide.size()):
|
||||
b_ion = peptide.getPrefix(i)
|
||||
y_ion = peptide.getSuffix(i)
|
||||
ions.append(('b', i, b_ion.getMonoWeight()))
|
||||
ions.append(('y', i, y_ion.getMonoWeight()))
|
||||
```
|
||||
|
||||
**Protein Digestion:**
|
||||
```python
|
||||
# Enzymatic digestion
|
||||
dig = oms.ProteaseDigestion()
|
||||
dig.setEnzyme("Trypsin")
|
||||
dig.setMissedCleavages(2)
|
||||
|
||||
protein_seq = oms.AASequence.fromString("MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK")
|
||||
peptides = []
|
||||
dig.digest(protein_seq, peptides)
|
||||
|
||||
for pep in peptides:
|
||||
print(f"{pep.toString()}: {pep.getMonoWeight():.2f} Da")
|
||||
```
|
||||
|
||||
**Modifications:**
|
||||
```python
|
||||
# Access modification database
|
||||
mod_db = oms.ModificationsDB()
|
||||
oxidation = mod_db.getModification("Oxidation")
|
||||
print(f"Oxidation mass diff: {oxidation.getDiffMonoMass()}")
|
||||
print(f"Residues: {oxidation.getResidues()}")
|
||||
```
|
||||
|
||||
### 4. Signal Processing and Filtering
|
||||
|
||||
Apply algorithms to process and filter MS data. See `references/algorithms.md` for comprehensive coverage.
|
||||
|
||||
**Spectral Smoothing:**
|
||||
```python
|
||||
# Gaussian smoothing
|
||||
gauss_filter = oms.GaussFilter()
|
||||
params = gauss_filter.getParameters()
|
||||
params.setValue("gaussian_width", 0.2)
|
||||
gauss_filter.setParameters(params)
|
||||
gauss_filter.filterExperiment(exp)
|
||||
|
||||
# Savitzky-Golay filter
|
||||
sg_filter = oms.SavitzkyGolayFilter()
|
||||
sg_filter.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Peak Filtering:**
|
||||
```python
|
||||
# Keep only N largest peaks per spectrum
|
||||
n_largest = oms.NLargest()
|
||||
params = n_largest.getParameters()
|
||||
params.setValue("n", 100) # Keep top 100 peaks
|
||||
n_largest.setParameters(params)
|
||||
n_largest.filterExperiment(exp)
|
||||
|
||||
# Threshold filtering
|
||||
threshold_filter = oms.ThresholdMower()
|
||||
params = threshold_filter.getParameters()
|
||||
params.setValue("threshold", 1000.0) # Remove peaks below 1000 intensity
|
||||
threshold_filter.setParameters(params)
|
||||
threshold_filter.filterExperiment(exp)
|
||||
|
||||
# Window-based filtering
|
||||
window_filter = oms.WindowMower()
|
||||
params = window_filter.getParameters()
|
||||
params.setValue("windowsize", 50.0) # 50 m/z windows
|
||||
params.setValue("peakcount", 10) # Keep 10 highest per window
|
||||
window_filter.setParameters(params)
|
||||
window_filter.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Spectrum Normalization:**
|
||||
```python
|
||||
normalizer = oms.Normalizer()
|
||||
normalizer.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**MS Level Filtering:**
|
||||
```python
|
||||
# Keep only MS2 spectra
|
||||
exp.filterMSLevel(2)
|
||||
|
||||
# Filter by retention time range
|
||||
exp.filterRT(100.0, 500.0) # Keep RT between 100-500 seconds
|
||||
|
||||
# Filter by m/z range
|
||||
exp.filterMZ(400.0, 1500.0) # Keep m/z between 400-1500
|
||||
```
|
||||
|
||||
### 5. Feature Detection and Quantification
|
||||
|
||||
Detect and quantify features in LC-MS data:
|
||||
|
||||
**Peak Picking (Centroiding):**
|
||||
```python
|
||||
# Convert profile data to centroid
|
||||
picker = oms.PeakPickerHiRes()
|
||||
params = picker.getParameters()
|
||||
params.setValue("signal_to_noise", 1.0)
|
||||
picker.setParameters(params)
|
||||
|
||||
exp_centroided = oms.MSExperiment()
|
||||
picker.pickExperiment(exp, exp_centroided)
|
||||
```
|
||||
|
||||
**Feature Detection:**
|
||||
```python
|
||||
# Detect features across LC-MS runs
|
||||
feature_finder = oms.FeatureFinderMultiplex()
|
||||
|
||||
features = oms.FeatureMap()
|
||||
feature_finder.run(exp, features, params)
|
||||
|
||||
print(f"Found {features.size()} features")
|
||||
for feature in features:
|
||||
print(f"m/z: {feature.getMZ():.4f}, RT: {feature.getRT():.2f}, "
|
||||
f"Intensity: {feature.getIntensity():.0f}")
|
||||
```
|
||||
|
||||
**Feature Linking (Map Alignment):**
|
||||
```python
|
||||
# Link features across multiple samples
|
||||
feature_grouper = oms.FeatureGroupingAlgorithmQT()
|
||||
consensus_map = oms.ConsensusMap()
|
||||
|
||||
# Provide multiple feature maps from different samples
|
||||
feature_maps = [features1, features2, features3]
|
||||
feature_grouper.group(feature_maps, consensus_map)
|
||||
```
|
||||
|
||||
### 6. Peptide Identification Workflows
|
||||
|
||||
Integrate with search engines and process identification results:
|
||||
|
||||
**Database Searching:**
|
||||
```python
|
||||
# Prepare parameters for search engine
|
||||
params = oms.Param()
|
||||
params.setValue("database", "uniprot_human.fasta")
|
||||
params.setValue("precursor_mass_tolerance", 10.0) # ppm
|
||||
params.setValue("fragment_mass_tolerance", 0.5) # Da
|
||||
params.setValue("enzyme", "Trypsin")
|
||||
params.setValue("missed_cleavages", 2)
|
||||
|
||||
# Variable modifications
|
||||
params.setValue("variable_modifications", ["Oxidation (M)", "Phospho (STY)"])
|
||||
|
||||
# Fixed modifications
|
||||
params.setValue("fixed_modifications", ["Carbamidomethyl (C)"])
|
||||
```
|
||||
|
||||
**FDR Control:**
|
||||
```python
|
||||
# False discovery rate estimation
|
||||
fdr = oms.FalseDiscoveryRate()
|
||||
fdr_threshold = 0.01 # 1% FDR
|
||||
|
||||
# Apply to peptide identifications
|
||||
protein_ids = []
|
||||
peptide_ids = []
|
||||
oms.IdXMLFile().load("search_results.idXML", protein_ids, peptide_ids)
|
||||
|
||||
fdr.apply(protein_ids, peptide_ids)
|
||||
```
|
||||
|
||||
### 7. Metabolomics Workflows
|
||||
|
||||
Analyze small molecule data:
|
||||
|
||||
**Adduct Detection:**
|
||||
```python
|
||||
# Common metabolite adducts
|
||||
adducts = ["[M+H]+", "[M+Na]+", "[M+K]+", "[M-H]-", "[M+Cl]-"]
|
||||
|
||||
# Feature annotation with adducts
|
||||
for feature in features:
|
||||
mz = feature.getMZ()
|
||||
# Calculate neutral mass for each adduct hypothesis
|
||||
for adduct in adducts:
|
||||
# Annotation logic
|
||||
pass
|
||||
```
|
||||
|
||||
**Isotope Pattern Matching:**
|
||||
```python
|
||||
# Compare experimental to theoretical isotope patterns
|
||||
experimental_pattern = [] # Extract from feature
|
||||
theoretical = coarse_gen.run(formula)
|
||||
|
||||
# Calculate similarity score
|
||||
similarity = compare_isotope_patterns(experimental_pattern, theoretical)
|
||||
```
|
||||
|
||||
### 8. Quality Control and Visualization
|
||||
|
||||
Monitor data quality and visualize results:
|
||||
|
||||
**Basic Statistics:**
|
||||
```python
|
||||
# Calculate TIC (Total Ion Current)
|
||||
tic_values = []
|
||||
rt_values = []
|
||||
for spectrum in exp:
|
||||
if spectrum.getMSLevel() == 1:
|
||||
tic = sum(spectrum.get_peaks()[1]) # Sum intensities
|
||||
tic_values.append(tic)
|
||||
rt_values.append(spectrum.getRT())
|
||||
|
||||
# Base peak chromatogram
|
||||
bpc_values = []
|
||||
for spectrum in exp:
|
||||
if spectrum.getMSLevel() == 1:
|
||||
max_intensity = max(spectrum.get_peaks()[1]) if spectrum.size() > 0 else 0
|
||||
bpc_values.append(max_intensity)
|
||||
```
|
||||
|
||||
**Plotting (with pyopenms.plotting or matplotlib):**
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Plot TIC
|
||||
plt.figure(figsize=(10, 4))
|
||||
plt.plot(rt_values, tic_values)
|
||||
plt.xlabel('Retention Time (s)')
|
||||
plt.ylabel('Total Ion Current')
|
||||
plt.title('TIC')
|
||||
plt.show()
|
||||
|
||||
# Plot single spectrum
|
||||
spectrum = exp.getSpectrum(0)
|
||||
mz, intensity = spectrum.get_peaks()
|
||||
plt.stem(mz, intensity, basefmt=' ')
|
||||
plt.xlabel('m/z')
|
||||
plt.ylabel('Intensity')
|
||||
plt.title(f'Spectrum at RT {spectrum.getRT():.2f}s')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Complete LC-MS/MS Processing Pipeline
|
||||
|
||||
```python
|
||||
import pyopenms as oms
|
||||
|
||||
# 1. Load data
|
||||
exp = oms.MSExperiment()
|
||||
oms.MzMLFile().load("raw_data.mzML", exp)
|
||||
|
||||
# 2. Filter and smooth
|
||||
exp.filterMSLevel(1) # Keep only MS1 for feature detection
|
||||
gauss = oms.GaussFilter()
|
||||
gauss.filterExperiment(exp)
|
||||
|
||||
# 3. Peak picking
|
||||
picker = oms.PeakPickerHiRes()
|
||||
exp_centroid = oms.MSExperiment()
|
||||
picker.pickExperiment(exp, exp_centroid)
|
||||
|
||||
# 4. Feature detection
|
||||
ff = oms.FeatureFinderMultiplex()
|
||||
features = oms.FeatureMap()
|
||||
ff.run(exp_centroid, features, oms.Param())
|
||||
|
||||
# 5. Export results
|
||||
oms.FeatureXMLFile().store("features.featureXML", features)
|
||||
print(f"Detected {features.size()} features")
|
||||
```
|
||||
|
||||
### Theoretical Peptide Mass Calculation
|
||||
|
||||
```python
|
||||
# Calculate masses for peptide with modifications
|
||||
peptide = oms.AASequence.fromString("PEPTIDEK")
|
||||
print(f"Unmodified [M+H]+: {peptide.getMonoWeight() + 1.007276:.4f}")
|
||||
|
||||
# With modification
|
||||
modified = oms.AASequence.fromString("PEPTIDEM(Oxidation)K")
|
||||
print(f"Oxidized [M+H]+: {modified.getMonoWeight() + 1.007276:.4f}")
|
||||
|
||||
# Calculate for different charge states
|
||||
for z in [1, 2, 3]:
|
||||
mz = (peptide.getMonoWeight() + z * 1.007276) / z
|
||||
print(f"[M+{z}H]^{z}+: {mz:.4f}")
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
Ensure pyOpenMS is installed before using this skill:
|
||||
|
||||
```bash
|
||||
# Via conda (recommended)
|
||||
conda install -c bioconda pyopenms
|
||||
|
||||
# Via pip
|
||||
pip install pyopenms
|
||||
```
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
pyOpenMS integrates seamlessly with:
|
||||
|
||||
- **Search Engines**: Comet, Mascot, MSGF+, MSFragger, Sage, SpectraST
|
||||
- **Post-processing**: Percolator, MSstats, Epiphany
|
||||
- **Metabolomics**: SIRIUS, CSI:FingerID
|
||||
- **Data Analysis**: Pandas, NumPy, SciPy for downstream analysis
|
||||
- **Visualization**: Matplotlib, Seaborn for plotting
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
|
||||
Detailed documentation on core concepts:
|
||||
|
||||
- **data_structures.md** - Comprehensive guide to MSExperiment, MSSpectrum, MSChromatogram, and peak data handling
|
||||
- **algorithms.md** - Complete reference for signal processing, filtering, feature detection, and quantification algorithms
|
||||
- **chemistry.md** - In-depth coverage of chemistry calculations, peptide handling, modifications, and isotope distributions
|
||||
|
||||
Load these references when needing detailed information about specific pyOpenMS capabilities.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **File Format**: Always use mzML for raw MS data (standardized, well-supported)
|
||||
2. **Peak Access**: Use `get_peaks()` and `set_peaks()` with numpy arrays for efficient processing
|
||||
3. **Parameters**: Always check and configure algorithm parameters via `getParameters()` and `setParameters()`
|
||||
4. **Memory**: For large datasets, process spectra iteratively rather than loading entire experiments
|
||||
5. **Validation**: Check data integrity (MS levels, RT ordering, precursor information) after loading
|
||||
6. **Modifications**: Use standard modification names from UniMod database
|
||||
7. **Units**: RT in seconds, m/z in Thomson (Da/charge), intensity in arbitrary units
|
||||
|
||||
## Common Patterns
|
||||
|
||||
**Algorithm Application Pattern:**
|
||||
```python
|
||||
# 1. Instantiate algorithm
|
||||
algorithm = oms.SomeAlgorithm()
|
||||
|
||||
# 2. Get and configure parameters
|
||||
params = algorithm.getParameters()
|
||||
params.setValue("parameter_name", value)
|
||||
algorithm.setParameters(params)
|
||||
|
||||
# 3. Apply to data
|
||||
algorithm.filterExperiment(exp) # or .process(), .run(), depending on algorithm
|
||||
```
|
||||
|
||||
**File I/O Pattern:**
|
||||
```python
|
||||
# Read
|
||||
data_container = oms.DataContainer() # MSExperiment, FeatureMap, etc.
|
||||
oms.FileHandler().load("input.format", data_container)
|
||||
|
||||
# Process
|
||||
# ... manipulate data_container ...
|
||||
|
||||
# Write
|
||||
oms.FileHandler().store("output.format", data_container)
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
- **Documentation**: https://pyopenms.readthedocs.io/
|
||||
- **API Reference**: Browse class documentation for detailed method signatures
|
||||
- **OpenMS Website**: https://www.openms.org/
|
||||
- **GitHub Issues**: https://github.com/OpenMS/OpenMS/issues
|
||||
@@ -1,643 +0,0 @@
|
||||
# pyOpenMS Algorithms Reference
|
||||
|
||||
This document provides comprehensive coverage of algorithms available in pyOpenMS for signal processing, feature detection, and quantification.
|
||||
|
||||
## Algorithm Usage Pattern
|
||||
|
||||
Most pyOpenMS algorithms follow a consistent pattern:
|
||||
|
||||
```python
|
||||
import pyopenms as oms
|
||||
|
||||
# 1. Instantiate algorithm
|
||||
algorithm = oms.AlgorithmName()
|
||||
|
||||
# 2. Get parameters
|
||||
params = algorithm.getParameters()
|
||||
|
||||
# 3. Modify parameters
|
||||
params.setValue("parameter_name", value)
|
||||
|
||||
# 4. Set parameters back
|
||||
algorithm.setParameters(params)
|
||||
|
||||
# 5. Apply to data
|
||||
algorithm.filterExperiment(exp) # or .process(), .run(), etc.
|
||||
```
|
||||
|
||||
## Signal Processing Algorithms
|
||||
|
||||
### Smoothing Filters
|
||||
|
||||
#### GaussFilter - Gaussian Smoothing
|
||||
|
||||
Applies Gaussian smoothing to reduce noise.
|
||||
|
||||
```python
|
||||
gauss = oms.GaussFilter()
|
||||
|
||||
# Configure parameters
|
||||
params = gauss.getParameters()
|
||||
params.setValue("gaussian_width", 0.2) # Gaussian width (larger = more smoothing)
|
||||
params.setValue("ppm_tolerance", 10.0) # PPM tolerance for spacing
|
||||
params.setValue("use_ppm_tolerance", "true")
|
||||
gauss.setParameters(params)
|
||||
|
||||
# Apply to experiment
|
||||
gauss.filterExperiment(exp)
|
||||
|
||||
# Or apply to single spectrum
|
||||
spectrum_smoothed = oms.MSSpectrum()
|
||||
gauss.filter(spectrum, spectrum_smoothed)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `gaussian_width`: Width of Gaussian kernel (default: 0.2 Da)
|
||||
- `ppm_tolerance`: Tolerance in ppm for spacing
|
||||
- `use_ppm_tolerance`: Whether to use ppm instead of absolute spacing
|
||||
|
||||
#### SavitzkyGolayFilter
|
||||
|
||||
Applies Savitzky-Golay smoothing (polynomial fitting).
|
||||
|
||||
```python
|
||||
sg_filter = oms.SavitzkyGolayFilter()
|
||||
|
||||
params = sg_filter.getParameters()
|
||||
params.setValue("frame_length", 11) # Window size (must be odd)
|
||||
params.setValue("polynomial_order", 3) # Polynomial degree
|
||||
sg_filter.setParameters(params)
|
||||
|
||||
sg_filter.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `frame_length`: Size of smoothing window (must be odd)
|
||||
- `polynomial_order`: Degree of polynomial (typically 2-4)
|
||||
|
||||
### Peak Filtering
|
||||
|
||||
#### NLargest - Keep Top N Peaks
|
||||
|
||||
Retains only the N most intense peaks per spectrum.
|
||||
|
||||
```python
|
||||
n_largest = oms.NLargest()
|
||||
|
||||
params = n_largest.getParameters()
|
||||
params.setValue("n", 100) # Keep top 100 peaks
|
||||
params.setValue("threshold", 0.0) # Optional minimum intensity
|
||||
n_largest.setParameters(params)
|
||||
|
||||
n_largest.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `n`: Number of peaks to keep per spectrum
|
||||
- `threshold`: Minimum absolute intensity threshold
|
||||
|
||||
#### ThresholdMower - Intensity Threshold Filtering
|
||||
|
||||
Removes peaks below a specified intensity threshold.
|
||||
|
||||
```python
|
||||
threshold_filter = oms.ThresholdMower()
|
||||
|
||||
params = threshold_filter.getParameters()
|
||||
params.setValue("threshold", 1000.0) # Absolute intensity threshold
|
||||
threshold_filter.setParameters(params)
|
||||
|
||||
threshold_filter.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `threshold`: Absolute intensity cutoff
|
||||
|
||||
#### WindowMower - Window-Based Peak Selection
|
||||
|
||||
Divides m/z range into windows and keeps top N peaks per window.
|
||||
|
||||
```python
|
||||
window_mower = oms.WindowMower()
|
||||
|
||||
params = window_mower.getParameters()
|
||||
params.setValue("windowsize", 50.0) # Window size in Da (or Thomson)
|
||||
params.setValue("peakcount", 10) # Peaks to keep per window
|
||||
params.setValue("movetype", "jump") # "jump" or "slide"
|
||||
window_mower.setParameters(params)
|
||||
|
||||
window_mower.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `windowsize`: Size of m/z window (Da)
|
||||
- `peakcount`: Number of peaks to retain per window
|
||||
- `movetype`: "jump" (non-overlapping) or "slide" (overlapping windows)
|
||||
|
||||
#### BernNorm - Bernoulli Normalization
|
||||
|
||||
Statistical normalization based on Bernoulli distribution.
|
||||
|
||||
```python
|
||||
bern_norm = oms.BernNorm()
|
||||
|
||||
params = bern_norm.getParameters()
|
||||
params.setValue("threshold", 0.7) # Threshold for normalization
|
||||
bern_norm.setParameters(params)
|
||||
|
||||
bern_norm.filterExperiment(exp)
|
||||
```
|
||||
|
||||
### Spectrum Normalization
|
||||
|
||||
#### Normalizer
|
||||
|
||||
Normalizes spectrum intensities to unit total intensity or maximum intensity.
|
||||
|
||||
```python
|
||||
normalizer = oms.Normalizer()
|
||||
|
||||
params = normalizer.getParameters()
|
||||
params.setValue("method", "to_one") # "to_one" or "to_TIC"
|
||||
normalizer.setParameters(params)
|
||||
|
||||
normalizer.filterExperiment(exp)
|
||||
```
|
||||
|
||||
**Methods:**
|
||||
- `to_one`: Normalize max peak to 1.0
|
||||
- `to_TIC`: Normalize to total ion current = 1.0
|
||||
|
||||
#### Scaler
|
||||
|
||||
Scales intensities by a constant factor.
|
||||
|
||||
```python
|
||||
scaler = oms.Scaler()
|
||||
|
||||
params = scaler.getParameters()
|
||||
params.setValue("scaling", 1000.0) # Scaling factor
|
||||
scaler.setParameters(params)
|
||||
|
||||
scaler.filterExperiment(exp)
|
||||
```
|
||||
|
||||
## Centroiding and Peak Picking
|
||||
|
||||
### PeakPickerHiRes - High-Resolution Peak Picking
|
||||
|
||||
Converts profile spectra to centroid mode for high-resolution data.
|
||||
|
||||
```python
|
||||
picker = oms.PeakPickerHiRes()
|
||||
|
||||
params = picker.getParameters()
|
||||
params.setValue("signal_to_noise", 1.0) # S/N threshold
|
||||
params.setValue("spacing_difference", 1.5) # Peak spacing factor
|
||||
params.setValue("sn_win_len", 20.0) # S/N window length
|
||||
params.setValue("sn_bin_count", 30) # Bins for S/N estimation
|
||||
params.setValue("ms1_only", "false") # Process only MS1
|
||||
params.setValue("ms_levels", [1, 2]) # MS levels to process
|
||||
picker.setParameters(params)
|
||||
|
||||
# Pick peaks
|
||||
exp_centroided = oms.MSExperiment()
|
||||
picker.pickExperiment(exp, exp_centroided)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `signal_to_noise`: Minimum signal-to-noise ratio
|
||||
- `spacing_difference`: Minimum spacing between peaks
|
||||
- `ms_levels`: List of MS levels to process
|
||||
|
||||
### PeakPickerWavelet - Wavelet-Based Peak Picking
|
||||
|
||||
Uses continuous wavelet transform for peak detection.
|
||||
|
||||
```python
|
||||
wavelet_picker = oms.PeakPickerWavelet()
|
||||
|
||||
params = wavelet_picker.getParameters()
|
||||
params.setValue("signal_to_noise", 1.0)
|
||||
params.setValue("peak_width", 0.15) # Expected peak width (Da)
|
||||
wavelet_picker.setParameters(params)
|
||||
|
||||
wavelet_picker.pickExperiment(exp, exp_centroided)
|
||||
```
|
||||
|
||||
## Feature Detection
|
||||
|
||||
### FeatureFinder Algorithms
|
||||
|
||||
Feature finders detect 2D features (m/z and RT) in LC-MS data.
|
||||
|
||||
#### FeatureFinderMultiplex
|
||||
|
||||
For multiplex labeling experiments (SILAC, dimethyl labeling).
|
||||
|
||||
```python
|
||||
ff = oms.FeatureFinderMultiplex()
|
||||
|
||||
params = ff.getParameters()
|
||||
params.setValue("algorithm:labels", "[]") # Empty for label-free
|
||||
params.setValue("algorithm:charge", "2:4") # Charge range
|
||||
params.setValue("algorithm:rt_typical", 40.0) # Expected feature RT width
|
||||
params.setValue("algorithm:rt_min", 2.0) # Minimum RT width
|
||||
params.setValue("algorithm:mz_tolerance", 10.0) # m/z tolerance (ppm)
|
||||
params.setValue("algorithm:intensity_cutoff", 1000.0) # Minimum intensity
|
||||
ff.setParameters(params)
|
||||
|
||||
# Run feature detection
|
||||
features = oms.FeatureMap()
|
||||
ff.run(exp, features, oms.Param())
|
||||
|
||||
print(f"Found {features.size()} features")
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `algorithm:charge`: Charge state range to consider
|
||||
- `algorithm:rt_typical`: Expected peak width in RT dimension
|
||||
- `algorithm:mz_tolerance`: Mass tolerance in ppm
|
||||
- `algorithm:intensity_cutoff`: Minimum intensity threshold
|
||||
|
||||
#### FeatureFinderCentroided
|
||||
|
||||
For centroided data, identifies isotope patterns and traces over RT.
|
||||
|
||||
```python
|
||||
ff_centroided = oms.FeatureFinderCentroided()
|
||||
|
||||
params = ff_centroided.getParameters()
|
||||
params.setValue("mass_trace:mz_tolerance", 10.0) # ppm
|
||||
params.setValue("mass_trace:min_spectra", 5) # Min consecutive spectra
|
||||
params.setValue("isotopic_pattern:charge_low", 1)
|
||||
params.setValue("isotopic_pattern:charge_high", 4)
|
||||
params.setValue("seed:min_score", 0.5)
|
||||
ff_centroided.setParameters(params)
|
||||
|
||||
features = oms.FeatureMap()
|
||||
seeds = oms.FeatureMap() # Optional seed features
|
||||
ff_centroided.run(exp, features, params, seeds)
|
||||
```
|
||||
|
||||
#### FeatureFinderIdentification
|
||||
|
||||
Uses peptide identifications to guide feature detection.
|
||||
|
||||
```python
|
||||
ff_id = oms.FeatureFinderIdentification()
|
||||
|
||||
params = ff_id.getParameters()
|
||||
params.setValue("extract:mz_window", 10.0) # ppm
|
||||
params.setValue("extract:rt_window", 60.0) # seconds
|
||||
params.setValue("detect:peak_width", 30.0) # Expected peak width
|
||||
ff_id.setParameters(params)
|
||||
|
||||
# Requires peptide identifications
|
||||
protein_ids = []
|
||||
peptide_ids = []
|
||||
features = oms.FeatureMap()
|
||||
|
||||
ff_id.run(exp, protein_ids, peptide_ids, features)
|
||||
```
|
||||
|
||||
## Charge and Isotope Deconvolution
|
||||
|
||||
### Decharging and Charge State Deconvolution
|
||||
|
||||
#### FeatureDeconvolution
|
||||
|
||||
Resolves charge states and combines features.
|
||||
|
||||
```python
|
||||
deconv = oms.FeatureDeconvolution()
|
||||
|
||||
params = deconv.getParameters()
|
||||
params.setValue("charge_min", 1)
|
||||
params.setValue("charge_max", 4)
|
||||
params.setValue("q_value", 0.01) # FDR threshold
|
||||
deconv.setParameters(params)
|
||||
|
||||
features_deconv = oms.FeatureMap()
|
||||
consensus_map = oms.ConsensusMap()
|
||||
deconv.compute(features, features_deconv, consensus_map)
|
||||
```
|
||||
|
||||
## Map Alignment
|
||||
|
||||
### MapAlignmentAlgorithm
|
||||
|
||||
Aligns retention times across multiple LC-MS runs.
|
||||
|
||||
#### MapAlignmentAlgorithmPoseClustering
|
||||
|
||||
Pose clustering-based RT alignment.
|
||||
|
||||
```python
|
||||
aligner = oms.MapAlignmentAlgorithmPoseClustering()
|
||||
|
||||
params = aligner.getParameters()
|
||||
params.setValue("max_num_peaks_considered", 1000)
|
||||
params.setValue("pairfinder:distance_MZ:max_difference", 0.3) # Da
|
||||
params.setValue("pairfinder:distance_RT:max_difference", 60.0) # seconds
|
||||
aligner.setParameters(params)
|
||||
|
||||
# Align multiple feature maps
|
||||
feature_maps = [features1, features2, features3]
|
||||
transformations = []
|
||||
|
||||
# Create reference (e.g., use first map)
|
||||
reference = oms.FeatureMap(feature_maps[0])
|
||||
|
||||
# Align others to reference
|
||||
for fm in feature_maps[1:]:
|
||||
transformation = oms.TransformationDescription()
|
||||
aligner.align(fm, reference, transformation)
|
||||
transformations.append(transformation)
|
||||
|
||||
# Apply transformation
|
||||
transformer = oms.MapAlignmentTransformer()
|
||||
transformer.transformRetentionTimes(fm, transformation)
|
||||
```
|
||||
|
||||
## Feature Linking
|
||||
|
||||
### FeatureGroupingAlgorithm
|
||||
|
||||
Links features across samples to create consensus features.
|
||||
|
||||
#### FeatureGroupingAlgorithmQT
|
||||
|
||||
Quality threshold-based feature linking.
|
||||
|
||||
```python
|
||||
grouper = oms.FeatureGroupingAlgorithmQT()
|
||||
|
||||
params = grouper.getParameters()
|
||||
params.setValue("distance_RT:max_difference", 60.0) # seconds
|
||||
params.setValue("distance_MZ:max_difference", 10.0) # ppm
|
||||
params.setValue("distance_MZ:unit", "ppm")
|
||||
grouper.setParameters(params)
|
||||
|
||||
# Create consensus map
|
||||
consensus_map = oms.ConsensusMap()
|
||||
|
||||
# Group features from multiple samples
|
||||
feature_maps = [features1, features2, features3]
|
||||
grouper.group(feature_maps, consensus_map)
|
||||
|
||||
print(f"Created {consensus_map.size()} consensus features")
|
||||
```
|
||||
|
||||
#### FeatureGroupingAlgorithmKD
|
||||
|
||||
KD-tree based linking (faster for large datasets).
|
||||
|
||||
```python
|
||||
grouper_kd = oms.FeatureGroupingAlgorithmKD()
|
||||
|
||||
params = grouper_kd.getParameters()
|
||||
params.setValue("mz_unit", "ppm")
|
||||
params.setValue("mz_tolerance", 10.0)
|
||||
params.setValue("rt_tolerance", 30.0)
|
||||
grouper_kd.setParameters(params)
|
||||
|
||||
consensus_map = oms.ConsensusMap()
|
||||
grouper_kd.group(feature_maps, consensus_map)
|
||||
```
|
||||
|
||||
## Chromatographic Analysis
|
||||
|
||||
### ElutionPeakDetection
|
||||
|
||||
Detects elution peaks in chromatograms.
|
||||
|
||||
```python
|
||||
epd = oms.ElutionPeakDetection()
|
||||
|
||||
params = epd.getParameters()
|
||||
params.setValue("chrom_peak_snr", 3.0) # Signal-to-noise threshold
|
||||
params.setValue("chrom_fwhm", 5.0) # Expected FWHM (seconds)
|
||||
epd.setParameters(params)
|
||||
|
||||
# Apply to chromatograms
|
||||
for chrom in exp.getChromatograms():
|
||||
peaks = epd.detectPeaks(chrom)
|
||||
```
|
||||
|
||||
### MRMFeatureFinderScoring
|
||||
|
||||
Scoring and peak picking for targeted (MRM/SRM) experiments.
|
||||
|
||||
```python
|
||||
mrm_finder = oms.MRMFeatureFinderScoring()
|
||||
|
||||
params = mrm_finder.getParameters()
|
||||
params.setValue("TransitionGroupPicker:min_peak_width", 2.0)
|
||||
params.setValue("TransitionGroupPicker:recalculate_peaks", "true")
|
||||
params.setValue("TransitionGroupPicker:PeakPickerMRM:signal_to_noise", 1.0)
|
||||
mrm_finder.setParameters(params)
|
||||
|
||||
# Requires chromatograms
|
||||
features = oms.FeatureMap()
|
||||
mrm_finder.pickExperiment(chrom_exp, features, targets, transformation, swath_maps)
|
||||
```
|
||||
|
||||
## Quantification
|
||||
|
||||
### ProteinInference
|
||||
|
||||
Infers proteins from peptide identifications.
|
||||
|
||||
```python
|
||||
protein_inference = oms.BasicProteinInferenceAlgorithm()
|
||||
|
||||
# Apply to identification results
|
||||
protein_inference.run(peptide_ids, protein_ids)
|
||||
```
|
||||
|
||||
### IsobaricQuantification
|
||||
|
||||
Quantification for isobaric labeling (TMT, iTRAQ).
|
||||
|
||||
```python
|
||||
# For TMT/iTRAQ quantification
|
||||
iso_quant = oms.IsobaricQuantification()
|
||||
|
||||
params = iso_quant.getParameters()
|
||||
params.setValue("channel_116_description", "Sample1")
|
||||
params.setValue("channel_117_description", "Sample2")
|
||||
# ... configure all channels
|
||||
iso_quant.setParameters(params)
|
||||
|
||||
# Run quantification
|
||||
quant_method = oms.IsobaricQuantitationMethod.TMT_10PLEX
|
||||
quant_info = oms.IsobaricQuantifierStatistics()
|
||||
iso_quant.quantify(exp, quant_info)
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### BaselineFiltering
|
||||
|
||||
Removes baseline from spectra.
|
||||
|
||||
```python
|
||||
baseline = oms.TopHatFilter()
|
||||
|
||||
params = baseline.getParameters()
|
||||
params.setValue("struc_elem_length", 3.0) # Structuring element size
|
||||
params.setValue("struc_elem_unit", "Thomson")
|
||||
baseline.setParameters(params)
|
||||
|
||||
baseline.filterExperiment(exp)
|
||||
```
|
||||
|
||||
### SpectraMerger
|
||||
|
||||
Merges consecutive similar spectra.
|
||||
|
||||
```python
|
||||
merger = oms.SpectraMerger()
|
||||
|
||||
params = merger.getParameters()
|
||||
params.setValue("mz_binning_width", 0.05) # Binning width (Da)
|
||||
params.setValue("sort_blocks", "RT_ascending")
|
||||
merger.setParameters(params)
|
||||
|
||||
merger.mergeSpectra(exp)
|
||||
```
|
||||
|
||||
## Quality Control
|
||||
|
||||
### MzMLFileQuality
|
||||
|
||||
Analyzes mzML file quality.
|
||||
|
||||
```python
|
||||
# Calculate basic QC metrics
|
||||
def calculate_qc_metrics(exp):
|
||||
metrics = {
|
||||
'n_spectra': exp.getNrSpectra(),
|
||||
'n_ms1': sum(1 for s in exp if s.getMSLevel() == 1),
|
||||
'n_ms2': sum(1 for s in exp if s.getMSLevel() == 2),
|
||||
'rt_range': (exp.getMinRT(), exp.getMaxRT()),
|
||||
'mz_range': (exp.getMinMZ(), exp.getMaxMZ()),
|
||||
}
|
||||
|
||||
# Calculate TIC
|
||||
tics = []
|
||||
for spectrum in exp:
|
||||
if spectrum.getMSLevel() == 1:
|
||||
mz, intensity = spectrum.get_peaks()
|
||||
tics.append(sum(intensity))
|
||||
|
||||
metrics['median_tic'] = np.median(tics)
|
||||
metrics['mean_tic'] = np.mean(tics)
|
||||
|
||||
return metrics
|
||||
```
|
||||
|
||||
## FDR Control
|
||||
|
||||
### FalseDiscoveryRate
|
||||
|
||||
Estimates and controls false discovery rate.
|
||||
|
||||
```python
|
||||
fdr = oms.FalseDiscoveryRate()
|
||||
|
||||
params = fdr.getParameters()
|
||||
params.setValue("add_decoy_peptides", "false")
|
||||
params.setValue("add_decoy_proteins", "false")
|
||||
fdr.setParameters(params)
|
||||
|
||||
# Apply to identifications
|
||||
fdr.apply(protein_ids, peptide_ids)
|
||||
|
||||
# Filter by FDR threshold
|
||||
fdr_threshold = 0.01
|
||||
filtered_peptides = [p for p in peptide_ids if p.getMetaValue("q-value") <= fdr_threshold]
|
||||
```
|
||||
|
||||
## Algorithm Selection Guide
|
||||
|
||||
### When to Use Which Algorithm
|
||||
|
||||
**For Smoothing:**
|
||||
- Use `GaussFilter` for general-purpose smoothing
|
||||
- Use `SavitzkyGolayFilter` for preserving peak shapes
|
||||
|
||||
**For Peak Picking:**
|
||||
- Use `PeakPickerHiRes` for high-resolution Orbitrap/FT-ICR data
|
||||
- Use `PeakPickerWavelet` for lower-resolution TOF data
|
||||
|
||||
**For Feature Detection:**
|
||||
- Use `FeatureFinderCentroided` for label-free proteomics (DDA)
|
||||
- Use `FeatureFinderMultiplex` for SILAC/dimethyl labeling
|
||||
- Use `FeatureFinderIdentification` when you have ID information
|
||||
- Use `MRMFeatureFinderScoring` for targeted (MRM/SRM) experiments
|
||||
|
||||
**For Feature Linking:**
|
||||
- Use `FeatureGroupingAlgorithmQT` for small-medium datasets (<10 samples)
|
||||
- Use `FeatureGroupingAlgorithmKD` for large datasets (>10 samples)
|
||||
|
||||
## Parameter Tuning Tips
|
||||
|
||||
1. **S/N Thresholds**: Start with 1-3 for clean data, increase for noisy data
|
||||
2. **m/z Tolerance**: Use 5-10 ppm for high-resolution instruments, 0.5-1 Da for low-res
|
||||
3. **RT Tolerance**: Typically 30-60 seconds depending on chromatographic stability
|
||||
4. **Peak Width**: Measure from real data - varies by instrument and gradient length
|
||||
5. **Charge States**: Set based on expected analytes (1-2 for metabolites, 2-4 for peptides)
|
||||
|
||||
## Common Algorithm Workflows
|
||||
|
||||
### Complete Proteomics Workflow
|
||||
|
||||
```python
|
||||
# 1. Load data
|
||||
exp = oms.MSExperiment()
|
||||
oms.MzMLFile().load("raw.mzML", exp)
|
||||
|
||||
# 2. Smooth
|
||||
gauss = oms.GaussFilter()
|
||||
gauss.filterExperiment(exp)
|
||||
|
||||
# 3. Peak picking
|
||||
picker = oms.PeakPickerHiRes()
|
||||
exp_centroid = oms.MSExperiment()
|
||||
picker.pickExperiment(exp, exp_centroid)
|
||||
|
||||
# 4. Feature detection
|
||||
ff = oms.FeatureFinderCentroided()
|
||||
features = oms.FeatureMap()
|
||||
ff.run(exp_centroid, features, oms.Param(), oms.FeatureMap())
|
||||
|
||||
# 5. Save results
|
||||
oms.FeatureXMLFile().store("features.featureXML", features)
|
||||
```
|
||||
|
||||
### Multi-Sample Quantification
|
||||
|
||||
```python
|
||||
# Load multiple samples
|
||||
feature_maps = []
|
||||
for filename in ["sample1.mzML", "sample2.mzML", "sample3.mzML"]:
|
||||
exp = oms.MSExperiment()
|
||||
oms.MzMLFile().load(filename, exp)
|
||||
|
||||
# Process and detect features
|
||||
features = detect_features(exp) # Your processing function
|
||||
feature_maps.append(features)
|
||||
|
||||
# Align retention times
|
||||
align_feature_maps(feature_maps) # Implement alignment
|
||||
|
||||
# Link features
|
||||
grouper = oms.FeatureGroupingAlgorithmQT()
|
||||
consensus_map = oms.ConsensusMap()
|
||||
grouper.group(feature_maps, consensus_map)
|
||||
|
||||
# Export quantification matrix
|
||||
export_quant_matrix(consensus_map)
|
||||
```
|
||||
@@ -1,715 +0,0 @@
|
||||
# pyOpenMS Chemistry Reference
|
||||
|
||||
This document provides comprehensive coverage of chemistry-related functionality in pyOpenMS, including elements, isotopes, molecular formulas, amino acids, peptides, proteins, and modifications.
|
||||
|
||||
## Elements and Isotopes
|
||||
|
||||
### ElementDB - Element Database
|
||||
|
||||
Access atomic and isotopic data for all elements.
|
||||
|
||||
```python
|
||||
import pyopenms as oms
|
||||
|
||||
# Get element database instance
|
||||
element_db = oms.ElementDB()
|
||||
|
||||
# Get element by symbol
|
||||
carbon = element_db.getElement("C")
|
||||
nitrogen = element_db.getElement("N")
|
||||
oxygen = element_db.getElement("O")
|
||||
|
||||
# Element properties
|
||||
print(f"Carbon monoisotopic weight: {carbon.getMonoWeight()}")
|
||||
print(f"Carbon average weight: {carbon.getAverageWeight()}")
|
||||
print(f"Atomic number: {carbon.getAtomicNumber()}")
|
||||
print(f"Symbol: {carbon.getSymbol()}")
|
||||
print(f"Name: {carbon.getName()}")
|
||||
```
|
||||
|
||||
### Isotope Information
|
||||
|
||||
```python
|
||||
# Get isotope distribution for an element
|
||||
isotopes = carbon.getIsotopeDistribution()
|
||||
|
||||
# Access specific isotope
|
||||
c12 = element_db.getElement("C", 12) # Carbon-12
|
||||
c13 = element_db.getElement("C", 13) # Carbon-13
|
||||
|
||||
print(f"C-12 abundance: {isotopes.getContainer()[0].getIntensity()}")
|
||||
print(f"C-13 abundance: {isotopes.getContainer()[1].getIntensity()}")
|
||||
|
||||
# Isotope mass
|
||||
print(f"C-12 mass: {c12.getMonoWeight()}")
|
||||
print(f"C-13 mass: {c13.getMonoWeight()}")
|
||||
```
|
||||
|
||||
### Constants
|
||||
|
||||
```python
|
||||
# Physical constants
|
||||
avogadro = oms.Constants.AVOGADRO
|
||||
electron_mass = oms.Constants.ELECTRON_MASS_U
|
||||
proton_mass = oms.Constants.PROTON_MASS_U
|
||||
|
||||
print(f"Avogadro's number: {avogadro}")
|
||||
print(f"Electron mass: {electron_mass} u")
|
||||
print(f"Proton mass: {proton_mass} u")
|
||||
```
|
||||
|
||||
## Empirical Formulas
|
||||
|
||||
### EmpiricalFormula - Molecular Formulas
|
||||
|
||||
Represent and manipulate molecular formulas.
|
||||
|
||||
#### Creating Formulas
|
||||
|
||||
```python
|
||||
# From string
|
||||
glucose = oms.EmpiricalFormula("C6H12O6")
|
||||
water = oms.EmpiricalFormula("H2O")
|
||||
ammonia = oms.EmpiricalFormula("NH3")
|
||||
|
||||
# From element composition
|
||||
formula = oms.EmpiricalFormula()
|
||||
formula.setCharge(1) # Set charge state
|
||||
```
|
||||
|
||||
#### Formula Arithmetic
|
||||
|
||||
```python
|
||||
# Addition
|
||||
sucrose = oms.EmpiricalFormula("C12H22O11")
|
||||
hydrolyzed = sucrose + water # Hydrolysis adds water
|
||||
|
||||
# Subtraction
|
||||
dehydrated = glucose - water # Dehydration removes water
|
||||
|
||||
# Multiplication
|
||||
three_waters = water * 3 # 3 H2O = H6O3
|
||||
|
||||
# Division
|
||||
formula_half = sucrose / 2 # Half the formula
|
||||
```
|
||||
|
||||
#### Mass Calculations
|
||||
|
||||
```python
|
||||
# Monoisotopic mass
|
||||
mono_mass = glucose.getMonoWeight()
|
||||
print(f"Glucose monoisotopic mass: {mono_mass:.6f} Da")
|
||||
|
||||
# Average mass
|
||||
avg_mass = glucose.getAverageWeight()
|
||||
print(f"Glucose average mass: {avg_mass:.6f} Da")
|
||||
|
||||
# Mass difference
|
||||
mass_diff = (glucose - water).getMonoWeight()
|
||||
```
|
||||
|
||||
#### Elemental Composition
|
||||
|
||||
```python
|
||||
# Get element counts
|
||||
formula = oms.EmpiricalFormula("C6H12O6")
|
||||
|
||||
# Access individual elements
|
||||
n_carbon = formula.getNumberOf(element_db.getElement("C"))
|
||||
n_hydrogen = formula.getNumberOf(element_db.getElement("H"))
|
||||
n_oxygen = formula.getNumberOf(element_db.getElement("O"))
|
||||
|
||||
print(f"C: {n_carbon}, H: {n_hydrogen}, O: {n_oxygen}")
|
||||
|
||||
# String representation
|
||||
print(f"Formula: {formula.toString()}")
|
||||
```
|
||||
|
||||
#### Isotope-Specific Formulas
|
||||
|
||||
```python
|
||||
# Specify specific isotopes using parentheses
|
||||
labeled_glucose = oms.EmpiricalFormula("(13)C6H12O6") # All carbons are C-13
|
||||
partially_labeled = oms.EmpiricalFormula("C5(13)CH12O6") # One C-13
|
||||
|
||||
# Deuterium labeling
|
||||
deuterated = oms.EmpiricalFormula("C6D12O6") # D2O instead of H2O
|
||||
```
|
||||
|
||||
#### Charge States
|
||||
|
||||
```python
|
||||
# Set charge
|
||||
formula = oms.EmpiricalFormula("C6H12O6")
|
||||
formula.setCharge(1) # Positive charge
|
||||
|
||||
# Get charge
|
||||
charge = formula.getCharge()
|
||||
|
||||
# Calculate m/z for charged molecule
|
||||
mz = formula.getMonoWeight() / abs(charge) if charge != 0 else formula.getMonoWeight()
|
||||
```
|
||||
|
||||
### Isotope Pattern Generation
|
||||
|
||||
Generate theoretical isotope patterns for formulas.
|
||||
|
||||
#### CoarseIsotopePatternGenerator
|
||||
|
||||
For unit mass resolution (low-resolution instruments).
|
||||
|
||||
```python
|
||||
# Create generator
|
||||
coarse_gen = oms.CoarseIsotopePatternGenerator()
|
||||
|
||||
# Generate pattern
|
||||
formula = oms.EmpiricalFormula("C6H12O6")
|
||||
pattern = coarse_gen.run(formula)
|
||||
|
||||
# Access isotope peaks
|
||||
iso_dist = pattern.getContainer()
|
||||
for peak in iso_dist:
|
||||
mass = peak.getMZ()
|
||||
abundance = peak.getIntensity()
|
||||
print(f"m/z: {mass:.4f}, Abundance: {abundance:.4f}")
|
||||
```
|
||||
|
||||
#### FineIsotopePatternGenerator
|
||||
|
||||
For high-resolution instruments (hyperfine structure).
|
||||
|
||||
```python
|
||||
# Create generator with resolution
|
||||
fine_gen = oms.FineIsotopePatternGenerator(0.01) # 0.01 Da resolution
|
||||
|
||||
# Generate fine pattern
|
||||
fine_pattern = fine_gen.run(formula)
|
||||
|
||||
# Access fine isotope structure
|
||||
for peak in fine_pattern.getContainer():
|
||||
print(f"m/z: {peak.getMZ():.6f}, Abundance: {peak.getIntensity():.6f}")
|
||||
```
|
||||
|
||||
#### Isotope Pattern Matching
|
||||
|
||||
```python
|
||||
# Compare experimental to theoretical
|
||||
def compare_isotope_patterns(experimental_mz, experimental_int, formula):
|
||||
# Generate theoretical
|
||||
coarse_gen = oms.CoarseIsotopePatternGenerator()
|
||||
theoretical = coarse_gen.run(formula)
|
||||
|
||||
# Extract theoretical peaks
|
||||
theo_peaks = theoretical.getContainer()
|
||||
theo_mz = [p.getMZ() for p in theo_peaks]
|
||||
theo_int = [p.getIntensity() for p in theo_peaks]
|
||||
|
||||
# Normalize both patterns
|
||||
exp_int_norm = [i / max(experimental_int) for i in experimental_int]
|
||||
theo_int_norm = [i / max(theo_int) for i in theo_int]
|
||||
|
||||
# Calculate similarity (e.g., cosine similarity)
|
||||
# ... implement similarity calculation
|
||||
return similarity_score
|
||||
```
|
||||
|
||||
## Amino Acids and Residues
|
||||
|
||||
### Residue - Amino Acid Representation
|
||||
|
||||
Access properties of amino acids.
|
||||
|
||||
```python
|
||||
# Get residue database
|
||||
res_db = oms.ResidueDB()
|
||||
|
||||
# Get specific residue
|
||||
leucine = res_db.getResidue("Leucine")
|
||||
# Or by one-letter code
|
||||
leu = res_db.getResidue("L")
|
||||
|
||||
# Residue properties
|
||||
print(f"Name: {leucine.getName()}")
|
||||
print(f"Three-letter code: {leucine.getThreeLetterCode()}")
|
||||
print(f"One-letter code: {leucine.getOneLetterCode()}")
|
||||
print(f"Monoisotopic mass: {leucine.getMonoWeight():.6f}")
|
||||
print(f"Average mass: {leucine.getAverageWeight():.6f}")
|
||||
|
||||
# Chemical formula
|
||||
formula = leucine.getFormula()
|
||||
print(f"Formula: {formula.toString()}")
|
||||
|
||||
# pKa values
|
||||
print(f"pKa (N-term): {leucine.getPka()}")
|
||||
print(f"pKa (C-term): {leucine.getPkb()}")
|
||||
print(f"pKa (side chain): {leucine.getPkc()}")
|
||||
|
||||
# Side chain basicity/acidity
|
||||
print(f"Basicity: {leucine.getBasicity()}")
|
||||
print(f"Hydrophobicity: {leucine.getHydrophobicity()}")
|
||||
```
|
||||
|
||||
### All Standard Amino Acids
|
||||
|
||||
```python
|
||||
# Iterate over all residues
|
||||
for residue_name in ["Alanine", "Cysteine", "Aspartic acid", "Glutamic acid",
|
||||
"Phenylalanine", "Glycine", "Histidine", "Isoleucine",
|
||||
"Lysine", "Leucine", "Methionine", "Asparagine",
|
||||
"Proline", "Glutamine", "Arginine", "Serine",
|
||||
"Threonine", "Valine", "Tryptophan", "Tyrosine"]:
|
||||
res = res_db.getResidue(residue_name)
|
||||
print(f"{res.getOneLetterCode()}: {res.getMonoWeight():.4f} Da")
|
||||
```
|
||||
|
||||
### Internal Residues vs. Termini
|
||||
|
||||
```python
|
||||
# Get internal residue mass (no terminal groups)
|
||||
internal_mass = leucine.getInternalToFull()
|
||||
|
||||
# Get residue with N-terminal modification
|
||||
n_terminal = res_db.getResidue("L[1]") # With NH2
|
||||
|
||||
# Get residue with C-terminal modification
|
||||
c_terminal = res_db.getResidue("L[2]") # With COOH
|
||||
```
|
||||
|
||||
## Peptide Sequences
|
||||
|
||||
### AASequence - Amino Acid Sequences
|
||||
|
||||
Represent and manipulate peptide sequences.
|
||||
|
||||
#### Creating Sequences
|
||||
|
||||
```python
|
||||
# From string
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
longer = oms.AASequence.fromString("MKTAYIAKQRQISFVK")
|
||||
|
||||
# Empty sequence
|
||||
empty_seq = oms.AASequence()
|
||||
```
|
||||
|
||||
#### Sequence Properties
|
||||
|
||||
```python
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
|
||||
# Length
|
||||
length = peptide.size()
|
||||
print(f"Length: {length} residues")
|
||||
|
||||
# Mass
|
||||
mono_mass = peptide.getMonoWeight()
|
||||
avg_mass = peptide.getAverageWeight()
|
||||
print(f"Monoisotopic mass: {mono_mass:.6f} Da")
|
||||
print(f"Average mass: {avg_mass:.6f} Da")
|
||||
|
||||
# Formula
|
||||
formula = peptide.getFormula()
|
||||
print(f"Formula: {formula.toString()}")
|
||||
|
||||
# String representation
|
||||
seq_str = peptide.toString()
|
||||
print(f"Sequence: {seq_str}")
|
||||
```
|
||||
|
||||
#### Accessing Individual Residues
|
||||
|
||||
```python
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
|
||||
# Access by index
|
||||
first_aa = peptide[0] # Returns Residue object
|
||||
print(f"First amino acid: {first_aa.getOneLetterCode()}")
|
||||
|
||||
# Iterate
|
||||
for i in range(peptide.size()):
|
||||
residue = peptide[i]
|
||||
print(f"Position {i}: {residue.getOneLetterCode()}")
|
||||
```
|
||||
|
||||
#### Modifications
|
||||
|
||||
Add post-translational modifications (PTMs) to sequences.
|
||||
|
||||
```python
|
||||
# Modifications in sequence string
|
||||
# Format: AA(ModificationName)
|
||||
oxidized_met = oms.AASequence.fromString("PEPTIDEM(Oxidation)")
|
||||
phospho = oms.AASequence.fromString("PEPTIDES(Phospho)T(Phospho)")
|
||||
|
||||
# Multiple modifications
|
||||
multi_mod = oms.AASequence.fromString("M(Oxidation)PEPTIDEK(Acetyl)")
|
||||
|
||||
# N-terminal modifications
|
||||
n_term_acetyl = oms.AASequence.fromString("(Acetyl)PEPTIDE")
|
||||
|
||||
# C-terminal modifications
|
||||
c_term_amide = oms.AASequence.fromString("PEPTIDE(Amidated)")
|
||||
|
||||
# Check mass change
|
||||
unmodified = oms.AASequence.fromString("PEPTIDE")
|
||||
modified = oms.AASequence.fromString("PEPTIDEM(Oxidation)")
|
||||
mass_diff = modified.getMonoWeight() - unmodified.getMonoWeight()
|
||||
print(f"Mass shift from oxidation: {mass_diff:.6f} Da")
|
||||
```
|
||||
|
||||
#### Sequence Manipulation
|
||||
|
||||
```python
|
||||
# Prefix (N-terminal fragment)
|
||||
prefix = peptide.getPrefix(3) # First 3 residues
|
||||
print(f"Prefix: {prefix.toString()}")
|
||||
|
||||
# Suffix (C-terminal fragment)
|
||||
suffix = peptide.getSuffix(3) # Last 3 residues
|
||||
print(f"Suffix: {suffix.toString()}")
|
||||
|
||||
# Subsequence
|
||||
subseq = peptide.getSubsequence(2, 4) # Residues 2-4
|
||||
print(f"Subsequence: {subseq.toString()}")
|
||||
```
|
||||
|
||||
#### Theoretical Fragmentation
|
||||
|
||||
Generate theoretical fragment ions for MS/MS.
|
||||
|
||||
```python
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
|
||||
# b-ions (N-terminal fragments)
|
||||
b_ions = []
|
||||
for i in range(1, peptide.size()):
|
||||
b_fragment = peptide.getPrefix(i)
|
||||
b_mass = b_fragment.getMonoWeight()
|
||||
b_ions.append(('b', i, b_mass))
|
||||
print(f"b{i}: {b_mass:.4f}")
|
||||
|
||||
# y-ions (C-terminal fragments)
|
||||
y_ions = []
|
||||
for i in range(1, peptide.size()):
|
||||
y_fragment = peptide.getSuffix(i)
|
||||
y_mass = y_fragment.getMonoWeight()
|
||||
y_ions.append(('y', i, y_mass))
|
||||
print(f"y{i}: {y_mass:.4f}")
|
||||
|
||||
# a-ions (b - CO)
|
||||
a_ions = []
|
||||
CO_mass = 27.994915 # CO loss
|
||||
for ion_type, position, mass in b_ions:
|
||||
a_mass = mass - CO_mass
|
||||
a_ions.append(('a', position, a_mass))
|
||||
|
||||
# c-ions (b + NH3)
|
||||
NH3_mass = 17.026549 # NH3 gain
|
||||
c_ions = []
|
||||
for ion_type, position, mass in b_ions:
|
||||
c_mass = mass + NH3_mass
|
||||
c_ions.append(('c', position, c_mass))
|
||||
|
||||
# z-ions (y - NH3)
|
||||
z_ions = []
|
||||
for ion_type, position, mass in y_ions:
|
||||
z_mass = mass - NH3_mass
|
||||
z_ions.append(('z', position, z_mass))
|
||||
```
|
||||
|
||||
#### Calculate m/z for Charge States
|
||||
|
||||
```python
|
||||
peptide = oms.AASequence.fromString("PEPTIDE")
|
||||
proton_mass = 1.007276
|
||||
|
||||
# [M+H]+
|
||||
mz_1 = peptide.getMonoWeight() + proton_mass
|
||||
print(f"[M+H]+: {mz_1:.4f}")
|
||||
|
||||
# [M+2H]2+
|
||||
mz_2 = (peptide.getMonoWeight() + 2 * proton_mass) / 2
|
||||
print(f"[M+2H]2+: {mz_2:.4f}")
|
||||
|
||||
# [M+3H]3+
|
||||
mz_3 = (peptide.getMonoWeight() + 3 * proton_mass) / 3
|
||||
print(f"[M+3H]3+: {mz_3:.4f}")
|
||||
|
||||
# General formula for any charge
|
||||
def calculate_mz(sequence, charge):
|
||||
proton_mass = 1.007276
|
||||
return (sequence.getMonoWeight() + charge * proton_mass) / charge
|
||||
|
||||
for z in range(1, 5):
|
||||
print(f"[M+{z}H]{z}+: {calculate_mz(peptide, z):.4f}")
|
||||
```
|
||||
|
||||
## Protein Digestion
|
||||
|
||||
### ProteaseDigestion - Enzymatic Cleavage
|
||||
|
||||
Simulate enzymatic protein digestion.
|
||||
|
||||
#### Basic Digestion
|
||||
|
||||
```python
|
||||
# Create digestion object
|
||||
dig = oms.ProteaseDigestion()
|
||||
|
||||
# Set enzyme
|
||||
dig.setEnzyme("Trypsin") # Cleaves after K, R
|
||||
|
||||
# Other common enzymes:
|
||||
# - "Trypsin" (K, R)
|
||||
# - "Lys-C" (K)
|
||||
# - "Arg-C" (R)
|
||||
# - "Asp-N" (D)
|
||||
# - "Glu-C" (E, D)
|
||||
# - "Chymotrypsin" (F, Y, W, L)
|
||||
|
||||
# Set missed cleavages
|
||||
dig.setMissedCleavages(0) # No missed cleavages
|
||||
dig.setMissedCleavages(2) # Allow up to 2 missed cleavages
|
||||
|
||||
# Perform digestion
|
||||
protein = oms.AASequence.fromString("MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK")
|
||||
peptides = []
|
||||
dig.digest(protein, peptides)
|
||||
|
||||
# Print results
|
||||
for pep in peptides:
|
||||
print(f"{pep.toString()}: {pep.getMonoWeight():.2f} Da")
|
||||
```
|
||||
|
||||
#### Advanced Digestion Options
|
||||
|
||||
```python
|
||||
# Get enzyme specificity
|
||||
specificity = dig.getSpecificity()
|
||||
# oms.EnzymaticDigestion.SPEC_FULL (both termini)
|
||||
# oms.EnzymaticDigestion.SPEC_SEMI (one terminus)
|
||||
# oms.EnzymaticDigestion.SPEC_NONE (no specificity)
|
||||
|
||||
# Set specificity for semi-tryptic search
|
||||
dig.setSpecificity(oms.EnzymaticDigestion.SPEC_SEMI)
|
||||
|
||||
# Get cleavage sites
|
||||
cleavage_residues = dig.getEnzyme().getCutAfterResidues()
|
||||
restriction_residues = dig.getEnzyme().getRestriction()
|
||||
```
|
||||
|
||||
#### Filter Peptides by Properties
|
||||
|
||||
```python
|
||||
# Filter by mass range
|
||||
min_mass = 600.0
|
||||
max_mass = 4000.0
|
||||
filtered = [p for p in peptides if min_mass <= p.getMonoWeight() <= max_mass]
|
||||
|
||||
# Filter by length
|
||||
min_length = 6
|
||||
max_length = 30
|
||||
length_filtered = [p for p in peptides if min_length <= p.size() <= max_length]
|
||||
|
||||
# Combine filters
|
||||
valid_peptides = [p for p in peptides
|
||||
if min_mass <= p.getMonoWeight() <= max_mass
|
||||
and min_length <= p.size() <= max_length]
|
||||
```
|
||||
|
||||
## Modifications
|
||||
|
||||
### ModificationsDB - Modification Database
|
||||
|
||||
Access and apply post-translational modifications.
|
||||
|
||||
#### Accessing Modifications
|
||||
|
||||
```python
|
||||
# Get modifications database
|
||||
mod_db = oms.ModificationsDB()
|
||||
|
||||
# Get specific modification
|
||||
oxidation = mod_db.getModification("Oxidation")
|
||||
phospho = mod_db.getModification("Phospho")
|
||||
acetyl = mod_db.getModification("Acetyl")
|
||||
|
||||
# Modification properties
|
||||
print(f"Name: {oxidation.getFullName()}")
|
||||
print(f"Mass difference: {oxidation.getDiffMonoMass():.6f} Da")
|
||||
print(f"Formula: {oxidation.getDiffFormula().toString()}")
|
||||
|
||||
# Affected residues
|
||||
print(f"Residues: {oxidation.getResidues()}") # e.g., ['M']
|
||||
|
||||
# Specificity (N-term, C-term, anywhere)
|
||||
print(f"Term specificity: {oxidation.getTermSpecificity()}")
|
||||
```
|
||||
|
||||
#### Common Modifications
|
||||
|
||||
```python
|
||||
# Oxidation (M)
|
||||
oxidation = mod_db.getModification("Oxidation")
|
||||
print(f"Oxidation: +{oxidation.getDiffMonoMass():.4f} Da")
|
||||
|
||||
# Phosphorylation (S, T, Y)
|
||||
phospho = mod_db.getModification("Phospho")
|
||||
print(f"Phospho: +{phospho.getDiffMonoMass():.4f} Da")
|
||||
|
||||
# Carbamidomethylation (C) - common alkylation
|
||||
carbamido = mod_db.getModification("Carbamidomethyl")
|
||||
print(f"Carbamidomethyl: +{carbamido.getDiffMonoMass():.4f} Da")
|
||||
|
||||
# Acetylation (K, N-term)
|
||||
acetyl = mod_db.getModification("Acetyl")
|
||||
print(f"Acetyl: +{acetyl.getDiffMonoMass():.4f} Da")
|
||||
|
||||
# Deamidation (N, Q)
|
||||
deamid = mod_db.getModification("Deamidated")
|
||||
print(f"Deamidation: +{deamid.getDiffMonoMass():.4f} Da")
|
||||
```
|
||||
|
||||
#### Searching Modifications
|
||||
|
||||
```python
|
||||
# Search modifications by mass
|
||||
mass_tolerance = 0.01 # Da
|
||||
target_mass = 15.9949 # Oxidation
|
||||
|
||||
# Get all modifications
|
||||
all_mods = []
|
||||
mod_db.getAllSearchModifications(all_mods)
|
||||
|
||||
# Find matching modifications
|
||||
matching = []
|
||||
for mod_name in all_mods:
|
||||
mod = mod_db.getModification(mod_name)
|
||||
if abs(mod.getDiffMonoMass() - target_mass) < mass_tolerance:
|
||||
matching.append(mod)
|
||||
print(f"Match: {mod.getFullName()} ({mod.getDiffMonoMass():.4f} Da)")
|
||||
```
|
||||
|
||||
#### Variable vs. Fixed Modifications
|
||||
|
||||
```python
|
||||
# In search engines, specify:
|
||||
# Fixed modifications: applied to all occurrences
|
||||
fixed_mods = ["Carbamidomethyl (C)"]
|
||||
|
||||
# Variable modifications: optionally present
|
||||
variable_mods = ["Oxidation (M)", "Phospho (S)", "Phospho (T)", "Phospho (Y)"]
|
||||
```
|
||||
|
||||
## Ribonucleotides (RNA)
|
||||
|
||||
### Ribonucleotide - RNA Building Blocks
|
||||
|
||||
```python
|
||||
# Get ribonucleotide database
|
||||
ribo_db = oms.RibonucleotideDB()
|
||||
|
||||
# Get specific ribonucleotide
|
||||
adenine = ribo_db.getRibonucleotide("A")
|
||||
uracil = ribo_db.getRibonucleotide("U")
|
||||
guanine = ribo_db.getRibonucleotide("G")
|
||||
cytosine = ribo_db.getRibonucleotide("C")
|
||||
|
||||
# Properties
|
||||
print(f"Adenine mono mass: {adenine.getMonoWeight()}")
|
||||
print(f"Formula: {adenine.getFormula().toString()}")
|
||||
|
||||
# Modified ribonucleotides
|
||||
modified_ribo = ribo_db.getRibonucleotide("m6A") # N6-methyladenosine
|
||||
```
|
||||
|
||||
## Practical Examples
|
||||
|
||||
### Calculate Peptide Mass with Modifications
|
||||
|
||||
```python
|
||||
def calculate_peptide_mz(sequence_str, charge):
|
||||
"""Calculate m/z for a peptide sequence string with modifications."""
|
||||
peptide = oms.AASequence.fromString(sequence_str)
|
||||
proton_mass = 1.007276
|
||||
mz = (peptide.getMonoWeight() + charge * proton_mass) / charge
|
||||
return mz
|
||||
|
||||
# Examples
|
||||
print(calculate_peptide_mz("PEPTIDE", 2)) # Unmodified [M+2H]2+
|
||||
print(calculate_peptide_mz("PEPTIDEM(Oxidation)", 2)) # With oxidation
|
||||
print(calculate_peptide_mz("(Acetyl)PEPTIDEK(Acetyl)", 2)) # Acetylated
|
||||
```
|
||||
|
||||
### Generate Complete Fragment Ion Series
|
||||
|
||||
```python
|
||||
def generate_fragment_ions(sequence_str, charge_states=[1, 2]):
|
||||
"""Generate comprehensive fragment ion list."""
|
||||
peptide = oms.AASequence.fromString(sequence_str)
|
||||
proton_mass = 1.007276
|
||||
fragments = []
|
||||
|
||||
for i in range(1, peptide.size()):
|
||||
# b and y ions
|
||||
b_frag = peptide.getPrefix(i)
|
||||
y_frag = peptide.getSuffix(i)
|
||||
|
||||
for z in charge_states:
|
||||
b_mz = (b_frag.getMonoWeight() + z * proton_mass) / z
|
||||
y_mz = (y_frag.getMonoWeight() + z * proton_mass) / z
|
||||
|
||||
fragments.append({
|
||||
'type': 'b',
|
||||
'position': i,
|
||||
'charge': z,
|
||||
'mz': b_mz
|
||||
})
|
||||
fragments.append({
|
||||
'type': 'y',
|
||||
'position': i,
|
||||
'charge': z,
|
||||
'mz': y_mz
|
||||
})
|
||||
|
||||
return fragments
|
||||
|
||||
# Usage
|
||||
ions = generate_fragment_ions("PEPTIDE", charge_states=[1, 2])
|
||||
for ion in ions:
|
||||
print(f"{ion['type']}{ion['position']}^{ion['charge']}+: {ion['mz']:.4f}")
|
||||
```
|
||||
|
||||
### Digest Protein and Calculate Peptide Masses
|
||||
|
||||
```python
|
||||
def digest_and_calculate(protein_seq_str, enzyme="Trypsin", missed_cleavages=2,
|
||||
min_mass=600, max_mass=4000):
|
||||
"""Digest protein and return valid peptides with masses."""
|
||||
dig = oms.ProteaseDigestion()
|
||||
dig.setEnzyme(enzyme)
|
||||
dig.setMissedCleavages(missed_cleavages)
|
||||
|
||||
protein = oms.AASequence.fromString(protein_seq_str)
|
||||
peptides = []
|
||||
dig.digest(protein, peptides)
|
||||
|
||||
results = []
|
||||
for pep in peptides:
|
||||
mass = pep.getMonoWeight()
|
||||
if min_mass <= mass <= max_mass:
|
||||
results.append({
|
||||
'sequence': pep.toString(),
|
||||
'mass': mass,
|
||||
'length': pep.size()
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
protein = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK"
|
||||
peptides = digest_and_calculate(protein)
|
||||
for pep in peptides:
|
||||
print(f"{pep['sequence']}: {pep['mass']:.2f} Da ({pep['length']} aa)")
|
||||
```
|
||||
@@ -1,560 +0,0 @@
|
||||
# pyOpenMS Data Structures Reference
|
||||
|
||||
This document provides comprehensive coverage of core data structures in pyOpenMS for representing mass spectrometry data.
|
||||
|
||||
## Core Hierarchy
|
||||
|
||||
```
|
||||
MSExperiment # Top-level: Complete LC-MS/MS run
|
||||
├── MSSpectrum[] # Collection of mass spectra
|
||||
│ ├── Peak1D[] # Individual m/z, intensity pairs
|
||||
│ └── SpectrumSettings # Metadata (RT, MS level, precursor)
|
||||
└── MSChromatogram[] # Collection of chromatograms
|
||||
├── ChromatogramPeak[] # RT, intensity pairs
|
||||
└── ChromatogramSettings # Metadata
|
||||
```
|
||||
|
||||
## MSSpectrum
|
||||
|
||||
Represents a single mass spectrum (1-dimensional peak data).
|
||||
|
||||
### Creation and Basic Properties
|
||||
|
||||
```python
|
||||
import pyopenms as oms
|
||||
|
||||
# Create empty spectrum
|
||||
spectrum = oms.MSSpectrum()
|
||||
|
||||
# Set metadata
|
||||
spectrum.setRT(123.45) # Retention time in seconds
|
||||
spectrum.setMSLevel(1) # MS level (1 for MS1, 2 for MS2, etc.)
|
||||
spectrum.setNativeID("scan=1234") # Native ID from file
|
||||
|
||||
# Additional metadata
|
||||
spectrum.setDriftTime(15.2) # Ion mobility drift time
|
||||
spectrum.setName("MyScan") # Optional name
|
||||
```
|
||||
|
||||
### Peak Data Management
|
||||
|
||||
**Setting Peaks (Method 1 - Lists):**
|
||||
```python
|
||||
mz_values = [100.5, 200.3, 300.7, 400.2, 500.1]
|
||||
intensity_values = [1000, 5000, 3000, 2000, 1500]
|
||||
|
||||
spectrum.set_peaks((mz_values, intensity_values))
|
||||
```
|
||||
|
||||
**Setting Peaks (Method 2 - NumPy arrays):**
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
mz_array = np.array([100.5, 200.3, 300.7, 400.2, 500.1])
|
||||
intensity_array = np.array([1000, 5000, 3000, 2000, 1500])
|
||||
|
||||
spectrum.set_peaks((mz_array, intensity_array))
|
||||
```
|
||||
|
||||
**Retrieving Peaks:**
|
||||
```python
|
||||
# Get as numpy arrays (efficient)
|
||||
mz_array, intensity_array = spectrum.get_peaks()
|
||||
|
||||
# Check number of peaks
|
||||
n_peaks = spectrum.size()
|
||||
|
||||
# Get individual peak (slower)
|
||||
for i in range(spectrum.size()):
|
||||
peak = spectrum[i]
|
||||
mz = peak.getMZ()
|
||||
intensity = peak.getIntensity()
|
||||
```
|
||||
|
||||
### Precursor Information (for MS2/MSn spectra)
|
||||
|
||||
```python
|
||||
# Create precursor
|
||||
precursor = oms.Precursor()
|
||||
precursor.setMZ(456.789) # Precursor m/z
|
||||
precursor.setCharge(2) # Precursor charge
|
||||
precursor.setIntensity(50000) # Precursor intensity
|
||||
precursor.setIsolationWindowLowerOffset(1.5) # Lower isolation window
|
||||
precursor.setIsolationWindowUpperOffset(1.5) # Upper isolation window
|
||||
|
||||
# Set activation method
|
||||
activation = oms.Activation()
|
||||
activation.setActivationEnergy(35.0) # Collision energy
|
||||
activation.setMethod(oms.Activation.ActivationMethod.CID)
|
||||
precursor.setActivation(activation)
|
||||
|
||||
# Assign to spectrum
|
||||
spectrum.setPrecursors([precursor])
|
||||
|
||||
# Retrieve precursor information
|
||||
precursors = spectrum.getPrecursors()
|
||||
if len(precursors) > 0:
|
||||
prec = precursors[0]
|
||||
print(f"Precursor m/z: {prec.getMZ()}")
|
||||
print(f"Precursor charge: {prec.getCharge()}")
|
||||
```
|
||||
|
||||
### Spectrum Metadata Access
|
||||
|
||||
```python
|
||||
# Check if spectrum is sorted by m/z
|
||||
is_sorted = spectrum.isSorted()
|
||||
|
||||
# Sort spectrum by m/z
|
||||
spectrum.sortByPosition()
|
||||
|
||||
# Sort by intensity
|
||||
spectrum.sortByIntensity()
|
||||
|
||||
# Clear all peaks
|
||||
spectrum.clear(False) # False = keep metadata, True = clear everything
|
||||
|
||||
# Get retention time
|
||||
rt = spectrum.getRT()
|
||||
|
||||
# Get MS level
|
||||
ms_level = spectrum.getMSLevel()
|
||||
```
|
||||
|
||||
### Spectrum Types and Modes
|
||||
|
||||
```python
|
||||
# Set spectrum type
|
||||
spectrum.setType(oms.SpectrumSettings.SpectrumType.CENTROID) # or PROFILE
|
||||
|
||||
# Get spectrum type
|
||||
spec_type = spectrum.getType()
|
||||
if spec_type == oms.SpectrumSettings.SpectrumType.CENTROID:
|
||||
print("Centroid spectrum")
|
||||
elif spec_type == oms.SpectrumSettings.SpectrumType.PROFILE:
|
||||
print("Profile spectrum")
|
||||
```
|
||||
|
||||
### Data Processing Annotations
|
||||
|
||||
```python
|
||||
# Add processing information
|
||||
processing = oms.DataProcessing()
|
||||
processing.setMetaValue("smoothing", "gaussian")
|
||||
spectrum.setDataProcessing([processing])
|
||||
```
|
||||
|
||||
## MSExperiment
|
||||
|
||||
Represents a complete LC-MS/MS experiment containing multiple spectra and chromatograms.
|
||||
|
||||
### Creation and Population
|
||||
|
||||
```python
|
||||
# Create empty experiment
|
||||
exp = oms.MSExperiment()
|
||||
|
||||
# Add spectra
|
||||
spectrum1 = oms.MSSpectrum()
|
||||
spectrum1.setRT(100.0)
|
||||
spectrum1.set_peaks(([100, 200], [1000, 2000]))
|
||||
|
||||
spectrum2 = oms.MSSpectrum()
|
||||
spectrum2.setRT(200.0)
|
||||
spectrum2.set_peaks(([100, 200], [1500, 2500]))
|
||||
|
||||
exp.addSpectrum(spectrum1)
|
||||
exp.addSpectrum(spectrum2)
|
||||
|
||||
# Add chromatograms
|
||||
chrom = oms.MSChromatogram()
|
||||
chrom.set_peaks(([10.5, 11.0, 11.5], [1000, 5000, 3000]))
|
||||
exp.addChromatogram(chrom)
|
||||
```
|
||||
|
||||
### Accessing Spectra and Chromatograms
|
||||
|
||||
```python
|
||||
# Get number of spectra and chromatograms
|
||||
n_spectra = exp.getNrSpectra()
|
||||
n_chroms = exp.getNrChromatograms()
|
||||
|
||||
# Access by index
|
||||
first_spectrum = exp.getSpectrum(0)
|
||||
last_spectrum = exp.getSpectrum(exp.getNrSpectra() - 1)
|
||||
|
||||
# Iterate over all spectra
|
||||
for spectrum in exp:
|
||||
rt = spectrum.getRT()
|
||||
ms_level = spectrum.getMSLevel()
|
||||
n_peaks = spectrum.size()
|
||||
print(f"RT: {rt:.2f}s, MS{ms_level}, Peaks: {n_peaks}")
|
||||
|
||||
# Get all spectra as list
|
||||
spectra = exp.getSpectra()
|
||||
|
||||
# Access chromatograms
|
||||
chrom = exp.getChromatogram(0)
|
||||
```
|
||||
|
||||
### Filtering Operations
|
||||
|
||||
```python
|
||||
# Filter by MS level
|
||||
exp.filterMSLevel(1) # Keep only MS1 spectra
|
||||
exp.filterMSLevel(2) # Keep only MS2 spectra
|
||||
|
||||
# Filter by retention time range
|
||||
exp.filterRT(100.0, 500.0) # Keep RT between 100-500 seconds
|
||||
|
||||
# Filter by m/z range (all spectra)
|
||||
exp.filterMZ(300.0, 1500.0) # Keep m/z between 300-1500
|
||||
|
||||
# Filter by scan number
|
||||
exp.filterScanNumber(100, 200) # Keep scans 100-200
|
||||
```
|
||||
|
||||
### Metadata and Properties
|
||||
|
||||
```python
|
||||
# Set experiment metadata
|
||||
exp.setMetaValue("operator", "John Doe")
|
||||
exp.setMetaValue("instrument", "Q Exactive HF")
|
||||
|
||||
# Get metadata
|
||||
operator = exp.getMetaValue("operator")
|
||||
|
||||
# Get RT range
|
||||
rt_range = exp.getMinRT(), exp.getMaxRT()
|
||||
|
||||
# Get m/z range
|
||||
mz_range = exp.getMinMZ(), exp.getMaxMZ()
|
||||
|
||||
# Clear all data
|
||||
exp.clear(False) # False = keep metadata
|
||||
```
|
||||
|
||||
### Sorting and Organization
|
||||
|
||||
```python
|
||||
# Sort spectra by retention time
|
||||
exp.sortSpectra()
|
||||
|
||||
# Update ranges (call after modifications)
|
||||
exp.updateRanges()
|
||||
|
||||
# Check if experiment is empty
|
||||
is_empty = exp.empty()
|
||||
|
||||
# Reset (clear everything)
|
||||
exp.reset()
|
||||
```
|
||||
|
||||
## MSChromatogram
|
||||
|
||||
Represents an extracted or reconstructed chromatogram (retention time vs. intensity).
|
||||
|
||||
### Creation and Basic Usage
|
||||
|
||||
```python
|
||||
# Create chromatogram
|
||||
chrom = oms.MSChromatogram()
|
||||
|
||||
# Set peaks (RT, intensity pairs)
|
||||
rt_values = [10.0, 10.5, 11.0, 11.5, 12.0]
|
||||
intensity_values = [1000, 5000, 8000, 6000, 2000]
|
||||
chrom.set_peaks((rt_values, intensity_values))
|
||||
|
||||
# Get peaks
|
||||
rt_array, int_array = chrom.get_peaks()
|
||||
|
||||
# Get size
|
||||
n_points = chrom.size()
|
||||
```
|
||||
|
||||
### Chromatogram Types
|
||||
|
||||
```python
|
||||
# Set chromatogram type
|
||||
chrom.setChromatogramType(oms.ChromatogramSettings.ChromatogramType.SELECTED_ION_CURRENT_CHROMATOGRAM)
|
||||
|
||||
# Other types:
|
||||
# - TOTAL_ION_CURRENT_CHROMATOGRAM
|
||||
# - BASEPEAK_CHROMATOGRAM
|
||||
# - SELECTED_ION_CURRENT_CHROMATOGRAM
|
||||
# - SELECTED_REACTION_MONITORING_CHROMATOGRAM
|
||||
```
|
||||
|
||||
### Metadata
|
||||
|
||||
```python
|
||||
# Set native ID
|
||||
chrom.setNativeID("TIC")
|
||||
|
||||
# Set name
|
||||
chrom.setName("Total Ion Current")
|
||||
|
||||
# Access
|
||||
native_id = chrom.getNativeID()
|
||||
name = chrom.getName()
|
||||
```
|
||||
|
||||
### Precursor and Product Information (for SRM/MRM)
|
||||
|
||||
```python
|
||||
# For targeted experiments
|
||||
precursor = oms.Precursor()
|
||||
precursor.setMZ(456.7)
|
||||
chrom.setPrecursor(precursor)
|
||||
|
||||
product = oms.Product()
|
||||
product.setMZ(789.4)
|
||||
chrom.setProduct(product)
|
||||
```
|
||||
|
||||
## Peak1D and ChromatogramPeak
|
||||
|
||||
Individual peak data points.
|
||||
|
||||
### Peak1D (for mass spectra)
|
||||
|
||||
```python
|
||||
# Create individual peak
|
||||
peak = oms.Peak1D()
|
||||
peak.setMZ(456.789)
|
||||
peak.setIntensity(10000)
|
||||
|
||||
# Access
|
||||
mz = peak.getMZ()
|
||||
intensity = peak.getIntensity()
|
||||
|
||||
# Set position and intensity
|
||||
peak.setPosition([456.789])
|
||||
peak.setIntensity(10000)
|
||||
```
|
||||
|
||||
### ChromatogramPeak (for chromatograms)
|
||||
|
||||
```python
|
||||
# Create chromatogram peak
|
||||
chrom_peak = oms.ChromatogramPeak()
|
||||
chrom_peak.setRT(125.5)
|
||||
chrom_peak.setIntensity(5000)
|
||||
|
||||
# Access
|
||||
rt = chrom_peak.getRT()
|
||||
intensity = chrom_peak.getIntensity()
|
||||
```
|
||||
|
||||
## FeatureMap and Feature
|
||||
|
||||
For quantification results.
|
||||
|
||||
### Feature
|
||||
|
||||
Represents a detected LC-MS feature (peptide or metabolite signal).
|
||||
|
||||
```python
|
||||
# Create feature
|
||||
feature = oms.Feature()
|
||||
|
||||
# Set properties
|
||||
feature.setMZ(456.789)
|
||||
feature.setRT(123.45)
|
||||
feature.setIntensity(1000000)
|
||||
feature.setCharge(2)
|
||||
feature.setWidth(15.0) # RT width in seconds
|
||||
|
||||
# Set quality score
|
||||
feature.setOverallQuality(0.95)
|
||||
|
||||
# Access
|
||||
mz = feature.getMZ()
|
||||
rt = feature.getRT()
|
||||
intensity = feature.getIntensity()
|
||||
charge = feature.getCharge()
|
||||
```
|
||||
|
||||
### FeatureMap
|
||||
|
||||
Collection of features.
|
||||
|
||||
```python
|
||||
# Create feature map
|
||||
feature_map = oms.FeatureMap()
|
||||
|
||||
# Add features
|
||||
feature1 = oms.Feature()
|
||||
feature1.setMZ(456.789)
|
||||
feature1.setRT(123.45)
|
||||
feature1.setIntensity(1000000)
|
||||
|
||||
feature_map.push_back(feature1)
|
||||
|
||||
# Get size
|
||||
n_features = feature_map.size()
|
||||
|
||||
# Iterate
|
||||
for feature in feature_map:
|
||||
print(f"m/z: {feature.getMZ():.4f}, RT: {feature.getRT():.2f}")
|
||||
|
||||
# Access by index
|
||||
first_feature = feature_map[0]
|
||||
|
||||
# Clear
|
||||
feature_map.clear()
|
||||
```
|
||||
|
||||
## PeptideIdentification and ProteinIdentification
|
||||
|
||||
For identification results.
|
||||
|
||||
### PeptideIdentification
|
||||
|
||||
```python
|
||||
# Create peptide identification
|
||||
pep_id = oms.PeptideIdentification()
|
||||
pep_id.setRT(123.45)
|
||||
pep_id.setMZ(456.789)
|
||||
|
||||
# Create peptide hit
|
||||
hit = oms.PeptideHit()
|
||||
hit.setSequence(oms.AASequence.fromString("PEPTIDE"))
|
||||
hit.setCharge(2)
|
||||
hit.setScore(25.5)
|
||||
hit.setRank(1)
|
||||
|
||||
# Add to identification
|
||||
pep_id.setHits([hit])
|
||||
pep_id.setHigherScoreBetter(True)
|
||||
pep_id.setScoreType("XCorr")
|
||||
|
||||
# Access
|
||||
hits = pep_id.getHits()
|
||||
for hit in hits:
|
||||
seq = hit.getSequence().toString()
|
||||
score = hit.getScore()
|
||||
print(f"Sequence: {seq}, Score: {score}")
|
||||
```
|
||||
|
||||
### ProteinIdentification
|
||||
|
||||
```python
|
||||
# Create protein identification
|
||||
prot_id = oms.ProteinIdentification()
|
||||
|
||||
# Create protein hit
|
||||
prot_hit = oms.ProteinHit()
|
||||
prot_hit.setAccession("P12345")
|
||||
prot_hit.setSequence("MKTAYIAKQRQISFVK...")
|
||||
prot_hit.setScore(100.5)
|
||||
|
||||
# Add to identification
|
||||
prot_id.setHits([prot_hit])
|
||||
prot_id.setScoreType("Mascot Score")
|
||||
prot_id.setHigherScoreBetter(True)
|
||||
|
||||
# Search parameters
|
||||
search_params = oms.ProteinIdentification.SearchParameters()
|
||||
search_params.db = "uniprot_human.fasta"
|
||||
search_params.enzyme = "Trypsin"
|
||||
prot_id.setSearchParameters(search_params)
|
||||
```
|
||||
|
||||
## ConsensusMap and ConsensusFeature
|
||||
|
||||
For linking features across multiple samples.
|
||||
|
||||
### ConsensusFeature
|
||||
|
||||
```python
|
||||
# Create consensus feature
|
||||
cons_feature = oms.ConsensusFeature()
|
||||
cons_feature.setMZ(456.789)
|
||||
cons_feature.setRT(123.45)
|
||||
cons_feature.setIntensity(5000000) # Combined intensity
|
||||
|
||||
# Access linked features
|
||||
for handle in cons_feature.getFeatureList():
|
||||
map_index = handle.getMapIndex()
|
||||
feature_index = handle.getIndex()
|
||||
intensity = handle.getIntensity()
|
||||
```
|
||||
|
||||
### ConsensusMap
|
||||
|
||||
```python
|
||||
# Create consensus map
|
||||
consensus_map = oms.ConsensusMap()
|
||||
|
||||
# Add consensus features
|
||||
consensus_map.push_back(cons_feature)
|
||||
|
||||
# Iterate
|
||||
for cons_feat in consensus_map:
|
||||
mz = cons_feat.getMZ()
|
||||
rt = cons_feat.getRT()
|
||||
n_features = cons_feat.size() # Number of linked features
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use numpy arrays** for peak data when possible - much faster than individual peak access
|
||||
2. **Sort spectra** by position (m/z) before searching or filtering
|
||||
3. **Update ranges** after modifying MSExperiment: `exp.updateRanges()`
|
||||
4. **Check MS level** before processing - different algorithms for MS1 vs MS2
|
||||
5. **Validate precursor info** for MS2 spectra - ensure charge and m/z are set
|
||||
6. **Use appropriate containers** - MSExperiment for raw data, FeatureMap for quantification
|
||||
7. **Clear metadata carefully** - use `clear(False)` to preserve metadata when clearing peaks
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Create MS2 Spectrum with Precursor
|
||||
|
||||
```python
|
||||
spectrum = oms.MSSpectrum()
|
||||
spectrum.setRT(205.2)
|
||||
spectrum.setMSLevel(2)
|
||||
spectrum.set_peaks(([100, 200, 300], [1000, 5000, 3000]))
|
||||
|
||||
precursor = oms.Precursor()
|
||||
precursor.setMZ(450.5)
|
||||
precursor.setCharge(2)
|
||||
spectrum.setPrecursors([precursor])
|
||||
```
|
||||
|
||||
### Extract MS1 Spectra from Experiment
|
||||
|
||||
```python
|
||||
ms1_exp = oms.MSExperiment()
|
||||
for spectrum in exp:
|
||||
if spectrum.getMSLevel() == 1:
|
||||
ms1_exp.addSpectrum(spectrum)
|
||||
```
|
||||
|
||||
### Calculate Total Ion Current (TIC)
|
||||
|
||||
```python
|
||||
tic_values = []
|
||||
rt_values = []
|
||||
for spectrum in exp:
|
||||
if spectrum.getMSLevel() == 1:
|
||||
mz, intensity = spectrum.get_peaks()
|
||||
tic = np.sum(intensity)
|
||||
tic_values.append(tic)
|
||||
rt_values.append(spectrum.getRT())
|
||||
```
|
||||
|
||||
### Find Spectrum Closest to RT
|
||||
|
||||
```python
|
||||
target_rt = 125.5
|
||||
closest_spectrum = None
|
||||
min_diff = float('inf')
|
||||
|
||||
for spectrum in exp:
|
||||
diff = abs(spectrum.getRT() - target_rt)
|
||||
if diff < min_diff:
|
||||
min_diff = diff
|
||||
closest_spectrum = spectrum
|
||||
```
|
||||
@@ -1,621 +0,0 @@
|
||||
---
|
||||
name: reportlab
|
||||
description: "PDF generation toolkit. Create invoices, reports, certificates, forms, charts, tables, barcodes, QR codes, Canvas/Platypus APIs, for professional document automation."
|
||||
---
|
||||
|
||||
# ReportLab PDF Generation
|
||||
|
||||
## Overview
|
||||
|
||||
ReportLab is a powerful Python library for programmatic PDF generation. Create anything from simple documents to complex reports with tables, charts, images, and interactive forms.
|
||||
|
||||
**Two main approaches:**
|
||||
- **Canvas API** (low-level): Direct drawing with coordinate-based positioning - use for precise layouts
|
||||
- **Platypus** (high-level): Flowing document layout with automatic page breaks - use for multi-page documents
|
||||
|
||||
**Core capabilities:**
|
||||
- Text with rich formatting and custom fonts
|
||||
- Tables with complex styling and cell spanning
|
||||
- Charts (bar, line, pie, area, scatter)
|
||||
- Barcodes and QR codes (Code128, EAN, QR, etc.)
|
||||
- Images with transparency
|
||||
- PDF features (links, bookmarks, forms, encryption)
|
||||
|
||||
## Choosing the Right Approach
|
||||
|
||||
### Use Canvas API when:
|
||||
- Creating labels, business cards, certificates
|
||||
- Precise positioning is critical (x, y coordinates)
|
||||
- Single-page documents or simple layouts
|
||||
- Drawing graphics, shapes, and custom designs
|
||||
- Adding barcodes or QR codes at specific locations
|
||||
|
||||
### Use Platypus when:
|
||||
- Creating multi-page documents (reports, articles, books)
|
||||
- Content should flow automatically across pages
|
||||
- Need headers/footers that repeat on each page
|
||||
- Working with paragraphs that can split across pages
|
||||
- Building complex documents with table of contents
|
||||
|
||||
### Use Both when:
|
||||
- Complex reports need both flowing content AND precise positioning
|
||||
- Adding headers/footers to Platypus documents (use `onPage` callback with Canvas)
|
||||
- Embedding custom graphics (Canvas) within flowing documents (Platypus)
|
||||
|
||||
## Quick Start Examples
|
||||
|
||||
### Simple Canvas Document
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
c = canvas.Canvas("output.pdf", pagesize=letter)
|
||||
width, height = letter
|
||||
|
||||
# Draw text
|
||||
c.setFont("Helvetica-Bold", 24)
|
||||
c.drawString(inch, height - inch, "Hello ReportLab!")
|
||||
|
||||
# Draw a rectangle
|
||||
c.setFillColorRGB(0.2, 0.4, 0.8)
|
||||
c.rect(inch, 5*inch, 4*inch, 2*inch, fill=1)
|
||||
|
||||
# Save
|
||||
c.showPage()
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Simple Platypus Document
|
||||
|
||||
```python
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
|
||||
from reportlab.lib.styles import getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
doc = SimpleDocTemplate("output.pdf", pagesize=letter)
|
||||
story = []
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Add content
|
||||
story.append(Paragraph("Document Title", styles['Title']))
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
story.append(Paragraph("This is body text with <b>bold</b> and <i>italic</i>.", styles['BodyText']))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Creating Tables
|
||||
|
||||
Tables work with both Canvas (via Drawing) and Platypus (as Flowables):
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Table, TableStyle
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Define data
|
||||
data = [
|
||||
['Product', 'Q1', 'Q2', 'Q3', 'Q4'],
|
||||
['Widget A', '100', '150', '130', '180'],
|
||||
['Widget B', '80', '120', '110', '160'],
|
||||
]
|
||||
|
||||
# Create table
|
||||
table = Table(data, colWidths=[2*inch, 1*inch, 1*inch, 1*inch, 1*inch])
|
||||
|
||||
# Apply styling
|
||||
style = TableStyle([
|
||||
# Header row
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.darkblue),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
|
||||
# Data rows
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]),
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
||||
])
|
||||
|
||||
table.setStyle(style)
|
||||
|
||||
# Add to Platypus story
|
||||
story.append(table)
|
||||
|
||||
# Or draw on Canvas
|
||||
table.wrapOn(c, width, height)
|
||||
table.drawOn(c, x, y)
|
||||
```
|
||||
|
||||
**Detailed table reference:** See `references/tables_reference.md` for cell spanning, borders, alignment, and advanced styling.
|
||||
|
||||
### Creating Charts
|
||||
|
||||
Charts use the graphics framework and can be added to both Canvas and Platypus:
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.charts.barcharts import VerticalBarChart
|
||||
from reportlab.lib import colors
|
||||
|
||||
# Create drawing
|
||||
drawing = Drawing(400, 200)
|
||||
|
||||
# Create chart
|
||||
chart = VerticalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 125
|
||||
|
||||
# Set data
|
||||
chart.data = [[100, 150, 130, 180, 140]]
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5']
|
||||
|
||||
# Style
|
||||
chart.bars[0].fillColor = colors.blue
|
||||
chart.valueAxis.valueMin = 0
|
||||
chart.valueAxis.valueMax = 200
|
||||
|
||||
# Add to drawing
|
||||
drawing.add(chart)
|
||||
|
||||
# Use in Platypus
|
||||
story.append(drawing)
|
||||
|
||||
# Or render directly to PDF
|
||||
from reportlab.graphics import renderPDF
|
||||
renderPDF.drawToFile(drawing, 'chart.pdf', 'Chart Title')
|
||||
```
|
||||
|
||||
**Available chart types:** Bar (vertical/horizontal), Line, Pie, Area, Scatter
|
||||
**Detailed charts reference:** See `references/charts_reference.md` for all chart types, styling, legends, and customization.
|
||||
|
||||
### Adding Barcodes and QR Codes
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import code128
|
||||
from reportlab.graphics.barcode.qr import QrCodeWidget
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics import renderPDF
|
||||
|
||||
# Code128 barcode (general purpose)
|
||||
barcode = code128.Code128("ABC123456789", barHeight=0.5*inch)
|
||||
|
||||
# On Canvas
|
||||
barcode.drawOn(c, x, y)
|
||||
|
||||
# QR Code
|
||||
qr = QrCodeWidget("https://example.com")
|
||||
qr.barWidth = 2*inch
|
||||
qr.barHeight = 2*inch
|
||||
|
||||
# Wrap in Drawing for Platypus
|
||||
d = Drawing()
|
||||
d.add(qr)
|
||||
story.append(d)
|
||||
```
|
||||
|
||||
**Supported formats:** Code128, Code39, EAN-13, EAN-8, UPC-A, ISBN, QR, Data Matrix, and 20+ more
|
||||
**Detailed barcode reference:** See `references/barcodes_reference.md` for all formats and usage examples.
|
||||
|
||||
### Working with Text and Fonts
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Paragraph
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
from reportlab.lib.enums import TA_JUSTIFY
|
||||
|
||||
# Create custom style
|
||||
custom_style = ParagraphStyle(
|
||||
'CustomStyle',
|
||||
fontSize=12,
|
||||
leading=14, # Line spacing
|
||||
alignment=TA_JUSTIFY,
|
||||
spaceAfter=10,
|
||||
textColor=colors.black,
|
||||
)
|
||||
|
||||
# Paragraph with inline formatting
|
||||
text = """
|
||||
This paragraph has <b>bold</b>, <i>italic</i>, and <u>underlined</u> text.
|
||||
You can also use <font color="blue">colors</font> and <font size="14">different sizes</font>.
|
||||
Chemical formula: H<sub>2</sub>O, Einstein: E=mc<sup>2</sup>
|
||||
"""
|
||||
|
||||
para = Paragraph(text, custom_style)
|
||||
story.append(para)
|
||||
```
|
||||
|
||||
**Using custom fonts:**
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
|
||||
# Register TrueType font
|
||||
pdfmetrics.registerFont(TTFont('CustomFont', 'CustomFont.ttf'))
|
||||
|
||||
# Use in Canvas
|
||||
c.setFont('CustomFont', 12)
|
||||
|
||||
# Use in Paragraph style
|
||||
style = ParagraphStyle('Custom', fontName='CustomFont', fontSize=12)
|
||||
```
|
||||
|
||||
**Detailed text reference:** See `references/text_and_fonts.md` for paragraph styles, font families, Asian languages, Greek letters, and formatting.
|
||||
|
||||
### Adding Images
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Image
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# In Platypus
|
||||
img = Image('photo.jpg', width=4*inch, height=3*inch)
|
||||
story.append(img)
|
||||
|
||||
# Maintain aspect ratio
|
||||
img = Image('photo.jpg', width=4*inch, height=3*inch, kind='proportional')
|
||||
|
||||
# In Canvas
|
||||
c.drawImage('photo.jpg', x, y, width=4*inch, height=3*inch)
|
||||
|
||||
# With transparency (mask white background)
|
||||
c.drawImage('logo.png', x, y, mask=[255,255,255,255,255,255])
|
||||
```
|
||||
|
||||
### Creating Forms
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.lib.colors import black, white, lightgrey
|
||||
|
||||
c = canvas.Canvas("form.pdf")
|
||||
|
||||
# Text field
|
||||
c.acroForm.textfield(
|
||||
name="name",
|
||||
tooltip="Enter your name",
|
||||
x=100, y=700,
|
||||
width=200, height=20,
|
||||
borderColor=black,
|
||||
fillColor=lightgrey,
|
||||
forceBorder=True
|
||||
)
|
||||
|
||||
# Checkbox
|
||||
c.acroForm.checkbox(
|
||||
name="agree",
|
||||
x=100, y=650,
|
||||
size=20,
|
||||
buttonStyle='check',
|
||||
checked=False
|
||||
)
|
||||
|
||||
# Dropdown
|
||||
c.acroForm.choice(
|
||||
name="country",
|
||||
x=100, y=600,
|
||||
width=150, height=20,
|
||||
options=[("United States", "US"), ("Canada", "CA")],
|
||||
forceBorder=True
|
||||
)
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
**Detailed PDF features reference:** See `references/pdf_features.md` for forms, links, bookmarks, encryption, and metadata.
|
||||
|
||||
### Headers and Footers
|
||||
|
||||
For Platypus documents, use page callbacks:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame
|
||||
|
||||
def add_header_footer(canvas, doc):
|
||||
"""Called on each page"""
|
||||
canvas.saveState()
|
||||
|
||||
# Header
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.drawString(inch, height - 0.5*inch, "Document Title")
|
||||
|
||||
# Footer
|
||||
canvas.drawRightString(width - inch, 0.5*inch, f"Page {doc.page}")
|
||||
|
||||
canvas.restoreState()
|
||||
|
||||
# Set up document
|
||||
doc = BaseDocTemplate("output.pdf")
|
||||
frame = Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height, id='normal')
|
||||
template = PageTemplate(id='normal', frames=[frame], onPage=add_header_footer)
|
||||
doc.addPageTemplates([template])
|
||||
|
||||
# Build with story
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Helper Scripts
|
||||
|
||||
This skill includes helper scripts for common tasks:
|
||||
|
||||
### Quick Document Generator
|
||||
|
||||
Use `scripts/quick_document.py` for rapid document creation:
|
||||
|
||||
```python
|
||||
from scripts.quick_document import create_simple_document, create_styled_table
|
||||
|
||||
# Simple document from content blocks
|
||||
content = [
|
||||
{'type': 'heading', 'content': 'Introduction'},
|
||||
{'type': 'paragraph', 'content': 'Your text here...'},
|
||||
{'type': 'bullet', 'content': 'Bullet point'},
|
||||
]
|
||||
|
||||
create_simple_document("output.pdf", "My Document", content_blocks=content)
|
||||
|
||||
# Styled tables with presets
|
||||
data = [['Header1', 'Header2'], ['Data1', 'Data2']]
|
||||
table = create_styled_table(data, style_name='striped') # 'default', 'striped', 'minimal', 'report'
|
||||
```
|
||||
|
||||
## Template Examples
|
||||
|
||||
Complete working examples in `assets/`:
|
||||
|
||||
### Invoice Template
|
||||
|
||||
`assets/invoice_template.py` - Professional invoice with:
|
||||
- Company and client information
|
||||
- Itemized table with calculations
|
||||
- Tax and totals
|
||||
- Terms and notes
|
||||
- Logo placement
|
||||
|
||||
```python
|
||||
from assets.invoice_template import create_invoice
|
||||
|
||||
create_invoice(
|
||||
filename="invoice.pdf",
|
||||
invoice_number="INV-2024-001",
|
||||
invoice_date="January 15, 2024",
|
||||
due_date="February 15, 2024",
|
||||
company_info={'name': 'Acme Corp', 'address': '...', 'phone': '...', 'email': '...'},
|
||||
client_info={'name': 'Client Name', ...},
|
||||
items=[
|
||||
{'description': 'Service', 'quantity': 1, 'unit_price': 500.00},
|
||||
...
|
||||
],
|
||||
tax_rate=0.08,
|
||||
notes="Thank you for your business!",
|
||||
)
|
||||
```
|
||||
|
||||
### Report Template
|
||||
|
||||
`assets/report_template.py` - Multi-page business report with:
|
||||
- Cover page
|
||||
- Table of contents
|
||||
- Multiple sections with subsections
|
||||
- Charts and tables
|
||||
- Headers and footers
|
||||
|
||||
```python
|
||||
from assets.report_template import create_report
|
||||
|
||||
report_data = {
|
||||
'title': 'Quarterly Report',
|
||||
'subtitle': 'Q4 2023',
|
||||
'author': 'Analytics Team',
|
||||
'sections': [
|
||||
{
|
||||
'title': 'Executive Summary',
|
||||
'content': 'Report content...',
|
||||
'table_data': {...},
|
||||
'chart_data': {...}
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
create_report("report.pdf", report_data)
|
||||
```
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
Comprehensive API references organized by feature:
|
||||
|
||||
- **`references/canvas_api.md`** - Low-level Canvas: drawing primitives, coordinates, transformations, state management, images, paths
|
||||
- **`references/platypus_guide.md`** - High-level Platypus: document templates, frames, flowables, page layouts, TOC
|
||||
- **`references/text_and_fonts.md`** - Text formatting: paragraph styles, inline markup, custom fonts, Asian languages, bullets, sequences
|
||||
- **`references/tables_reference.md`** - Tables: creation, styling, cell spanning, borders, alignment, colors, gradients
|
||||
- **`references/charts_reference.md`** - Charts: all chart types, data handling, axes, legends, colors, rendering
|
||||
- **`references/barcodes_reference.md`** - Barcodes: Code128, QR codes, EAN, UPC, postal codes, and 20+ formats
|
||||
- **`references/pdf_features.md`** - PDF features: links, bookmarks, forms, encryption, metadata, page transitions
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Coordinate System (Canvas)
|
||||
- Origin (0, 0) is **lower-left corner** (not top-left)
|
||||
- Y-axis points **upward**
|
||||
- Units are in **points** (72 points = 1 inch)
|
||||
- Always specify page size explicitly
|
||||
|
||||
```python
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
width, height = letter
|
||||
margin = inch
|
||||
|
||||
# Top of page
|
||||
y_top = height - margin
|
||||
|
||||
# Bottom of page
|
||||
y_bottom = margin
|
||||
```
|
||||
|
||||
### Choosing Page Size
|
||||
|
||||
```python
|
||||
from reportlab.lib.pagesizes import letter, A4, landscape
|
||||
|
||||
# US Letter (8.5" x 11")
|
||||
pagesize=letter
|
||||
|
||||
# ISO A4 (210mm x 297mm)
|
||||
pagesize=A4
|
||||
|
||||
# Landscape
|
||||
pagesize=landscape(letter)
|
||||
|
||||
# Custom
|
||||
pagesize=(6*inch, 9*inch)
|
||||
```
|
||||
|
||||
### Performance Tips
|
||||
|
||||
1. **Use `drawImage()` over `drawInlineImage()`** - caches images for reuse
|
||||
2. **Enable compression for large files:** `canvas.Canvas("file.pdf", pageCompression=1)`
|
||||
3. **Reuse styles** - create once, use throughout document
|
||||
4. **Use Forms/XObjects** for repeated graphics
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Centering text on Canvas:**
|
||||
```python
|
||||
text = "Centered Text"
|
||||
text_width = c.stringWidth(text, "Helvetica", 12)
|
||||
x = (width - text_width) / 2
|
||||
c.drawString(x, y, text)
|
||||
|
||||
# Or use built-in
|
||||
c.drawCentredString(width/2, y, text)
|
||||
```
|
||||
|
||||
**Page breaks in Platypus:**
|
||||
```python
|
||||
from reportlab.platypus import PageBreak
|
||||
|
||||
story.append(PageBreak())
|
||||
```
|
||||
|
||||
**Keep content together (no split):**
|
||||
```python
|
||||
from reportlab.platypus import KeepTogether
|
||||
|
||||
story.append(KeepTogether([
|
||||
heading,
|
||||
paragraph1,
|
||||
paragraph2,
|
||||
]))
|
||||
```
|
||||
|
||||
**Alternate row colors:**
|
||||
```python
|
||||
style = TableStyle([
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]),
|
||||
])
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Text overlaps or disappears:**
|
||||
- Check Y-coordinates - remember origin is bottom-left
|
||||
- Ensure text fits within page bounds
|
||||
- Verify `leading` (line spacing) is greater than `fontSize`
|
||||
|
||||
**Table doesn't fit on page:**
|
||||
- Reduce column widths
|
||||
- Decrease font size
|
||||
- Use landscape orientation
|
||||
- Enable table splitting with `repeatRows`
|
||||
|
||||
**Barcode not scanning:**
|
||||
- Increase `barHeight` (try 0.5 inch minimum)
|
||||
- Set `quiet=1` for quiet zones
|
||||
- Test print quality (300+ DPI recommended)
|
||||
- Validate data format for barcode type
|
||||
|
||||
**Font not found:**
|
||||
- Register TrueType fonts with `pdfmetrics.registerFont()`
|
||||
- Use font family name exactly as registered
|
||||
- Check font file path is correct
|
||||
|
||||
**Images have white background:**
|
||||
- Use `mask` parameter to make white transparent
|
||||
- Provide RGB range to mask: `mask=[255,255,255,255,255,255]`
|
||||
- Or use PNG with alpha channel
|
||||
|
||||
## Example Workflows
|
||||
|
||||
### Creating an Invoice
|
||||
|
||||
1. Start with invoice template from `assets/invoice_template.py`
|
||||
2. Customize company info, logo path
|
||||
3. Add items with descriptions, quantities, prices
|
||||
4. Set tax rate if applicable
|
||||
5. Add notes and payment terms
|
||||
6. Generate PDF
|
||||
|
||||
### Creating a Report
|
||||
|
||||
1. Start with report template from `assets/report_template.py`
|
||||
2. Define sections with titles and content
|
||||
3. Add tables for data using `create_styled_table()`
|
||||
4. Add charts using graphics framework
|
||||
5. Build with `doc.multiBuild(story)` for TOC
|
||||
|
||||
### Creating a Certificate
|
||||
|
||||
1. Use Canvas API for precise positioning
|
||||
2. Load custom fonts for elegant typography
|
||||
3. Add border graphics or image background
|
||||
4. Position text elements (name, date, achievement)
|
||||
5. Optional: Add QR code for verification
|
||||
|
||||
### Creating Labels with Barcodes
|
||||
|
||||
1. Use Canvas with custom page size (label dimensions)
|
||||
2. Calculate grid positions for multiple labels per page
|
||||
3. Draw label content (text, images)
|
||||
4. Add barcode at specific position
|
||||
5. Use `showPage()` between labels or grids
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install reportlab
|
||||
|
||||
# For image support
|
||||
pip install pillow
|
||||
|
||||
# For charts
|
||||
pip install reportlab[renderPM]
|
||||
|
||||
# For barcode support (included in reportlab)
|
||||
# QR codes require: pip install qrcode
|
||||
```
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Generating PDF documents programmatically
|
||||
- Creating invoices, receipts, or billing documents
|
||||
- Building reports with tables and charts
|
||||
- Generating certificates, badges, or credentials
|
||||
- Creating shipping labels or product labels with barcodes
|
||||
- Designing forms or fillable PDFs
|
||||
- Producing multi-page documents with consistent formatting
|
||||
- Converting data to PDF format for archival or distribution
|
||||
- Creating custom layouts that require precise positioning
|
||||
|
||||
This skill provides comprehensive guidance for all ReportLab capabilities, from simple documents to complex multi-page reports with charts, tables, and interactive elements.
|
||||
@@ -1,256 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Invoice Template - Complete example of a professional invoice
|
||||
|
||||
This template demonstrates:
|
||||
- Company header with logo placement
|
||||
- Client information
|
||||
- Invoice details table
|
||||
- Calculations (subtotal, tax, total)
|
||||
- Professional styling
|
||||
- Terms and conditions footer
|
||||
"""
|
||||
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image
|
||||
)
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.enums import TA_LEFT, TA_RIGHT, TA_CENTER
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def create_invoice(
|
||||
filename,
|
||||
invoice_number,
|
||||
invoice_date,
|
||||
due_date,
|
||||
company_info,
|
||||
client_info,
|
||||
items,
|
||||
tax_rate=0.0,
|
||||
notes="",
|
||||
terms="Payment due within 30 days.",
|
||||
logo_path=None
|
||||
):
|
||||
"""
|
||||
Create a professional invoice PDF.
|
||||
|
||||
Args:
|
||||
filename: Output PDF filename
|
||||
invoice_number: Invoice number (e.g., "INV-2024-001")
|
||||
invoice_date: Date of invoice (datetime or string)
|
||||
due_date: Payment due date (datetime or string)
|
||||
company_info: Dict with company details
|
||||
{'name': 'Company Name', 'address': 'Address', 'phone': 'Phone', 'email': 'Email'}
|
||||
client_info: Dict with client details (same structure as company_info)
|
||||
items: List of dicts with item details
|
||||
[{'description': 'Item', 'quantity': 1, 'unit_price': 100.00}, ...]
|
||||
tax_rate: Tax rate as decimal (e.g., 0.08 for 8%)
|
||||
notes: Additional notes to client
|
||||
terms: Payment terms
|
||||
logo_path: Path to company logo image (optional)
|
||||
"""
|
||||
# Create document
|
||||
doc = SimpleDocTemplate(filename, pagesize=letter,
|
||||
rightMargin=0.5*inch, leftMargin=0.5*inch,
|
||||
topMargin=0.5*inch, bottomMargin=0.5*inch)
|
||||
|
||||
# Container for elements
|
||||
story = []
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Create custom styles
|
||||
title_style = ParagraphStyle(
|
||||
'InvoiceTitle',
|
||||
parent=styles['Heading1'],
|
||||
fontSize=24,
|
||||
textColor=colors.HexColor('#2C3E50'),
|
||||
spaceAfter=12,
|
||||
)
|
||||
|
||||
header_style = ParagraphStyle(
|
||||
'Header',
|
||||
parent=styles['Normal'],
|
||||
fontSize=10,
|
||||
textColor=colors.HexColor('#34495E'),
|
||||
)
|
||||
|
||||
# --- HEADER SECTION ---
|
||||
header_data = []
|
||||
|
||||
# Company info (left side)
|
||||
company_text = f"""
|
||||
<b><font size="14">{company_info['name']}</font></b><br/>
|
||||
{company_info.get('address', '')}<br/>
|
||||
Phone: {company_info.get('phone', '')}<br/>
|
||||
Email: {company_info.get('email', '')}
|
||||
"""
|
||||
|
||||
# Invoice title and number (right side)
|
||||
invoice_text = f"""
|
||||
<b><font size="16" color="#2C3E50">INVOICE</font></b><br/>
|
||||
<font size="10">Invoice #: {invoice_number}</font><br/>
|
||||
<font size="10">Date: {invoice_date}</font><br/>
|
||||
<font size="10">Due Date: {due_date}</font>
|
||||
"""
|
||||
|
||||
if logo_path:
|
||||
logo = Image(logo_path, width=1.5*inch, height=1*inch)
|
||||
header_data = [[logo, Paragraph(company_text, header_style), Paragraph(invoice_text, header_style)]]
|
||||
header_table = Table(header_data, colWidths=[1.5*inch, 3*inch, 2.5*inch])
|
||||
else:
|
||||
header_data = [[Paragraph(company_text, header_style), Paragraph(invoice_text, header_style)]]
|
||||
header_table = Table(header_data, colWidths=[4.5*inch, 2.5*inch])
|
||||
|
||||
header_table.setStyle(TableStyle([
|
||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||
('ALIGN', (-1, 0), (-1, -1), 'RIGHT'),
|
||||
]))
|
||||
|
||||
story.append(header_table)
|
||||
story.append(Spacer(1, 0.3*inch))
|
||||
|
||||
# --- CLIENT INFORMATION ---
|
||||
client_label = Paragraph("<b>Bill To:</b>", header_style)
|
||||
client_text = f"""
|
||||
<b>{client_info['name']}</b><br/>
|
||||
{client_info.get('address', '')}<br/>
|
||||
Phone: {client_info.get('phone', '')}<br/>
|
||||
Email: {client_info.get('email', '')}
|
||||
"""
|
||||
client_para = Paragraph(client_text, header_style)
|
||||
|
||||
client_table = Table([[client_label, client_para]], colWidths=[1*inch, 6*inch])
|
||||
story.append(client_table)
|
||||
story.append(Spacer(1, 0.3*inch))
|
||||
|
||||
# --- ITEMS TABLE ---
|
||||
# Table header
|
||||
items_data = [['Description', 'Quantity', 'Unit Price', 'Amount']]
|
||||
|
||||
# Calculate items
|
||||
subtotal = 0
|
||||
for item in items:
|
||||
desc = item['description']
|
||||
qty = item['quantity']
|
||||
price = item['unit_price']
|
||||
amount = qty * price
|
||||
subtotal += amount
|
||||
|
||||
items_data.append([
|
||||
desc,
|
||||
str(qty),
|
||||
f"${price:,.2f}",
|
||||
f"${amount:,.2f}"
|
||||
])
|
||||
|
||||
# Create items table
|
||||
items_table = Table(items_data, colWidths=[3.5*inch, 1*inch, 1.5*inch, 1*inch])
|
||||
|
||||
items_table.setStyle(TableStyle([
|
||||
# Header row
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495E')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 11),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
|
||||
# Data rows
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.white),
|
||||
('ALIGN', (1, 1), (-1, -1), 'RIGHT'),
|
||||
('ALIGN', (0, 1), (0, -1), 'LEFT'),
|
||||
('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
|
||||
('FONTSIZE', (0, 1), (-1, -1), 10),
|
||||
('TOPPADDING', (0, 1), (-1, -1), 6),
|
||||
('BOTTOMPADDING', (0, 1), (-1, -1), 6),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
]))
|
||||
|
||||
story.append(items_table)
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
|
||||
# --- TOTALS SECTION ---
|
||||
tax_amount = subtotal * tax_rate
|
||||
total = subtotal + tax_amount
|
||||
|
||||
totals_data = [
|
||||
['Subtotal:', f"${subtotal:,.2f}"],
|
||||
]
|
||||
|
||||
if tax_rate > 0:
|
||||
totals_data.append([f'Tax ({tax_rate*100:.1f}%):', f"${tax_amount:,.2f}"])
|
||||
|
||||
totals_data.append(['<b>Total:</b>', f"<b>${total:,.2f}</b>"])
|
||||
|
||||
totals_table = Table(totals_data, colWidths=[5*inch, 2*inch])
|
||||
totals_table.setStyle(TableStyle([
|
||||
('ALIGN', (0, 0), (-1, -1), 'RIGHT'),
|
||||
('FONTNAME', (0, 0), (-1, -2), 'Helvetica'),
|
||||
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 11),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
('LINEABOVE', (1, -1), (1, -1), 2, colors.HexColor('#34495E')),
|
||||
]))
|
||||
|
||||
story.append(totals_table)
|
||||
story.append(Spacer(1, 0.4*inch))
|
||||
|
||||
# --- NOTES ---
|
||||
if notes:
|
||||
notes_style = ParagraphStyle('Notes', parent=styles['Normal'], fontSize=9)
|
||||
story.append(Paragraph(f"<b>Notes:</b><br/>{notes}", notes_style))
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
|
||||
# --- TERMS ---
|
||||
terms_style = ParagraphStyle('Terms', parent=styles['Normal'],
|
||||
fontSize=9, textColor=colors.grey)
|
||||
story.append(Paragraph(f"<b>Payment Terms:</b><br/>{terms}", terms_style))
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
return filename
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Sample data
|
||||
company = {
|
||||
'name': 'Acme Corporation',
|
||||
'address': '123 Business St, Suite 100\nNew York, NY 10001',
|
||||
'phone': '(555) 123-4567',
|
||||
'email': 'info@acme.com'
|
||||
}
|
||||
|
||||
client = {
|
||||
'name': 'John Doe',
|
||||
'address': '456 Client Ave\nLos Angeles, CA 90001',
|
||||
'phone': '(555) 987-6543',
|
||||
'email': 'john@example.com'
|
||||
}
|
||||
|
||||
items = [
|
||||
{'description': 'Web Design Services', 'quantity': 1, 'unit_price': 2500.00},
|
||||
{'description': 'Content Writing (10 pages)', 'quantity': 10, 'unit_price': 50.00},
|
||||
{'description': 'SEO Optimization', 'quantity': 1, 'unit_price': 750.00},
|
||||
{'description': 'Hosting Setup', 'quantity': 1, 'unit_price': 200.00},
|
||||
]
|
||||
|
||||
create_invoice(
|
||||
filename="sample_invoice.pdf",
|
||||
invoice_number="INV-2024-001",
|
||||
invoice_date="January 15, 2024",
|
||||
due_date="February 15, 2024",
|
||||
company_info=company,
|
||||
client_info=client,
|
||||
items=items,
|
||||
tax_rate=0.08,
|
||||
notes="Thank you for your business! We appreciate your prompt payment.",
|
||||
terms="Payment due within 30 days. Late payments subject to 1.5% monthly fee.",
|
||||
logo_path=None # Set to your logo path if available
|
||||
)
|
||||
|
||||
print("Invoice created: sample_invoice.pdf")
|
||||
@@ -1,343 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Report Template - Complete example of a professional multi-page report
|
||||
|
||||
This template demonstrates:
|
||||
- Cover page
|
||||
- Table of contents
|
||||
- Multiple sections with headers
|
||||
- Charts and graphs integration
|
||||
- Tables with data
|
||||
- Headers and footers
|
||||
- Professional styling
|
||||
"""
|
||||
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import (
|
||||
BaseDocTemplate, PageTemplate, Frame, Paragraph, Spacer,
|
||||
Table, TableStyle, PageBreak, KeepTogether, TableOfContents
|
||||
)
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.enums import TA_LEFT, TA_RIGHT, TA_CENTER, TA_JUSTIFY
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.charts.barcharts import VerticalBarChart
|
||||
from reportlab.graphics.charts.linecharts import HorizontalLineChart
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def header_footer(canvas, doc):
|
||||
"""Draw header and footer on each page (except cover)"""
|
||||
canvas.saveState()
|
||||
|
||||
# Skip header/footer on cover page (page 1)
|
||||
if doc.page > 1:
|
||||
# Header
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.setFillColor(colors.grey)
|
||||
canvas.drawString(inch, letter[1] - 0.5*inch, "Quarterly Business Report")
|
||||
canvas.line(inch, letter[1] - 0.55*inch, letter[0] - inch, letter[1] - 0.55*inch)
|
||||
|
||||
# Footer
|
||||
canvas.drawString(inch, 0.5*inch, f"Generated: {datetime.now().strftime('%B %d, %Y')}")
|
||||
canvas.drawRightString(letter[0] - inch, 0.5*inch, f"Page {doc.page - 1}")
|
||||
|
||||
canvas.restoreState()
|
||||
|
||||
|
||||
def create_report(filename, report_data):
|
||||
"""
|
||||
Create a comprehensive business report.
|
||||
|
||||
Args:
|
||||
filename: Output PDF filename
|
||||
report_data: Dict containing report information
|
||||
{
|
||||
'title': 'Report Title',
|
||||
'subtitle': 'Report Subtitle',
|
||||
'author': 'Author Name',
|
||||
'date': 'Date',
|
||||
'sections': [
|
||||
{
|
||||
'title': 'Section Title',
|
||||
'content': 'Section content...',
|
||||
'subsections': [...],
|
||||
'table': {...},
|
||||
'chart': {...}
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Create document with custom page template
|
||||
doc = BaseDocTemplate(filename, pagesize=letter,
|
||||
rightMargin=72, leftMargin=72,
|
||||
topMargin=inch, bottomMargin=inch)
|
||||
|
||||
# Define frame for content
|
||||
frame = Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height - 0.5*inch, id='normal')
|
||||
|
||||
# Create page template with header/footer
|
||||
template = PageTemplate(id='normal', frames=[frame], onPage=header_footer)
|
||||
doc.addPageTemplates([template])
|
||||
|
||||
# Get styles
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Custom styles
|
||||
title_style = ParagraphStyle(
|
||||
'ReportTitle',
|
||||
parent=styles['Title'],
|
||||
fontSize=28,
|
||||
textColor=colors.HexColor('#2C3E50'),
|
||||
spaceAfter=20,
|
||||
alignment=TA_CENTER,
|
||||
)
|
||||
|
||||
subtitle_style = ParagraphStyle(
|
||||
'ReportSubtitle',
|
||||
parent=styles['Normal'],
|
||||
fontSize=14,
|
||||
textColor=colors.grey,
|
||||
alignment=TA_CENTER,
|
||||
spaceAfter=30,
|
||||
)
|
||||
|
||||
heading1_style = ParagraphStyle(
|
||||
'CustomHeading1',
|
||||
parent=styles['Heading1'],
|
||||
fontSize=18,
|
||||
textColor=colors.HexColor('#2C3E50'),
|
||||
spaceAfter=12,
|
||||
spaceBefore=12,
|
||||
)
|
||||
|
||||
heading2_style = ParagraphStyle(
|
||||
'CustomHeading2',
|
||||
parent=styles['Heading2'],
|
||||
fontSize=14,
|
||||
textColor=colors.HexColor('#34495E'),
|
||||
spaceAfter=10,
|
||||
spaceBefore=10,
|
||||
)
|
||||
|
||||
body_style = ParagraphStyle(
|
||||
'ReportBody',
|
||||
parent=styles['BodyText'],
|
||||
fontSize=11,
|
||||
alignment=TA_JUSTIFY,
|
||||
spaceAfter=12,
|
||||
leading=14,
|
||||
)
|
||||
|
||||
# Build story
|
||||
story = []
|
||||
|
||||
# --- COVER PAGE ---
|
||||
story.append(Spacer(1, 2*inch))
|
||||
story.append(Paragraph(report_data['title'], title_style))
|
||||
story.append(Paragraph(report_data.get('subtitle', ''), subtitle_style))
|
||||
story.append(Spacer(1, inch))
|
||||
|
||||
# Cover info table
|
||||
cover_info = [
|
||||
['Prepared by:', report_data.get('author', '')],
|
||||
['Date:', report_data.get('date', datetime.now().strftime('%B %d, %Y'))],
|
||||
['Period:', report_data.get('period', 'Q4 2023')],
|
||||
]
|
||||
|
||||
cover_table = Table(cover_info, colWidths=[2*inch, 4*inch])
|
||||
cover_table.setStyle(TableStyle([
|
||||
('ALIGN', (0, 0), (0, -1), 'RIGHT'),
|
||||
('ALIGN', (1, 0), (1, -1), 'LEFT'),
|
||||
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 11),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
]))
|
||||
|
||||
story.append(cover_table)
|
||||
story.append(PageBreak())
|
||||
|
||||
# --- TABLE OF CONTENTS ---
|
||||
toc = TableOfContents()
|
||||
toc.levelStyles = [
|
||||
ParagraphStyle(name='TOCHeading1', fontSize=14, leftIndent=20, spaceBefore=10, spaceAfter=5),
|
||||
ParagraphStyle(name='TOCHeading2', fontSize=12, leftIndent=40, spaceBefore=3, spaceAfter=3),
|
||||
]
|
||||
|
||||
story.append(Paragraph("Table of Contents", heading1_style))
|
||||
story.append(toc)
|
||||
story.append(PageBreak())
|
||||
|
||||
# --- SECTIONS ---
|
||||
for section in report_data.get('sections', []):
|
||||
# Section heading
|
||||
section_title = section['title']
|
||||
story.append(Paragraph(f'<a name="{section_title}"/>{section_title}', heading1_style))
|
||||
|
||||
# Add to TOC
|
||||
toc.addEntry(0, section_title, doc.page)
|
||||
|
||||
# Section content
|
||||
if 'content' in section:
|
||||
for para in section['content'].split('\n\n'):
|
||||
if para.strip():
|
||||
story.append(Paragraph(para.strip(), body_style))
|
||||
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
|
||||
# Subsections
|
||||
for subsection in section.get('subsections', []):
|
||||
story.append(Paragraph(subsection['title'], heading2_style))
|
||||
|
||||
if 'content' in subsection:
|
||||
story.append(Paragraph(subsection['content'], body_style))
|
||||
|
||||
story.append(Spacer(1, 0.1*inch))
|
||||
|
||||
# Add table if provided
|
||||
if 'table_data' in section:
|
||||
table = create_section_table(section['table_data'])
|
||||
story.append(table)
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
|
||||
# Add chart if provided
|
||||
if 'chart_data' in section:
|
||||
chart = create_section_chart(section['chart_data'])
|
||||
story.append(chart)
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
|
||||
story.append(Spacer(1, 0.3*inch))
|
||||
|
||||
# Build PDF (twice for TOC to populate)
|
||||
doc.multiBuild(story)
|
||||
return filename
|
||||
|
||||
|
||||
def create_section_table(table_data):
|
||||
"""Create a styled table for report sections"""
|
||||
data = table_data['data']
|
||||
table = Table(data, colWidths=table_data.get('colWidths'))
|
||||
|
||||
table.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#34495E')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 11),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.white),
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
('FONTNAME', (0, 1), (-1, -1), 'Helvetica'),
|
||||
('FONTSIZE', (0, 1), (-1, -1), 10),
|
||||
]))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_section_chart(chart_data):
|
||||
"""Create a chart for report sections"""
|
||||
chart_type = chart_data.get('type', 'bar')
|
||||
drawing = Drawing(400, 200)
|
||||
|
||||
if chart_type == 'bar':
|
||||
chart = VerticalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 30
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
chart.data = chart_data['data']
|
||||
chart.categoryAxis.categoryNames = chart_data.get('categories', [])
|
||||
chart.valueAxis.valueMin = 0
|
||||
|
||||
# Style bars
|
||||
for i in range(len(chart_data['data'])):
|
||||
chart.bars[i].fillColor = colors.HexColor(['#3498db', '#e74c3c', '#2ecc71'][i % 3])
|
||||
|
||||
elif chart_type == 'line':
|
||||
chart = HorizontalLineChart()
|
||||
chart.x = 50
|
||||
chart.y = 30
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
chart.data = chart_data['data']
|
||||
chart.categoryAxis.categoryNames = chart_data.get('categories', [])
|
||||
|
||||
# Style lines
|
||||
for i in range(len(chart_data['data'])):
|
||||
chart.lines[i].strokeColor = colors.HexColor(['#3498db', '#e74c3c', '#2ecc71'][i % 3])
|
||||
chart.lines[i].strokeWidth = 2
|
||||
|
||||
drawing.add(chart)
|
||||
return drawing
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
report = {
|
||||
'title': 'Quarterly Business Report',
|
||||
'subtitle': 'Q4 2023 Performance Analysis',
|
||||
'author': 'Analytics Team',
|
||||
'date': 'January 15, 2024',
|
||||
'period': 'October - December 2023',
|
||||
'sections': [
|
||||
{
|
||||
'title': 'Executive Summary',
|
||||
'content': """
|
||||
This report provides a comprehensive analysis of our Q4 2023 performance.
|
||||
Overall, the quarter showed strong growth across all key metrics, with
|
||||
revenue increasing by 25% year-over-year and customer satisfaction
|
||||
scores reaching an all-time high of 4.8/5.0.
|
||||
|
||||
Key highlights include the successful launch of three new products,
|
||||
expansion into two new markets, and the completion of our digital
|
||||
transformation initiative.
|
||||
""",
|
||||
'subsections': [
|
||||
{
|
||||
'title': 'Key Achievements',
|
||||
'content': 'Successfully launched Product X with 10,000 units sold in first month.'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'title': 'Financial Performance',
|
||||
'content': """
|
||||
The financial results for Q4 exceeded expectations across all categories.
|
||||
Revenue growth was driven primarily by strong product sales and increased
|
||||
market share in key regions.
|
||||
""",
|
||||
'table_data': {
|
||||
'data': [
|
||||
['Metric', 'Q3 2023', 'Q4 2023', 'Change'],
|
||||
['Revenue', '$2.5M', '$3.1M', '+24%'],
|
||||
['Profit', '$500K', '$680K', '+36%'],
|
||||
['Expenses', '$2.0M', '$2.4M', '+20%'],
|
||||
],
|
||||
'colWidths': [2*inch, 1.5*inch, 1.5*inch, 1*inch]
|
||||
},
|
||||
'chart_data': {
|
||||
'type': 'bar',
|
||||
'data': [[2.5, 3.1], [0.5, 0.68], [2.0, 2.4]],
|
||||
'categories': ['Q3', 'Q4']
|
||||
}
|
||||
},
|
||||
{
|
||||
'title': 'Market Analysis',
|
||||
'content': """
|
||||
Market conditions remained favorable throughout the quarter, with
|
||||
strong consumer confidence and increasing demand for our products.
|
||||
""",
|
||||
'chart_data': {
|
||||
'type': 'line',
|
||||
'data': [[100, 120, 115, 140, 135, 150]],
|
||||
'categories': ['Oct', 'Nov', 'Dec', 'Oct', 'Nov', 'Dec']
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
create_report("sample_report.pdf", report)
|
||||
print("Report created: sample_report.pdf")
|
||||
@@ -1,504 +0,0 @@
|
||||
# Barcodes Reference
|
||||
|
||||
Comprehensive guide to creating barcodes and QR codes in ReportLab.
|
||||
|
||||
## Available Barcode Types
|
||||
|
||||
ReportLab supports a wide range of 1D and 2D barcode formats.
|
||||
|
||||
### 1D Barcodes (Linear)
|
||||
|
||||
- **Code128** - Compact, encodes full ASCII
|
||||
- **Code39** (Standard39) - Alphanumeric, widely supported
|
||||
- **Code93** (Standard93) - Compressed Code39
|
||||
- **EAN-13** - European Article Number (retail)
|
||||
- **EAN-8** - Short form of EAN
|
||||
- **EAN-5** - 5-digit add-on (pricing)
|
||||
- **UPC-A** - Universal Product Code (North America)
|
||||
- **ISBN** - International Standard Book Number
|
||||
- **Code11** - Telecommunications
|
||||
- **Codabar** - Blood banks, FedEx, libraries
|
||||
- **I2of5** (Interleaved 2 of 5) - Warehouse/distribution
|
||||
- **MSI** - Inventory control
|
||||
- **POSTNET** - US Postal Service
|
||||
- **USPS_4State** - US Postal Service
|
||||
- **FIM** (A, B, C, D) - Facing Identification Mark (mail sorting)
|
||||
|
||||
### 2D Barcodes
|
||||
|
||||
- **QR** - QR Code (widely used for URLs, contact info)
|
||||
- **ECC200DataMatrix** - Data Matrix format
|
||||
|
||||
## Using Barcodes with Canvas
|
||||
|
||||
### Code128 (Recommended for General Use)
|
||||
|
||||
Code128 is versatile and compact - encodes full ASCII character set with mandatory checksum.
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.graphics.barcode import code128
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
c = canvas.Canvas("barcode.pdf")
|
||||
|
||||
# Create barcode
|
||||
barcode = code128.Code128("HELLO123")
|
||||
|
||||
# Draw on canvas
|
||||
barcode.drawOn(c, 1*inch, 5*inch)
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Code128 Options
|
||||
|
||||
```python
|
||||
barcode = code128.Code128(
|
||||
value="ABC123", # Required: data to encode
|
||||
barWidth=0.01*inch, # Width of narrowest bar
|
||||
barHeight=0.5*inch, # Height of bars
|
||||
quiet=1, # Add quiet zones (margins)
|
||||
lquiet=None, # Left quiet zone width
|
||||
rquiet=None, # Right quiet zone width
|
||||
stop=1, # Show stop symbol
|
||||
)
|
||||
|
||||
# Draw with specific size
|
||||
barcode.drawOn(canvas, x, y)
|
||||
|
||||
# Get dimensions
|
||||
width = barcode.width
|
||||
height = barcode.height
|
||||
```
|
||||
|
||||
### Code39 (Standard39)
|
||||
|
||||
Supports: 0-9, A-Z (uppercase), space, and special chars (-.$/+%*).
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import code39
|
||||
|
||||
barcode = code39.Standard39(
|
||||
value="HELLO",
|
||||
barWidth=0.01*inch,
|
||||
barHeight=0.5*inch,
|
||||
quiet=1,
|
||||
checksum=0, # 0 or 1
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
### Extended Code39
|
||||
|
||||
Encodes full ASCII (pairs of Code39 characters).
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import code39
|
||||
|
||||
barcode = code39.Extended39(
|
||||
value="Hello World!", # Can include lowercase and symbols
|
||||
barWidth=0.01*inch,
|
||||
barHeight=0.5*inch,
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
### Code93
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import code93
|
||||
|
||||
# Standard93 - uppercase, digits, some symbols
|
||||
barcode = code93.Standard93(
|
||||
value="HELLO93",
|
||||
barWidth=0.01*inch,
|
||||
barHeight=0.5*inch,
|
||||
)
|
||||
|
||||
# Extended93 - full ASCII
|
||||
barcode = code93.Extended93(
|
||||
value="Hello 93!",
|
||||
barWidth=0.01*inch,
|
||||
barHeight=0.5*inch,
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
### EAN-13 (European Article Number)
|
||||
|
||||
13-digit barcode for retail products.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import eanbc
|
||||
|
||||
# Must be exactly 12 digits (13th is calculated checksum)
|
||||
barcode = eanbc.Ean13BarcodeWidget(
|
||||
value="123456789012"
|
||||
)
|
||||
|
||||
# Draw
|
||||
from reportlab.graphics import renderPDF
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
|
||||
d = Drawing()
|
||||
d.add(barcode)
|
||||
renderPDF.draw(d, canvas, x, y)
|
||||
```
|
||||
|
||||
### EAN-8
|
||||
|
||||
Short form, 8 digits.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import eanbc
|
||||
|
||||
# Must be exactly 7 digits (8th is calculated)
|
||||
barcode = eanbc.Ean8BarcodeWidget(
|
||||
value="1234567"
|
||||
)
|
||||
```
|
||||
|
||||
### UPC-A
|
||||
|
||||
12-digit barcode used in North America.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import usps
|
||||
|
||||
# 11 digits (12th is checksum)
|
||||
barcode = usps.UPCA(
|
||||
value="01234567890"
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
### ISBN (Books)
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode.widgets import ISBNBarcodeWidget
|
||||
|
||||
# 10 or 13 digit ISBN
|
||||
barcode = ISBNBarcodeWidget(
|
||||
value="978-0-123456-78-9"
|
||||
)
|
||||
|
||||
# With pricing (EAN-5 add-on)
|
||||
barcode = ISBNBarcodeWidget(
|
||||
value="978-0-123456-78-9",
|
||||
price=True,
|
||||
)
|
||||
```
|
||||
|
||||
### QR Codes
|
||||
|
||||
Most versatile 2D barcode - can encode URLs, text, contact info, etc.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode.qr import QrCodeWidget
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics import renderPDF
|
||||
|
||||
# Create QR code
|
||||
qr = QrCodeWidget("https://example.com")
|
||||
|
||||
# Size in pixels (QR codes are square)
|
||||
qr.barWidth = 100 # Width in points
|
||||
qr.barHeight = 100 # Height in points
|
||||
|
||||
# Error correction level
|
||||
# L = 7% recovery, M = 15%, Q = 25%, H = 30%
|
||||
qr.qrVersion = 1 # Auto-size (1-40, or None for auto)
|
||||
qr.errorLevel = 'M' # L, M, Q, H
|
||||
|
||||
# Draw
|
||||
d = Drawing()
|
||||
d.add(qr)
|
||||
renderPDF.draw(d, canvas, x, y)
|
||||
```
|
||||
|
||||
### QR Code - More Options
|
||||
|
||||
```python
|
||||
# URL QR Code
|
||||
qr = QrCodeWidget("https://example.com")
|
||||
|
||||
# Contact information (vCard)
|
||||
vcard_data = """BEGIN:VCARD
|
||||
VERSION:3.0
|
||||
FN:John Doe
|
||||
TEL:+1-555-1234
|
||||
EMAIL:john@example.com
|
||||
END:VCARD"""
|
||||
qr = QrCodeWidget(vcard_data)
|
||||
|
||||
# WiFi credentials
|
||||
wifi_data = "WIFI:T:WPA;S:NetworkName;P:Password;;"
|
||||
qr = QrCodeWidget(wifi_data)
|
||||
|
||||
# Plain text
|
||||
qr = QrCodeWidget("Any text here")
|
||||
```
|
||||
|
||||
### Data Matrix (ECC200)
|
||||
|
||||
Compact 2D barcode for small items.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode.datamatrix import DataMatrixWidget
|
||||
|
||||
barcode = DataMatrixWidget(
|
||||
value="DATA123"
|
||||
)
|
||||
|
||||
d = Drawing()
|
||||
d.add(barcode)
|
||||
renderPDF.draw(d, canvas, x, y)
|
||||
```
|
||||
|
||||
### Postal Barcodes
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import usps
|
||||
|
||||
# POSTNET (older format)
|
||||
barcode = usps.POSTNET(
|
||||
value="55555-1234", # ZIP or ZIP+4
|
||||
)
|
||||
|
||||
# USPS 4-State (newer)
|
||||
barcode = usps.USPS_4State(
|
||||
value="12345678901234567890", # 20-digit routing code
|
||||
routing="12345678901"
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
### FIM (Facing Identification Mark)
|
||||
|
||||
Used for mail sorting.
|
||||
|
||||
```python
|
||||
from reportlab.graphics.barcode import usps
|
||||
|
||||
# FIM-A, FIM-B, FIM-C, or FIM-D
|
||||
barcode = usps.FIM(
|
||||
value="A" # A, B, C, or D
|
||||
)
|
||||
|
||||
barcode.drawOn(canvas, x, y)
|
||||
```
|
||||
|
||||
## Using Barcodes with Platypus
|
||||
|
||||
For flowing documents, wrap barcodes in Flowables.
|
||||
|
||||
### Simple Approach - Drawing Flowable
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.barcode.qr import QrCodeWidget
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Create drawing
|
||||
d = Drawing(2*inch, 2*inch)
|
||||
|
||||
# Create barcode
|
||||
qr = QrCodeWidget("https://example.com")
|
||||
qr.barWidth = 2*inch
|
||||
qr.barHeight = 2*inch
|
||||
qr.x = 0
|
||||
qr.y = 0
|
||||
|
||||
d.add(qr)
|
||||
|
||||
# Add to story
|
||||
story.append(d)
|
||||
```
|
||||
|
||||
### Custom Flowable Wrapper
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Flowable
|
||||
from reportlab.graphics.barcode import code128
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
class BarcodeFlowable(Flowable):
|
||||
def __init__(self, code, barcode_type='code128', width=2*inch, height=0.5*inch):
|
||||
Flowable.__init__(self)
|
||||
self.code = code
|
||||
self.barcode_type = barcode_type
|
||||
self.width_val = width
|
||||
self.height_val = height
|
||||
|
||||
# Create barcode
|
||||
if barcode_type == 'code128':
|
||||
self.barcode = code128.Code128(code, barWidth=width/100, barHeight=height)
|
||||
# Add other types as needed
|
||||
|
||||
def draw(self):
|
||||
self.barcode.drawOn(self.canv, 0, 0)
|
||||
|
||||
def wrap(self, availWidth, availHeight):
|
||||
return (self.barcode.width, self.barcode.height)
|
||||
|
||||
# Use in story
|
||||
story.append(BarcodeFlowable("PRODUCT123"))
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Product Label with Barcode
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.graphics.barcode import code128
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
def create_product_label(filename, product_code, product_name):
|
||||
c = canvas.Canvas(filename, pagesize=(4*inch, 2*inch))
|
||||
|
||||
# Product name
|
||||
c.setFont("Helvetica-Bold", 14)
|
||||
c.drawCentredString(2*inch, 1.5*inch, product_name)
|
||||
|
||||
# Barcode
|
||||
barcode = code128.Code128(product_code)
|
||||
barcode_width = barcode.width
|
||||
barcode_height = barcode.height
|
||||
|
||||
# Center barcode
|
||||
x = (4*inch - barcode_width) / 2
|
||||
y = 0.5*inch
|
||||
|
||||
barcode.drawOn(c, x, y)
|
||||
|
||||
# Code text
|
||||
c.setFont("Courier", 10)
|
||||
c.drawCentredString(2*inch, 0.3*inch, product_code)
|
||||
|
||||
c.save()
|
||||
|
||||
create_product_label("label.pdf", "ABC123456789", "Premium Widget")
|
||||
```
|
||||
|
||||
### QR Code Contact Card
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.graphics.barcode.qr import QrCodeWidget
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics import renderPDF
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
def create_contact_card(filename, name, phone, email):
|
||||
c = canvas.Canvas(filename, pagesize=(3.5*inch, 2*inch))
|
||||
|
||||
# Contact info
|
||||
c.setFont("Helvetica-Bold", 12)
|
||||
c.drawString(0.5*inch, 1.5*inch, name)
|
||||
c.setFont("Helvetica", 10)
|
||||
c.drawString(0.5*inch, 1.3*inch, phone)
|
||||
c.drawString(0.5*inch, 1.1*inch, email)
|
||||
|
||||
# Create vCard data
|
||||
vcard = f"""BEGIN:VCARD
|
||||
VERSION:3.0
|
||||
FN:{name}
|
||||
TEL:{phone}
|
||||
EMAIL:{email}
|
||||
END:VCARD"""
|
||||
|
||||
# QR code
|
||||
qr = QrCodeWidget(vcard)
|
||||
qr.barWidth = 1.5*inch
|
||||
qr.barHeight = 1.5*inch
|
||||
|
||||
d = Drawing()
|
||||
d.add(qr)
|
||||
|
||||
renderPDF.draw(d, c, 1.8*inch, 0.2*inch)
|
||||
|
||||
c.save()
|
||||
|
||||
create_contact_card("contact.pdf", "John Doe", "+1-555-1234", "john@example.com")
|
||||
```
|
||||
|
||||
### Shipping Label with Multiple Barcodes
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.graphics.barcode import code128
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
def create_shipping_label(filename, tracking_code, zip_code):
|
||||
c = canvas.Canvas(filename, pagesize=(6*inch, 4*inch))
|
||||
|
||||
# Title
|
||||
c.setFont("Helvetica-Bold", 16)
|
||||
c.drawString(0.5*inch, 3.5*inch, "SHIPPING LABEL")
|
||||
|
||||
# Tracking barcode
|
||||
c.setFont("Helvetica", 10)
|
||||
c.drawString(0.5*inch, 2.8*inch, "Tracking Number:")
|
||||
|
||||
tracking_barcode = code128.Code128(tracking_code, barHeight=0.5*inch)
|
||||
tracking_barcode.drawOn(c, 0.5*inch, 2*inch)
|
||||
|
||||
c.setFont("Courier", 9)
|
||||
c.drawString(0.5*inch, 1.8*inch, tracking_code)
|
||||
|
||||
# Additional info can be added
|
||||
|
||||
c.save()
|
||||
|
||||
create_shipping_label("shipping.pdf", "1Z999AA10123456784", "12345")
|
||||
```
|
||||
|
||||
## Barcode Selection Guide
|
||||
|
||||
**Choose Code128 when:**
|
||||
- General purpose encoding
|
||||
- Need to encode numbers and letters
|
||||
- Want compact size
|
||||
- Widely supported
|
||||
|
||||
**Choose Code39 when:**
|
||||
- Older systems require it
|
||||
- Don't need lowercase letters
|
||||
- Want maximum compatibility
|
||||
|
||||
**Choose QR Code when:**
|
||||
- Need to encode URLs
|
||||
- Want mobile device scanning
|
||||
- Need high data capacity
|
||||
- Want error correction
|
||||
|
||||
**Choose EAN/UPC when:**
|
||||
- Retail product identification
|
||||
- Need industry-standard format
|
||||
- Global distribution
|
||||
|
||||
**Choose Data Matrix when:**
|
||||
- Very limited space
|
||||
- Small items (PCB, electronics)
|
||||
- Need 2D compact format
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test scanning** early with actual barcode scanners/readers
|
||||
2. **Add quiet zones** (white space) around barcodes - set `quiet=1`
|
||||
3. **Choose appropriate height** - taller barcodes are easier to scan
|
||||
4. **Include human-readable text** below barcode for manual entry
|
||||
5. **Use Code128** as default for general purpose - it's compact and versatile
|
||||
6. **For URLs, use QR codes** - much easier for mobile users
|
||||
7. **Check barcode standards** for your industry (retail uses EAN/UPC)
|
||||
8. **Test print quality** - low DPI can make barcodes unscannable
|
||||
9. **Validate data** before encoding - wrong check digits cause issues
|
||||
10. **Consider error correction** for QR codes - use 'M' or 'H' for important data
|
||||
@@ -1,241 +0,0 @@
|
||||
# Canvas API Reference
|
||||
|
||||
The Canvas API provides low-level, precise control over PDF generation using coordinate-based drawing.
|
||||
|
||||
## Coordinate System
|
||||
|
||||
- Origin (0, 0) is at the **lower-left corner** (not top-left like web graphics)
|
||||
- X-axis points right, Y-axis points upward
|
||||
- Units are in points (72 points = 1 inch)
|
||||
- Default page size is A4; explicitly specify page size for consistency
|
||||
|
||||
## Basic Setup
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.lib.pagesizes import letter, A4
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Create canvas
|
||||
c = canvas.Canvas("output.pdf", pagesize=letter)
|
||||
|
||||
# Get page dimensions
|
||||
width, height = letter
|
||||
|
||||
# Draw content
|
||||
c.drawString(100, 100, "Hello World")
|
||||
|
||||
# Finish page and save
|
||||
c.showPage() # Complete current page
|
||||
c.save() # Write PDF to disk
|
||||
```
|
||||
|
||||
## Text Drawing
|
||||
|
||||
### Basic String Methods
|
||||
```python
|
||||
# Basic text placement
|
||||
c.drawString(x, y, text) # Left-aligned at x, y
|
||||
c.drawRightString(x, y, text) # Right-aligned at x, y
|
||||
c.drawCentredString(x, y, text) # Center-aligned at x, y
|
||||
|
||||
# Font control
|
||||
c.setFont(fontname, size) # e.g., "Helvetica", 12
|
||||
c.setFillColor(color) # Text color
|
||||
```
|
||||
|
||||
### Text Objects (Advanced)
|
||||
For complex text operations with multiple lines and precise control:
|
||||
|
||||
```python
|
||||
t = c.beginText(x, y)
|
||||
t.setFont("Times-Roman", 14)
|
||||
t.textLine("First line")
|
||||
t.textLine("Second line")
|
||||
t.setTextOrigin(x, y) # Reset position
|
||||
c.drawText(t)
|
||||
```
|
||||
|
||||
## Drawing Primitives
|
||||
|
||||
### Lines
|
||||
```python
|
||||
c.line(x1, y1, x2, y2) # Single line
|
||||
c.lines([(x1,y1,x2,y2), (x3,y3,x4,y4)]) # Multiple lines
|
||||
c.grid(xlist, ylist) # Grid from coordinate lists
|
||||
```
|
||||
|
||||
### Shapes
|
||||
```python
|
||||
c.rect(x, y, width, height, stroke=1, fill=0)
|
||||
c.roundRect(x, y, width, height, radius, stroke=1, fill=0)
|
||||
c.circle(x_ctr, y_ctr, r, stroke=1, fill=0)
|
||||
c.ellipse(x1, y1, x2, y2, stroke=1, fill=0)
|
||||
c.wedge(x, y, radius, startAng, extent, stroke=1, fill=0)
|
||||
```
|
||||
|
||||
### Bezier Curves
|
||||
```python
|
||||
c.bezier(x1, y1, x2, y2, x3, y3, x4, y4)
|
||||
```
|
||||
|
||||
## Path Objects
|
||||
|
||||
For complex shapes, use path objects:
|
||||
|
||||
```python
|
||||
p = c.beginPath()
|
||||
p.moveTo(x, y) # Move without drawing
|
||||
p.lineTo(x, y) # Draw line to point
|
||||
p.curveTo(x1, y1, x2, y2, x3, y3) # Bezier curve
|
||||
p.arc(x1, y1, x2, y2, startAng, extent)
|
||||
p.arcTo(x1, y1, x2, y2, startAng, extent)
|
||||
p.close() # Close path to start point
|
||||
|
||||
# Draw the path
|
||||
c.drawPath(p, stroke=1, fill=0)
|
||||
```
|
||||
|
||||
## Colors
|
||||
|
||||
### RGB (Screen Display)
|
||||
```python
|
||||
from reportlab.lib.colors import red, blue, Color
|
||||
|
||||
c.setFillColorRGB(r, g, b) # r, g, b are 0-1
|
||||
c.setStrokeColorRGB(r, g, b)
|
||||
c.setFillColor(red) # Named colors
|
||||
c.setStrokeColor(blue)
|
||||
|
||||
# Custom with alpha transparency
|
||||
c.setFillColor(Color(0.5, 0, 0, alpha=0.5))
|
||||
```
|
||||
|
||||
### CMYK (Professional Printing)
|
||||
```python
|
||||
from reportlab.lib.colors import CMYKColor, PCMYKColor
|
||||
|
||||
c.setFillColorCMYK(c, m, y, k) # 0-1 range
|
||||
c.setStrokeColorCMYK(c, m, y, k)
|
||||
|
||||
# Integer percentages (0-100)
|
||||
c.setFillColor(PCMYKColor(100, 50, 0, 0))
|
||||
```
|
||||
|
||||
## Line Styling
|
||||
|
||||
```python
|
||||
c.setLineWidth(width) # Thickness in points
|
||||
c.setLineCap(mode) # 0=butt, 1=round, 2=square
|
||||
c.setLineJoin(mode) # 0=miter, 1=round, 2=bevel
|
||||
c.setDash(array, phase) # e.g., [3, 3] for dotted line
|
||||
```
|
||||
|
||||
## Coordinate Transformations
|
||||
|
||||
**IMPORTANT:** Transformations are incremental and cumulative.
|
||||
|
||||
```python
|
||||
# Translation (move origin)
|
||||
c.translate(dx, dy)
|
||||
|
||||
# Rotation (in degrees, counterclockwise)
|
||||
c.rotate(theta)
|
||||
|
||||
# Scaling
|
||||
c.scale(xscale, yscale)
|
||||
|
||||
# Skewing
|
||||
c.skew(alpha, beta)
|
||||
```
|
||||
|
||||
### State Management
|
||||
```python
|
||||
# Save current graphics state
|
||||
c.saveState()
|
||||
|
||||
# ... apply transformations and draw ...
|
||||
|
||||
# Restore previous state
|
||||
c.restoreState()
|
||||
```
|
||||
|
||||
**Note:** State cannot be preserved across `showPage()` calls.
|
||||
|
||||
## Images
|
||||
|
||||
```python
|
||||
from reportlab.lib.utils import ImageReader
|
||||
|
||||
# Preferred method (with caching)
|
||||
c.drawImage(image_source, x, y, width=None, height=None,
|
||||
mask=None, preserveAspectRatio=False)
|
||||
|
||||
# image_source can be:
|
||||
# - Filename string
|
||||
# - PIL Image object
|
||||
# - ImageReader object
|
||||
|
||||
# For transparency, specify RGB mask range
|
||||
c.drawImage("logo.png", 100, 500, mask=[255, 255, 255, 255, 255, 255])
|
||||
|
||||
# Inline (inefficient, no caching)
|
||||
c.drawInlineImage(image_source, x, y, width=None, height=None)
|
||||
```
|
||||
|
||||
## Page Management
|
||||
|
||||
```python
|
||||
# Complete current page
|
||||
c.showPage()
|
||||
|
||||
# Set page size for next page
|
||||
c.setPageSize(size) # e.g., letter, A4
|
||||
|
||||
# Page compression (smaller files, slower generation)
|
||||
c = canvas.Canvas("output.pdf", pageCompression=1)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Margins and Layout
|
||||
```python
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib.pagesizes import letter
|
||||
|
||||
width, height = letter
|
||||
margin = inch
|
||||
|
||||
# Draw within margins
|
||||
content_width = width - 2*margin
|
||||
content_height = height - 2*margin
|
||||
|
||||
# Text at top margin
|
||||
c.drawString(margin, height - margin, "Header")
|
||||
|
||||
# Text at bottom margin
|
||||
c.drawString(margin, margin, "Footer")
|
||||
```
|
||||
|
||||
### Headers and Footers
|
||||
```python
|
||||
def draw_header_footer(c, width, height):
|
||||
c.saveState()
|
||||
c.setFont("Helvetica", 9)
|
||||
c.drawString(inch, height - 0.5*inch, "Company Name")
|
||||
c.drawRightString(width - inch, 0.5*inch, f"Page {c.getPageNumber()}")
|
||||
c.restoreState()
|
||||
|
||||
# Call on each page
|
||||
draw_header_footer(c, width, height)
|
||||
c.showPage()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always specify page size** - Different platforms have different defaults
|
||||
2. **Use variables for measurements** - `margin = inch` instead of hardcoded values
|
||||
3. **Match saveState/restoreState** - Always balance these calls
|
||||
4. **Apply transformations externally** for engineering drawings to prevent line width scaling
|
||||
5. **Use drawImage over drawInlineImage** for better performance with repeated images
|
||||
6. **Draw from bottom-up** - Remember Y-axis points upward
|
||||
@@ -1,624 +0,0 @@
|
||||
# Charts and Graphics Reference
|
||||
|
||||
Comprehensive guide to creating charts and data visualizations in ReportLab.
|
||||
|
||||
## Graphics Architecture
|
||||
|
||||
ReportLab's graphics system provides platform-independent drawing:
|
||||
|
||||
- **Drawings** - Container for shapes and charts
|
||||
- **Shapes** - Primitives (rectangles, circles, lines, polygons, paths)
|
||||
- **Renderers** - Convert to PDF, PostScript, SVG, or bitmaps (PNG, GIF, JPG)
|
||||
- **Coordinate System** - Y-axis points upward (like PDF, unlike web graphics)
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.charts.barcharts import VerticalBarChart
|
||||
from reportlab.graphics import renderPDF
|
||||
|
||||
# Create drawing (canvas for chart)
|
||||
drawing = Drawing(400, 200)
|
||||
|
||||
# Create chart
|
||||
chart = VerticalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 125
|
||||
chart.data = [[100, 150, 130, 180]]
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Add chart to drawing
|
||||
drawing.add(chart)
|
||||
|
||||
# Render to PDF
|
||||
renderPDF.drawToFile(drawing, 'chart.pdf', 'Chart Title')
|
||||
|
||||
# Or add as flowable to Platypus document
|
||||
story.append(drawing)
|
||||
```
|
||||
|
||||
## Available Chart Types
|
||||
|
||||
### Bar Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.barcharts import (
|
||||
VerticalBarChart,
|
||||
HorizontalBarChart,
|
||||
)
|
||||
|
||||
# Vertical bar chart
|
||||
chart = VerticalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Single series
|
||||
chart.data = [[100, 150, 130, 180, 140]]
|
||||
|
||||
# Multiple series (grouped bars)
|
||||
chart.data = [
|
||||
[100, 150, 130, 180], # Series 1
|
||||
[80, 120, 110, 160], # Series 2
|
||||
]
|
||||
|
||||
# Categories
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Colors for each series
|
||||
chart.bars[0].fillColor = colors.blue
|
||||
chart.bars[1].fillColor = colors.red
|
||||
|
||||
# Bar spacing
|
||||
chart.barWidth = 10
|
||||
chart.groupSpacing = 10
|
||||
chart.barSpacing = 2
|
||||
```
|
||||
|
||||
### Stacked Bar Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.barcharts import VerticalBarChart
|
||||
|
||||
chart = VerticalBarChart()
|
||||
# ... set position and size ...
|
||||
|
||||
chart.data = [
|
||||
[100, 150, 130, 180], # Bottom layer
|
||||
[50, 70, 60, 90], # Top layer
|
||||
]
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Enable stacking
|
||||
chart.barLabelFormat = 'values'
|
||||
chart.valueAxis.visible = 1
|
||||
```
|
||||
|
||||
### Horizontal Bar Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.barcharts import HorizontalBarChart
|
||||
|
||||
chart = HorizontalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
chart.data = [[100, 150, 130, 180]]
|
||||
chart.categoryAxis.categoryNames = ['Product A', 'Product B', 'Product C', 'Product D']
|
||||
|
||||
# Horizontal charts use valueAxis horizontally
|
||||
chart.valueAxis.valueMin = 0
|
||||
chart.valueAxis.valueMax = 200
|
||||
```
|
||||
|
||||
### Line Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.linecharts import HorizontalLineChart
|
||||
|
||||
chart = HorizontalLineChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Multiple lines
|
||||
chart.data = [
|
||||
[100, 150, 130, 180, 140], # Line 1
|
||||
[80, 120, 110, 160, 130], # Line 2
|
||||
]
|
||||
|
||||
chart.categoryAxis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr', 'May']
|
||||
|
||||
# Line styling
|
||||
chart.lines[0].strokeColor = colors.blue
|
||||
chart.lines[0].strokeWidth = 2
|
||||
chart.lines[1].strokeColor = colors.red
|
||||
chart.lines[1].strokeWidth = 2
|
||||
|
||||
# Show/hide points
|
||||
chart.lines[0].symbol = None # No symbols
|
||||
# Or use symbols from makeMarker()
|
||||
```
|
||||
|
||||
### Line Plots (X-Y Plots)
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.lineplots import LinePlot
|
||||
|
||||
chart = LinePlot()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Data as (x, y) tuples
|
||||
chart.data = [
|
||||
[(0, 0), (1, 1), (2, 4), (3, 9), (4, 16)], # y = x^2
|
||||
[(0, 0), (1, 2), (2, 4), (3, 6), (4, 8)], # y = 2x
|
||||
]
|
||||
|
||||
# Both axes are value axes (not category)
|
||||
chart.xValueAxis.valueMin = 0
|
||||
chart.xValueAxis.valueMax = 5
|
||||
chart.yValueAxis.valueMin = 0
|
||||
chart.yValueAxis.valueMax = 20
|
||||
|
||||
# Line styling
|
||||
chart.lines[0].strokeColor = colors.blue
|
||||
chart.lines[1].strokeColor = colors.red
|
||||
```
|
||||
|
||||
### Pie Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.piecharts import Pie
|
||||
|
||||
chart = Pie()
|
||||
chart.x = 100
|
||||
chart.y = 50
|
||||
chart.width = 200
|
||||
chart.height = 200
|
||||
|
||||
chart.data = [25, 35, 20, 20]
|
||||
chart.labels = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Slice colors
|
||||
chart.slices[0].fillColor = colors.blue
|
||||
chart.slices[1].fillColor = colors.red
|
||||
chart.slices[2].fillColor = colors.green
|
||||
chart.slices[3].fillColor = colors.yellow
|
||||
|
||||
# Pop out a slice
|
||||
chart.slices[1].popout = 10
|
||||
|
||||
# Label positioning
|
||||
chart.slices.strokeColor = colors.white
|
||||
chart.slices.strokeWidth = 2
|
||||
```
|
||||
|
||||
### Pie Chart with Side Labels
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.piecharts import Pie
|
||||
|
||||
chart = Pie()
|
||||
# ... set position, data, labels ...
|
||||
|
||||
# Side label mode (labels in columns beside pie)
|
||||
chart.sideLabels = 1
|
||||
chart.sideLabelsOffset = 0.1 # Distance from pie
|
||||
|
||||
# Simple labels (not fancy layout)
|
||||
chart.simpleLabels = 1
|
||||
```
|
||||
|
||||
### Area Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.areacharts import HorizontalAreaChart
|
||||
|
||||
chart = HorizontalAreaChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Areas stack on top of each other
|
||||
chart.data = [
|
||||
[100, 150, 130, 180], # Bottom area
|
||||
[50, 70, 60, 90], # Top area
|
||||
]
|
||||
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Area colors
|
||||
chart.strands[0].fillColor = colors.lightblue
|
||||
chart.strands[1].fillColor = colors.pink
|
||||
```
|
||||
|
||||
### Scatter Charts
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.lineplots import ScatterPlot
|
||||
|
||||
chart = ScatterPlot()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Data points
|
||||
chart.data = [
|
||||
[(1, 2), (2, 3), (3, 5), (4, 4), (5, 6)], # Series 1
|
||||
[(1, 1), (2, 2), (3, 3), (4, 3), (5, 4)], # Series 2
|
||||
]
|
||||
|
||||
# Hide lines, show points only
|
||||
chart.lines[0].strokeColor = None
|
||||
chart.lines[1].strokeColor = None
|
||||
|
||||
# Marker symbols
|
||||
from reportlab.graphics.widgets.markers import makeMarker
|
||||
chart.lines[0].symbol = makeMarker('Circle')
|
||||
chart.lines[1].symbol = makeMarker('Square')
|
||||
```
|
||||
|
||||
## Axes Configuration
|
||||
|
||||
### Category Axis (XCategoryAxis)
|
||||
|
||||
For categorical data (labels, not numbers):
|
||||
|
||||
```python
|
||||
# Access via chart
|
||||
axis = chart.categoryAxis
|
||||
|
||||
# Labels
|
||||
axis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr']
|
||||
|
||||
# Label angle (for long labels)
|
||||
axis.labels.angle = 45
|
||||
axis.labels.dx = 0
|
||||
axis.labels.dy = -5
|
||||
|
||||
# Label formatting
|
||||
axis.labels.fontSize = 10
|
||||
axis.labels.fontName = 'Helvetica'
|
||||
|
||||
# Visibility
|
||||
axis.visible = 1
|
||||
```
|
||||
|
||||
### Value Axis (YValueAxis)
|
||||
|
||||
For numeric data:
|
||||
|
||||
```python
|
||||
# Access via chart
|
||||
axis = chart.valueAxis
|
||||
|
||||
# Range
|
||||
axis.valueMin = 0
|
||||
axis.valueMax = 200
|
||||
axis.valueStep = 50 # Tick interval
|
||||
|
||||
# Or auto-configure
|
||||
axis.valueSteps = [0, 50, 100, 150, 200] # Explicit steps
|
||||
|
||||
# Label formatting
|
||||
axis.labels.fontSize = 10
|
||||
axis.labelTextFormat = '%d%%' # Add percentage sign
|
||||
|
||||
# Grid lines
|
||||
axis.strokeWidth = 1
|
||||
axis.strokeColor = colors.black
|
||||
```
|
||||
|
||||
## Styling and Customization
|
||||
|
||||
### Colors
|
||||
|
||||
```python
|
||||
from reportlab.lib import colors
|
||||
|
||||
# Named colors
|
||||
colors.blue, colors.red, colors.green, colors.yellow
|
||||
|
||||
# RGB
|
||||
colors.Color(0.5, 0.5, 0.5) # Grey
|
||||
|
||||
# With alpha
|
||||
colors.Color(1, 0, 0, alpha=0.5) # Semi-transparent red
|
||||
|
||||
# Hex colors
|
||||
colors.HexColor('#FF5733')
|
||||
```
|
||||
|
||||
### Line Styling
|
||||
|
||||
```python
|
||||
# For line charts
|
||||
chart.lines[0].strokeColor = colors.blue
|
||||
chart.lines[0].strokeWidth = 2
|
||||
chart.lines[0].strokeDashArray = [2, 2] # Dashed line
|
||||
```
|
||||
|
||||
### Bar Labels
|
||||
|
||||
```python
|
||||
# Show values on bars
|
||||
chart.barLabels.nudge = 5 # Offset from bar top
|
||||
chart.barLabels.fontSize = 8
|
||||
chart.barLabelFormat = '%d' # Number format
|
||||
|
||||
# For negative values
|
||||
chart.barLabels.dy = -5 # Position below bar
|
||||
```
|
||||
|
||||
## Legends
|
||||
|
||||
Charts can have associated legends:
|
||||
|
||||
```python
|
||||
from reportlab.graphics.charts.legends import Legend
|
||||
|
||||
# Create legend
|
||||
legend = Legend()
|
||||
legend.x = 350
|
||||
legend.y = 150
|
||||
legend.columnMaximum = 10
|
||||
|
||||
# Link to chart (share colors)
|
||||
legend.colorNamePairs = [
|
||||
(chart.bars[0].fillColor, 'Series 1'),
|
||||
(chart.bars[1].fillColor, 'Series 2'),
|
||||
]
|
||||
|
||||
# Add to drawing
|
||||
drawing.add(legend)
|
||||
```
|
||||
|
||||
## Drawing Shapes
|
||||
|
||||
### Basic Shapes
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import (
|
||||
Drawing, Rect, Circle, Ellipse, Line, Polygon, String
|
||||
)
|
||||
from reportlab.lib import colors
|
||||
|
||||
drawing = Drawing(400, 200)
|
||||
|
||||
# Rectangle
|
||||
rect = Rect(50, 50, 100, 50)
|
||||
rect.fillColor = colors.blue
|
||||
rect.strokeColor = colors.black
|
||||
rect.strokeWidth = 1
|
||||
drawing.add(rect)
|
||||
|
||||
# Circle
|
||||
circle = Circle(200, 100, 30)
|
||||
circle.fillColor = colors.red
|
||||
drawing.add(circle)
|
||||
|
||||
# Line
|
||||
line = Line(50, 150, 350, 150)
|
||||
line.strokeColor = colors.black
|
||||
line.strokeWidth = 2
|
||||
drawing.add(line)
|
||||
|
||||
# Text
|
||||
text = String(50, 175, "Label Text")
|
||||
text.fontSize = 12
|
||||
text.fontName = 'Helvetica'
|
||||
drawing.add(text)
|
||||
```
|
||||
|
||||
### Paths (Complex Shapes)
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Path
|
||||
|
||||
path = Path()
|
||||
path.moveTo(50, 50)
|
||||
path.lineTo(100, 100)
|
||||
path.curveTo(120, 120, 140, 100, 150, 50)
|
||||
path.closePath()
|
||||
|
||||
path.fillColor = colors.lightblue
|
||||
path.strokeColor = colors.blue
|
||||
path.strokeWidth = 2
|
||||
|
||||
drawing.add(path)
|
||||
```
|
||||
|
||||
## Rendering Options
|
||||
|
||||
### Render to PDF
|
||||
|
||||
```python
|
||||
from reportlab.graphics import renderPDF
|
||||
|
||||
# Direct to file
|
||||
renderPDF.drawToFile(drawing, 'output.pdf', 'Chart Title')
|
||||
|
||||
# As flowable in Platypus
|
||||
story.append(drawing)
|
||||
```
|
||||
|
||||
### Render to Image
|
||||
|
||||
```python
|
||||
from reportlab.graphics import renderPM
|
||||
|
||||
# PNG
|
||||
renderPM.drawToFile(drawing, 'chart.png', fmt='PNG')
|
||||
|
||||
# GIF
|
||||
renderPM.drawToFile(drawing, 'chart.gif', fmt='GIF')
|
||||
|
||||
# JPG
|
||||
renderPM.drawToFile(drawing, 'chart.jpg', fmt='JPG')
|
||||
|
||||
# With specific DPI
|
||||
renderPM.drawToFile(drawing, 'chart.png', fmt='PNG', dpi=150)
|
||||
```
|
||||
|
||||
### Render to SVG
|
||||
|
||||
```python
|
||||
from reportlab.graphics import renderSVG
|
||||
|
||||
renderSVG.drawToFile(drawing, 'chart.svg')
|
||||
```
|
||||
|
||||
## Advanced Customization
|
||||
|
||||
### Inspect Properties
|
||||
|
||||
```python
|
||||
# List all properties
|
||||
print(chart.getProperties())
|
||||
|
||||
# Dump properties (for debugging)
|
||||
chart.dumpProperties()
|
||||
|
||||
# Set multiple properties
|
||||
chart.setProperties({
|
||||
'width': 400,
|
||||
'height': 200,
|
||||
'data': [[100, 150, 130]],
|
||||
})
|
||||
```
|
||||
|
||||
### Custom Colors for Series
|
||||
|
||||
```python
|
||||
# Define color scheme
|
||||
from reportlab.lib.colors import PCMYKColor
|
||||
|
||||
colors_list = [
|
||||
PCMYKColor(100, 67, 0, 23), # Blue
|
||||
PCMYKColor(0, 100, 100, 0), # Red
|
||||
PCMYKColor(66, 13, 0, 22), # Green
|
||||
]
|
||||
|
||||
# Apply to chart
|
||||
for i, color in enumerate(colors_list):
|
||||
chart.bars[i].fillColor = color
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Sales Report Bar Chart
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.charts.barcharts import VerticalBarChart
|
||||
from reportlab.graphics.charts.legends import Legend
|
||||
from reportlab.lib import colors
|
||||
|
||||
drawing = Drawing(400, 250)
|
||||
|
||||
# Create chart
|
||||
chart = VerticalBarChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 300
|
||||
chart.height = 150
|
||||
|
||||
# Data
|
||||
chart.data = [
|
||||
[120, 150, 180, 200], # 2023
|
||||
[100, 130, 160, 190], # 2022
|
||||
]
|
||||
chart.categoryAxis.categoryNames = ['Q1', 'Q2', 'Q3', 'Q4']
|
||||
|
||||
# Styling
|
||||
chart.bars[0].fillColor = colors.HexColor('#3498db')
|
||||
chart.bars[1].fillColor = colors.HexColor('#e74c3c')
|
||||
chart.valueAxis.valueMin = 0
|
||||
chart.valueAxis.valueMax = 250
|
||||
chart.categoryAxis.labels.fontSize = 10
|
||||
chart.valueAxis.labels.fontSize = 10
|
||||
|
||||
# Add legend
|
||||
legend = Legend()
|
||||
legend.x = 325
|
||||
legend.y = 200
|
||||
legend.columnMaximum = 2
|
||||
legend.colorNamePairs = [
|
||||
(chart.bars[0].fillColor, '2023'),
|
||||
(chart.bars[1].fillColor, '2022'),
|
||||
]
|
||||
|
||||
drawing.add(chart)
|
||||
drawing.add(legend)
|
||||
|
||||
# Add to story or save
|
||||
story.append(drawing)
|
||||
```
|
||||
|
||||
### Multi-Line Trend Chart
|
||||
|
||||
```python
|
||||
from reportlab.graphics.shapes import Drawing
|
||||
from reportlab.graphics.charts.linecharts import HorizontalLineChart
|
||||
from reportlab.lib import colors
|
||||
|
||||
drawing = Drawing(400, 250)
|
||||
|
||||
chart = HorizontalLineChart()
|
||||
chart.x = 50
|
||||
chart.y = 50
|
||||
chart.width = 320
|
||||
chart.height = 170
|
||||
|
||||
# Data
|
||||
chart.data = [
|
||||
[10, 15, 12, 18, 20, 25], # Product A
|
||||
[8, 10, 14, 16, 18, 22], # Product B
|
||||
[12, 11, 13, 15, 17, 19], # Product C
|
||||
]
|
||||
|
||||
chart.categoryAxis.categoryNames = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
|
||||
|
||||
# Line styling
|
||||
chart.lines[0].strokeColor = colors.blue
|
||||
chart.lines[0].strokeWidth = 2
|
||||
chart.lines[1].strokeColor = colors.red
|
||||
chart.lines[1].strokeWidth = 2
|
||||
chart.lines[2].strokeColor = colors.green
|
||||
chart.lines[2].strokeWidth = 2
|
||||
|
||||
# Axes
|
||||
chart.valueAxis.valueMin = 0
|
||||
chart.valueAxis.valueMax = 30
|
||||
chart.categoryAxis.labels.angle = 0
|
||||
chart.categoryAxis.labels.fontSize = 9
|
||||
chart.valueAxis.labels.fontSize = 9
|
||||
|
||||
drawing.add(chart)
|
||||
story.append(drawing)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Set explicit dimensions** for Drawing to ensure consistent sizing
|
||||
2. **Position charts** with enough margin (x, y at least 30-50 from edge)
|
||||
3. **Use consistent color schemes** throughout document
|
||||
4. **Set valueMin and valueMax** explicitly for consistent scales
|
||||
5. **Test with realistic data** to ensure labels fit and don't overlap
|
||||
6. **Add legends** for multi-series charts
|
||||
7. **Angle category labels** if they're long (45° works well)
|
||||
8. **Keep it simple** - fewer data series are easier to read
|
||||
9. **Use appropriate chart types** - bars for comparisons, lines for trends, pies for proportions
|
||||
10. **Consider colorblind-friendly palettes** - avoid red/green combinations
|
||||
@@ -1,561 +0,0 @@
|
||||
# PDF Features Reference
|
||||
|
||||
Advanced PDF capabilities: links, bookmarks, forms, encryption, and metadata.
|
||||
|
||||
## Document Metadata
|
||||
|
||||
Set PDF document properties viewable in PDF readers.
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("output.pdf")
|
||||
|
||||
# Set metadata
|
||||
c.setAuthor("John Doe")
|
||||
c.setTitle("Annual Report 2024")
|
||||
c.setSubject("Financial Analysis")
|
||||
c.setKeywords("finance, annual, report, 2024")
|
||||
c.setCreator("MyApp v1.0")
|
||||
|
||||
# ... draw content ...
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
With Platypus:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import SimpleDocTemplate
|
||||
|
||||
doc = SimpleDocTemplate(
|
||||
"output.pdf",
|
||||
title="Annual Report 2024",
|
||||
author="John Doe",
|
||||
subject="Financial Analysis",
|
||||
)
|
||||
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Bookmarks and Destinations
|
||||
|
||||
Create internal navigation structure.
|
||||
|
||||
### Simple Bookmarks
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("output.pdf")
|
||||
|
||||
# Create bookmark for current page
|
||||
c.bookmarkPage("intro") # Internal key
|
||||
c.addOutlineEntry("Introduction", "intro", level=0)
|
||||
|
||||
c.showPage()
|
||||
|
||||
# Another bookmark
|
||||
c.bookmarkPage("chapter1")
|
||||
c.addOutlineEntry("Chapter 1", "chapter1", level=0)
|
||||
|
||||
# Sub-sections
|
||||
c.bookmarkPage("section1_1")
|
||||
c.addOutlineEntry("Section 1.1", "section1_1", level=1) # Nested
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Bookmark Levels
|
||||
|
||||
```python
|
||||
# Create hierarchical outline
|
||||
c.bookmarkPage("ch1")
|
||||
c.addOutlineEntry("Chapter 1", "ch1", level=0)
|
||||
|
||||
c.bookmarkPage("ch1_s1")
|
||||
c.addOutlineEntry("Section 1.1", "ch1_s1", level=1)
|
||||
|
||||
c.bookmarkPage("ch1_s1_1")
|
||||
c.addOutlineEntry("Subsection 1.1.1", "ch1_s1_1", level=2)
|
||||
|
||||
c.bookmarkPage("ch2")
|
||||
c.addOutlineEntry("Chapter 2", "ch2", level=0)
|
||||
```
|
||||
|
||||
### Destination Fit Modes
|
||||
|
||||
Control how the page displays when navigating:
|
||||
|
||||
```python
|
||||
# bookmarkPage with fit mode
|
||||
c.bookmarkPage(
|
||||
key="chapter1",
|
||||
fit="Fit" # Fit entire page in window
|
||||
)
|
||||
|
||||
# Or use bookmarkHorizontalAbsolute
|
||||
c.bookmarkHorizontalAbsolute(key="section", top=500)
|
||||
|
||||
# Available fit modes:
|
||||
# "Fit" - Fit whole page
|
||||
# "FitH" - Fit horizontally
|
||||
# "FitV" - Fit vertically
|
||||
# "FitR" - Fit rectangle
|
||||
# "XYZ" - Specific position and zoom
|
||||
```
|
||||
|
||||
## Hyperlinks
|
||||
|
||||
### External Links
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
c = canvas.Canvas("output.pdf")
|
||||
|
||||
# Draw link rectangle
|
||||
c.linkURL(
|
||||
"https://www.example.com",
|
||||
rect=(1*inch, 5*inch, 3*inch, 5.5*inch), # (x1, y1, x2, y2)
|
||||
relative=0, # 0 for absolute positioning
|
||||
thickness=1,
|
||||
color=(0, 0, 1), # Blue
|
||||
dashArray=None
|
||||
)
|
||||
|
||||
# Draw text over link area
|
||||
c.setFillColorRGB(0, 0, 1) # Blue text
|
||||
c.drawString(1*inch, 5.2*inch, "Click here to visit example.com")
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Internal Links
|
||||
|
||||
Link to bookmarked locations within the document:
|
||||
|
||||
```python
|
||||
# Create destination
|
||||
c.bookmarkPage("target_section")
|
||||
|
||||
# Later, create link to that destination
|
||||
c.linkRect(
|
||||
"Link Text",
|
||||
"target_section", # Bookmark key
|
||||
rect=(1*inch, 3*inch, 2*inch, 3.2*inch),
|
||||
relative=0
|
||||
)
|
||||
```
|
||||
|
||||
### Links in Paragraphs
|
||||
|
||||
For Platypus documents:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Paragraph
|
||||
|
||||
# External link
|
||||
text = '<link href="https://example.com" color="blue">Visit our website</link>'
|
||||
para = Paragraph(text, style)
|
||||
|
||||
# Internal link (to anchor)
|
||||
text = '<link href="#section1" color="blue">Go to Section 1</link>'
|
||||
para1 = Paragraph(text, style)
|
||||
|
||||
# Create anchor
|
||||
text = '<a name="section1"/>Section 1 Heading'
|
||||
para2 = Paragraph(text, heading_style)
|
||||
|
||||
story.append(para1)
|
||||
story.append(para2)
|
||||
```
|
||||
|
||||
## Interactive Forms
|
||||
|
||||
Create fillable PDF forms.
|
||||
|
||||
### Text Fields
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.pdfbase import pdfform
|
||||
from reportlab.lib.colors import black, white
|
||||
|
||||
c = canvas.Canvas("form.pdf")
|
||||
|
||||
# Create text field
|
||||
c.acroForm.textfield(
|
||||
name="name",
|
||||
tooltip="Enter your name",
|
||||
x=100,
|
||||
y=700,
|
||||
width=200,
|
||||
height=20,
|
||||
borderColor=black,
|
||||
fillColor=white,
|
||||
textColor=black,
|
||||
forceBorder=True,
|
||||
fontSize=12,
|
||||
maxlen=100, # Maximum character length
|
||||
)
|
||||
|
||||
# Label
|
||||
c.drawString(100, 725, "Name:")
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Checkboxes
|
||||
|
||||
```python
|
||||
# Create checkbox
|
||||
c.acroForm.checkbox(
|
||||
name="agree",
|
||||
tooltip="I agree to terms",
|
||||
x=100,
|
||||
y=650,
|
||||
size=20,
|
||||
buttonStyle='check', # 'check', 'circle', 'cross', 'diamond', 'square', 'star'
|
||||
borderColor=black,
|
||||
fillColor=white,
|
||||
textColor=black,
|
||||
forceBorder=True,
|
||||
checked=False, # Initial state
|
||||
)
|
||||
|
||||
c.drawString(130, 655, "I agree to the terms and conditions")
|
||||
```
|
||||
|
||||
### Radio Buttons
|
||||
|
||||
```python
|
||||
# Radio button group - only one can be selected
|
||||
c.acroForm.radio(
|
||||
name="payment", # Same name for group
|
||||
tooltip="Credit Card",
|
||||
value="credit", # Value when selected
|
||||
x=100,
|
||||
y=600,
|
||||
size=15,
|
||||
selected=False,
|
||||
)
|
||||
c.drawString(125, 603, "Credit Card")
|
||||
|
||||
c.acroForm.radio(
|
||||
name="payment", # Same name
|
||||
tooltip="PayPal",
|
||||
value="paypal",
|
||||
x=100,
|
||||
y=580,
|
||||
size=15,
|
||||
selected=False,
|
||||
)
|
||||
c.drawString(125, 583, "PayPal")
|
||||
```
|
||||
|
||||
### List Boxes
|
||||
|
||||
```python
|
||||
# Listbox with multiple options
|
||||
c.acroForm.listbox(
|
||||
name="country",
|
||||
tooltip="Select your country",
|
||||
value="US", # Default selected
|
||||
x=100,
|
||||
y=500,
|
||||
width=150,
|
||||
height=80,
|
||||
borderColor=black,
|
||||
fillColor=white,
|
||||
textColor=black,
|
||||
forceBorder=True,
|
||||
options=[
|
||||
("United States", "US"),
|
||||
("Canada", "CA"),
|
||||
("Mexico", "MX"),
|
||||
("Other", "OTHER"),
|
||||
], # List of (label, value) tuples
|
||||
multiple=False, # Allow multiple selections
|
||||
)
|
||||
```
|
||||
|
||||
### Choice (Dropdown)
|
||||
|
||||
```python
|
||||
# Dropdown menu
|
||||
c.acroForm.choice(
|
||||
name="state",
|
||||
tooltip="Select state",
|
||||
value="CA",
|
||||
x=100,
|
||||
y=450,
|
||||
width=150,
|
||||
height=20,
|
||||
borderColor=black,
|
||||
fillColor=white,
|
||||
textColor=black,
|
||||
forceBorder=True,
|
||||
options=[
|
||||
("California", "CA"),
|
||||
("New York", "NY"),
|
||||
("Texas", "TX"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Complete Form Example
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.lib.colors import black, white, lightgrey
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
def create_registration_form(filename):
|
||||
c = canvas.Canvas(filename, pagesize=letter)
|
||||
c.setFont("Helvetica-Bold", 16)
|
||||
c.drawString(inch, 10*inch, "Registration Form")
|
||||
|
||||
y = 9*inch
|
||||
c.setFont("Helvetica", 12)
|
||||
|
||||
# Name field
|
||||
c.drawString(inch, y, "Full Name:")
|
||||
c.acroForm.textfield(
|
||||
name="fullname",
|
||||
x=2*inch, y=y-5, width=4*inch, height=20,
|
||||
borderColor=black, fillColor=lightgrey, forceBorder=True
|
||||
)
|
||||
|
||||
# Email field
|
||||
y -= 0.5*inch
|
||||
c.drawString(inch, y, "Email:")
|
||||
c.acroForm.textfield(
|
||||
name="email",
|
||||
x=2*inch, y=y-5, width=4*inch, height=20,
|
||||
borderColor=black, fillColor=lightgrey, forceBorder=True
|
||||
)
|
||||
|
||||
# Age dropdown
|
||||
y -= 0.5*inch
|
||||
c.drawString(inch, y, "Age Group:")
|
||||
c.acroForm.choice(
|
||||
name="age_group",
|
||||
x=2*inch, y=y-5, width=2*inch, height=20,
|
||||
borderColor=black, fillColor=lightgrey, forceBorder=True,
|
||||
options=[("18-25", "18-25"), ("26-35", "26-35"),
|
||||
("36-50", "36-50"), ("51+", "51+")]
|
||||
)
|
||||
|
||||
# Newsletter checkbox
|
||||
y -= 0.5*inch
|
||||
c.acroForm.checkbox(
|
||||
name="newsletter",
|
||||
x=inch, y=y-5, size=15,
|
||||
buttonStyle='check', borderColor=black, forceBorder=True
|
||||
)
|
||||
c.drawString(inch + 25, y, "Subscribe to newsletter")
|
||||
|
||||
c.save()
|
||||
|
||||
create_registration_form("registration.pdf")
|
||||
```
|
||||
|
||||
## Encryption and Security
|
||||
|
||||
Protect PDFs with passwords and permissions.
|
||||
|
||||
### Basic Encryption
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("secure.pdf")
|
||||
|
||||
# Encrypt with user password
|
||||
c.encrypt(
|
||||
userPassword="user123", # Password to open
|
||||
ownerPassword="owner456", # Password to change permissions
|
||||
canPrint=1, # Allow printing
|
||||
canModify=0, # Disallow modifications
|
||||
canCopy=1, # Allow text copying
|
||||
canAnnotate=0, # Disallow annotations
|
||||
strength=128, # 40 or 128 bit encryption
|
||||
)
|
||||
|
||||
# ... draw content ...
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Permission Settings
|
||||
|
||||
```python
|
||||
c.encrypt(
|
||||
userPassword="user123",
|
||||
ownerPassword="owner456",
|
||||
canPrint=1, # 1 = allow, 0 = deny
|
||||
canModify=0, # Prevent content modification
|
||||
canCopy=1, # Allow text/graphics copying
|
||||
canAnnotate=0, # Prevent comments/annotations
|
||||
strength=128, # Use 128-bit encryption
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Encryption
|
||||
|
||||
```python
|
||||
from reportlab.lib.pdfencrypt import StandardEncryption
|
||||
|
||||
# Create encryption object
|
||||
encrypt = StandardEncryption(
|
||||
userPassword="user123",
|
||||
ownerPassword="owner456",
|
||||
canPrint=1,
|
||||
canModify=0,
|
||||
canCopy=1,
|
||||
canAnnotate=1,
|
||||
strength=128,
|
||||
)
|
||||
|
||||
# Use with canvas
|
||||
c = canvas.Canvas("secure.pdf")
|
||||
c._doc.encrypt = encrypt
|
||||
|
||||
# ... draw content ...
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
### Platypus with Encryption
|
||||
|
||||
```python
|
||||
from reportlab.platypus import SimpleDocTemplate
|
||||
|
||||
doc = SimpleDocTemplate("secure.pdf")
|
||||
|
||||
# Set encryption
|
||||
doc.encrypt = True
|
||||
doc.canPrint = 1
|
||||
doc.canModify = 0
|
||||
|
||||
# Or use encrypt() method
|
||||
doc.encrypt = encrypt_object
|
||||
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Page Transitions
|
||||
|
||||
Add visual effects for presentations.
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("presentation.pdf")
|
||||
|
||||
# Set transition for current page
|
||||
c.setPageTransition(
|
||||
effectname="Wipe", # Transition effect
|
||||
duration=1, # Duration in seconds
|
||||
direction=0 # Direction (effect-specific)
|
||||
)
|
||||
|
||||
# Available effects:
|
||||
# "Split", "Blinds", "Box", "Wipe", "Dissolve",
|
||||
# "Glitter", "R" (Replace), "Fly", "Push", "Cover",
|
||||
# "Uncover", "Fade"
|
||||
|
||||
# Direction values (effect-dependent):
|
||||
# 0, 90, 180, 270 for most directional effects
|
||||
|
||||
# Example: Slide with fade transition
|
||||
c.setFont("Helvetica-Bold", 24)
|
||||
c.drawString(100, 400, "Slide 1")
|
||||
c.setPageTransition("Fade", 0.5)
|
||||
c.showPage()
|
||||
|
||||
c.drawString(100, 400, "Slide 2")
|
||||
c.setPageTransition("Wipe", 1, 90)
|
||||
c.showPage()
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
## PDF/A Compliance
|
||||
|
||||
Create archival-quality PDFs.
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("pdfa.pdf")
|
||||
|
||||
# Enable PDF/A-1b compliance
|
||||
c.setPageCompression(0) # PDF/A requires uncompressed
|
||||
# Note: Full PDF/A requires additional XMP metadata
|
||||
# This is simplified - full compliance needs more setup
|
||||
|
||||
# ... draw content ...
|
||||
|
||||
c.save()
|
||||
```
|
||||
|
||||
## Compression
|
||||
|
||||
Control file size vs generation speed.
|
||||
|
||||
```python
|
||||
# Enable page compression
|
||||
c = canvas.Canvas("output.pdf", pageCompression=1)
|
||||
|
||||
# Compression reduces file size but slows generation
|
||||
# 0 = no compression (faster, larger files)
|
||||
# 1 = compression (slower, smaller files)
|
||||
```
|
||||
|
||||
## Forms and XObjects
|
||||
|
||||
Reusable graphics elements.
|
||||
|
||||
```python
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
c = canvas.Canvas("output.pdf")
|
||||
|
||||
# Begin form (reusable object)
|
||||
c.beginForm("logo")
|
||||
c.setFillColorRGB(0, 0, 1)
|
||||
c.rect(0, 0, 100, 50, fill=1)
|
||||
c.setFillColorRGB(1, 1, 1)
|
||||
c.drawString(10, 20, "LOGO")
|
||||
c.endForm()
|
||||
|
||||
# Use form multiple times
|
||||
c.doForm("logo") # At current position
|
||||
c.translate(200, 0)
|
||||
c.doForm("logo") # At translated position
|
||||
c.translate(200, 0)
|
||||
c.doForm("logo")
|
||||
|
||||
c.save()
|
||||
|
||||
# Benefits: Smaller file size, faster rendering
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always set metadata** for professional documents
|
||||
2. **Use bookmarks** for documents > 10 pages
|
||||
3. **Make links visually distinct** (blue, underlined)
|
||||
4. **Test forms** in multiple PDF readers (behavior varies)
|
||||
5. **Use strong encryption (128-bit)** for sensitive data
|
||||
6. **Set both user and owner passwords** for full security
|
||||
7. **Enable printing** unless specifically restricted
|
||||
8. **Test page transitions** - some readers don't support all effects
|
||||
9. **Use meaningful bookmark titles** for navigation
|
||||
10. **Consider PDF/A** for long-term archival needs
|
||||
11. **Validate form field names** - must be unique and valid identifiers
|
||||
12. **Add tooltips** to form fields for better UX
|
||||
@@ -1,343 +0,0 @@
|
||||
# Platypus Guide - High-Level Page Layout
|
||||
|
||||
Platypus ("Page Layout and Typography Using Scripts") provides high-level document layout for complex, flowing documents with minimal code.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
Platypus uses a layered design:
|
||||
|
||||
1. **DocTemplates** - Document container with page formatting rules
|
||||
2. **PageTemplates** - Specifications for different page layouts
|
||||
3. **Frames** - Regions where content flows
|
||||
4. **Flowables** - Content elements (paragraphs, tables, images, spacers)
|
||||
5. **Canvas** - Underlying rendering engine (usually hidden)
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak
|
||||
from reportlab.lib.styles import getSampleStyleSheet
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Create document
|
||||
doc = SimpleDocTemplate("output.pdf", pagesize=letter,
|
||||
rightMargin=72, leftMargin=72,
|
||||
topMargin=72, bottomMargin=18)
|
||||
|
||||
# Create story (list of flowables)
|
||||
story = []
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Add content
|
||||
story.append(Paragraph("Title", styles['Title']))
|
||||
story.append(Spacer(1, 0.2*inch))
|
||||
story.append(Paragraph("Body text here", styles['BodyText']))
|
||||
story.append(PageBreak())
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### DocTemplates
|
||||
|
||||
#### SimpleDocTemplate
|
||||
Most common template for standard documents:
|
||||
|
||||
```python
|
||||
doc = SimpleDocTemplate(
|
||||
filename,
|
||||
pagesize=letter,
|
||||
rightMargin=72, # 1 inch = 72 points
|
||||
leftMargin=72,
|
||||
topMargin=72,
|
||||
bottomMargin=18,
|
||||
title=None, # PDF metadata
|
||||
author=None,
|
||||
subject=None
|
||||
)
|
||||
```
|
||||
|
||||
#### BaseDocTemplate (Advanced)
|
||||
For complex documents with multiple page layouts:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame
|
||||
from reportlab.lib.pagesizes import letter
|
||||
|
||||
doc = BaseDocTemplate("output.pdf", pagesize=letter)
|
||||
|
||||
# Define frames (content regions)
|
||||
frame1 = Frame(doc.leftMargin, doc.bottomMargin,
|
||||
doc.width/2-6, doc.height, id='col1')
|
||||
frame2 = Frame(doc.leftMargin+doc.width/2+6, doc.bottomMargin,
|
||||
doc.width/2-6, doc.height, id='col2')
|
||||
|
||||
# Create page template
|
||||
template = PageTemplate(id='TwoCol', frames=[frame1, frame2])
|
||||
doc.addPageTemplates([template])
|
||||
|
||||
# Build with story
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
### Frames
|
||||
|
||||
Frames define regions where content flows:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Frame
|
||||
|
||||
frame = Frame(
|
||||
x1, y1, # Lower-left corner
|
||||
width, height, # Dimensions
|
||||
leftPadding=6, # Internal padding
|
||||
bottomPadding=6,
|
||||
rightPadding=6,
|
||||
topPadding=6,
|
||||
id=None, # Optional identifier
|
||||
showBoundary=0 # 1 to show frame border (debugging)
|
||||
)
|
||||
```
|
||||
|
||||
### PageTemplates
|
||||
|
||||
Define page layouts with frames and optional functions:
|
||||
|
||||
```python
|
||||
def header_footer(canvas, doc):
|
||||
"""Called on each page for headers/footers"""
|
||||
canvas.saveState()
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.drawString(inch, 0.75*inch, f"Page {doc.page}")
|
||||
canvas.restoreState()
|
||||
|
||||
template = PageTemplate(
|
||||
id='Normal',
|
||||
frames=[frame],
|
||||
onPage=header_footer, # Function called for each page
|
||||
onPageEnd=None,
|
||||
pagesize=letter
|
||||
)
|
||||
```
|
||||
|
||||
## Flowables
|
||||
|
||||
Flowables are content elements that flow through frames.
|
||||
|
||||
### Common Flowables
|
||||
|
||||
```python
|
||||
from reportlab.platypus import (
|
||||
Paragraph, Spacer, PageBreak, FrameBreak,
|
||||
Image, Table, KeepTogether, CondPageBreak
|
||||
)
|
||||
|
||||
# Spacer - vertical whitespace
|
||||
Spacer(width, height)
|
||||
|
||||
# Page break - force new page
|
||||
PageBreak()
|
||||
|
||||
# Frame break - move to next frame
|
||||
FrameBreak()
|
||||
|
||||
# Conditional page break - break if less than N space remaining
|
||||
CondPageBreak(height)
|
||||
|
||||
# Keep together - prevent splitting across pages
|
||||
KeepTogether([flowable1, flowable2, ...])
|
||||
```
|
||||
|
||||
### Paragraph Flowable
|
||||
See `text_and_fonts.md` for detailed Paragraph usage.
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Paragraph
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
|
||||
style = ParagraphStyle(
|
||||
'CustomStyle',
|
||||
fontSize=12,
|
||||
leading=14,
|
||||
alignment=0 # 0=left, 1=center, 2=right, 4=justify
|
||||
)
|
||||
|
||||
para = Paragraph("Text with <b>bold</b> and <i>italic</i>", style)
|
||||
story.append(para)
|
||||
```
|
||||
|
||||
### Image Flowable
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Image
|
||||
|
||||
# Auto-size to fit
|
||||
img = Image('photo.jpg')
|
||||
|
||||
# Fixed size
|
||||
img = Image('photo.jpg', width=2*inch, height=2*inch)
|
||||
|
||||
# Maintain aspect ratio with max width
|
||||
img = Image('photo.jpg', width=4*inch, height=3*inch,
|
||||
kind='proportional')
|
||||
|
||||
story.append(img)
|
||||
```
|
||||
|
||||
### Table Flowable
|
||||
See `tables_reference.md` for detailed Table usage.
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Table
|
||||
|
||||
data = [['Header1', 'Header2'],
|
||||
['Row1Col1', 'Row1Col2'],
|
||||
['Row2Col1', 'Row2Col2']]
|
||||
|
||||
table = Table(data, colWidths=[2*inch, 2*inch])
|
||||
story.append(table)
|
||||
```
|
||||
|
||||
## Page Layouts
|
||||
|
||||
### Single Column Document
|
||||
|
||||
```python
|
||||
doc = SimpleDocTemplate("output.pdf", pagesize=letter)
|
||||
story = []
|
||||
# Add flowables...
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
### Two-Column Layout
|
||||
|
||||
```python
|
||||
from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame
|
||||
|
||||
doc = BaseDocTemplate("output.pdf", pagesize=letter)
|
||||
width, height = letter
|
||||
margin = inch
|
||||
|
||||
# Two side-by-side frames
|
||||
frame1 = Frame(margin, margin, width/2 - 1.5*margin, height - 2*margin, id='col1')
|
||||
frame2 = Frame(width/2 + 0.5*margin, margin, width/2 - 1.5*margin, height - 2*margin, id='col2')
|
||||
|
||||
template = PageTemplate(id='TwoCol', frames=[frame1, frame2])
|
||||
doc.addPageTemplates([template])
|
||||
|
||||
story = []
|
||||
# Content flows left column first, then right column
|
||||
# Add flowables...
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
### Multiple Page Templates
|
||||
|
||||
```python
|
||||
from reportlab.platypus import NextPageTemplate
|
||||
|
||||
# Define templates
|
||||
cover_template = PageTemplate(id='Cover', frames=[cover_frame])
|
||||
body_template = PageTemplate(id='Body', frames=[body_frame])
|
||||
|
||||
doc.addPageTemplates([cover_template, body_template])
|
||||
|
||||
story = []
|
||||
# Cover page content
|
||||
story.append(Paragraph("Cover", title_style))
|
||||
story.append(NextPageTemplate('Body')) # Switch to body template
|
||||
story.append(PageBreak())
|
||||
|
||||
# Body content
|
||||
story.append(Paragraph("Chapter 1", heading_style))
|
||||
# ...
|
||||
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Headers and Footers
|
||||
|
||||
Headers and footers are added via `onPage` callback functions:
|
||||
|
||||
```python
|
||||
def header_footer(canvas, doc):
|
||||
"""Draw header and footer on each page"""
|
||||
canvas.saveState()
|
||||
|
||||
# Header
|
||||
canvas.setFont('Helvetica-Bold', 12)
|
||||
canvas.drawCentredString(letter[0]/2, letter[1] - 0.5*inch,
|
||||
"Document Title")
|
||||
|
||||
# Footer
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.drawString(inch, 0.75*inch, "Left Footer")
|
||||
canvas.drawRightString(letter[0] - inch, 0.75*inch,
|
||||
f"Page {doc.page}")
|
||||
|
||||
canvas.restoreState()
|
||||
|
||||
# Apply to template
|
||||
template = PageTemplate(id='Normal', frames=[frame], onPage=header_footer)
|
||||
```
|
||||
|
||||
## Table of Contents
|
||||
|
||||
```python
|
||||
from reportlab.platypus import TableOfContents
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
|
||||
# Create TOC
|
||||
toc = TableOfContents()
|
||||
toc.levelStyles = [
|
||||
ParagraphStyle(name='TOC1', fontSize=14, leftIndent=0),
|
||||
ParagraphStyle(name='TOC2', fontSize=12, leftIndent=20),
|
||||
]
|
||||
|
||||
story = []
|
||||
story.append(toc)
|
||||
story.append(PageBreak())
|
||||
|
||||
# Add entries
|
||||
story.append(Paragraph("Chapter 1<a name='ch1'/>", heading_style))
|
||||
toc.addEntry(0, "Chapter 1", doc.page, 'ch1')
|
||||
|
||||
# Must call build twice for TOC to populate
|
||||
doc.build(story)
|
||||
```
|
||||
|
||||
## Document Properties
|
||||
|
||||
```python
|
||||
from reportlab.lib.pagesizes import letter, A4
|
||||
from reportlab.lib.units import inch, cm, mm
|
||||
|
||||
# Page sizes
|
||||
letter # US Letter (8.5" x 11")
|
||||
A4 # ISO A4 (210mm x 297mm)
|
||||
landscape(letter) # Rotate to landscape
|
||||
|
||||
# Units
|
||||
inch # 72 points
|
||||
cm # 28.35 points
|
||||
mm # 2.835 points
|
||||
|
||||
# Custom page size
|
||||
custom_size = (6*inch, 9*inch)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use SimpleDocTemplate** for most documents - it handles common layouts
|
||||
2. **Build story list** completely before calling `doc.build(story)`
|
||||
3. **Use Spacer** for vertical spacing instead of empty Paragraphs
|
||||
4. **Group related content** with KeepTogether to prevent awkward page breaks
|
||||
5. **Test page breaks** early with realistic content amounts
|
||||
6. **Use styles consistently** - create style once, reuse throughout document
|
||||
7. **Set showBoundary=1** on Frames during development to visualize layout
|
||||
8. **Headers/footers go in onPage** callback, not in story
|
||||
9. **For long documents**, use BaseDocTemplate with multiple page templates
|
||||
10. **Build TOC documents twice** to properly populate table of contents
|
||||
@@ -1,442 +0,0 @@
|
||||
# Tables Reference
|
||||
|
||||
Comprehensive guide to creating and styling tables in ReportLab.
|
||||
|
||||
## Basic Table Creation
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Table, TableStyle
|
||||
from reportlab.lib import colors
|
||||
|
||||
# Simple data (list of lists or tuples)
|
||||
data = [
|
||||
['Header 1', 'Header 2', 'Header 3'],
|
||||
['Row 1, Col 1', 'Row 1, Col 2', 'Row 1, Col 3'],
|
||||
['Row 2, Col 1', 'Row 2, Col 2', 'Row 2, Col 3'],
|
||||
]
|
||||
|
||||
# Create table
|
||||
table = Table(data)
|
||||
|
||||
# Add to story
|
||||
story.append(table)
|
||||
```
|
||||
|
||||
## Table Constructor
|
||||
|
||||
```python
|
||||
table = Table(
|
||||
data, # Required: list of lists/tuples
|
||||
colWidths=None, # List of column widths or single value
|
||||
rowHeights=None, # List of row heights or single value
|
||||
style=None, # TableStyle object
|
||||
splitByRow=1, # Split across pages by rows (not columns)
|
||||
repeatRows=0, # Number of header rows to repeat
|
||||
repeatCols=0, # Number of header columns to repeat
|
||||
rowSplitRange=None, # Tuple (start, end) of splittable rows
|
||||
spaceBefore=None, # Space before table
|
||||
spaceAfter=None, # Space after table
|
||||
cornerRadii=None, # [TL, TR, BL, BR] for rounded corners
|
||||
)
|
||||
```
|
||||
|
||||
### Column Widths
|
||||
|
||||
```python
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Equal widths
|
||||
table = Table(data, colWidths=2*inch)
|
||||
|
||||
# Different widths per column
|
||||
table = Table(data, colWidths=[1.5*inch, 2*inch, 1*inch])
|
||||
|
||||
# Auto-calculate widths (default)
|
||||
table = Table(data)
|
||||
|
||||
# Percentage-based (of available width)
|
||||
table = Table(data, colWidths=[None, None, None]) # Equal auto-sizing
|
||||
```
|
||||
|
||||
## Cell Content Types
|
||||
|
||||
### Text and Newlines
|
||||
|
||||
```python
|
||||
# Newlines work in cells
|
||||
data = [
|
||||
['Line 1\nLine 2', 'Single line'],
|
||||
['Another\nmulti-line\ncell', 'Text'],
|
||||
]
|
||||
```
|
||||
|
||||
### Paragraph Objects
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Paragraph
|
||||
from reportlab.lib.styles import getSampleStyleSheet
|
||||
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
data = [
|
||||
[Paragraph("Formatted <b>bold</b> text", styles['Normal']),
|
||||
Paragraph("More <i>italic</i> text", styles['Normal'])],
|
||||
]
|
||||
|
||||
table = Table(data)
|
||||
```
|
||||
|
||||
### Images
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Image
|
||||
|
||||
data = [
|
||||
['Description', Image('logo.png', width=1*inch, height=1*inch)],
|
||||
['Product', Image('product.jpg', width=2*inch, height=1.5*inch)],
|
||||
]
|
||||
|
||||
table = Table(data)
|
||||
```
|
||||
|
||||
### Nested Tables
|
||||
|
||||
```python
|
||||
# Create inner table
|
||||
inner_data = [['A', 'B'], ['C', 'D']]
|
||||
inner_table = Table(inner_data)
|
||||
|
||||
# Use in outer table
|
||||
outer_data = [
|
||||
['Label', inner_table],
|
||||
['Other', 'Content'],
|
||||
]
|
||||
|
||||
outer_table = Table(outer_data)
|
||||
```
|
||||
|
||||
## TableStyle
|
||||
|
||||
Styles are applied using command lists:
|
||||
|
||||
```python
|
||||
from reportlab.platypus import TableStyle
|
||||
from reportlab.lib import colors
|
||||
|
||||
style = TableStyle([
|
||||
# Command format: ('COMMAND', (startcol, startrow), (endcol, endrow), *args)
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black), # Grid over all cells
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.grey), # Header background
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), # Header text color
|
||||
])
|
||||
|
||||
table = Table(data)
|
||||
table.setStyle(style)
|
||||
```
|
||||
|
||||
### Cell Coordinate System
|
||||
|
||||
- Columns and rows are 0-indexed: `(col, row)`
|
||||
- Negative indices count from end: `-1` is last column/row
|
||||
- `(0, 0)` is top-left cell
|
||||
- `(-1, -1)` is bottom-right cell
|
||||
|
||||
```python
|
||||
# Examples:
|
||||
(0, 0), (2, 0) # First three cells of header row
|
||||
(0, 1), (-1, -1) # All cells except header
|
||||
(0, 0), (-1, -1) # Entire table
|
||||
```
|
||||
|
||||
## Styling Commands
|
||||
|
||||
### Text Formatting
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Font name
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
|
||||
# Font size
|
||||
('FONTSIZE', (0, 0), (-1, -1), 10),
|
||||
|
||||
# Text color
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('TEXTCOLOR', (0, 1), (-1, -1), colors.black),
|
||||
|
||||
# Combined font command
|
||||
('FONT', (0, 0), (-1, 0), 'Helvetica-Bold', 12), # name, size
|
||||
])
|
||||
```
|
||||
|
||||
### Alignment
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Horizontal alignment: LEFT, CENTER, RIGHT, DECIMAL
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('ALIGN', (0, 1), (0, -1), 'LEFT'), # First column left
|
||||
('ALIGN', (1, 1), (-1, -1), 'RIGHT'), # Other columns right
|
||||
|
||||
# Vertical alignment: TOP, MIDDLE, BOTTOM
|
||||
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
|
||||
('VALIGN', (0, 0), (-1, 0), 'BOTTOM'), # Header bottom-aligned
|
||||
])
|
||||
```
|
||||
|
||||
### Cell Padding
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Individual padding
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 12),
|
||||
('RIGHTPADDING', (0, 0), (-1, -1), 12),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
|
||||
# Or set all at once by setting each
|
||||
])
|
||||
```
|
||||
|
||||
### Background Colors
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Solid background
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.blue),
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.lightgrey),
|
||||
|
||||
# Alternating row colors
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightblue]),
|
||||
|
||||
# Alternating column colors
|
||||
('COLBACKGROUNDS', (0, 0), (-1, -1), [colors.white, colors.lightgrey]),
|
||||
])
|
||||
```
|
||||
|
||||
### Gradient Backgrounds
|
||||
|
||||
```python
|
||||
from reportlab.lib.colors import Color
|
||||
|
||||
style = TableStyle([
|
||||
# Vertical gradient (top to bottom)
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.blue),
|
||||
('VERTICALGRADIENT', (0, 0), (-1, 0),
|
||||
[colors.blue, colors.lightblue]),
|
||||
|
||||
# Horizontal gradient (left to right)
|
||||
('HORIZONTALGRADIENT', (0, 1), (-1, 1),
|
||||
[colors.red, colors.yellow]),
|
||||
])
|
||||
```
|
||||
|
||||
### Lines and Borders
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Complete grid
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
||||
|
||||
# Box/outline only
|
||||
('BOX', (0, 0), (-1, -1), 2, colors.black),
|
||||
('OUTLINE', (0, 0), (-1, -1), 2, colors.black), # Same as BOX
|
||||
|
||||
# Inner grid only
|
||||
('INNERGRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
|
||||
# Directional lines
|
||||
('LINEABOVE', (0, 0), (-1, 0), 2, colors.black), # Header border
|
||||
('LINEBELOW', (0, 0), (-1, 0), 1, colors.black), # Header bottom
|
||||
('LINEBEFORE', (0, 0), (0, -1), 1, colors.black), # Left border
|
||||
('LINEAFTER', (-1, 0), (-1, -1), 1, colors.black), # Right border
|
||||
|
||||
# Thickness and color
|
||||
('LINEABOVE', (0, 1), (-1, 1), 0.5, colors.grey), # Thin grey line
|
||||
])
|
||||
```
|
||||
|
||||
### Cell Spanning
|
||||
|
||||
```python
|
||||
data = [
|
||||
['Spanning Header', '', ''], # Span will merge these
|
||||
['A', 'B', 'C'],
|
||||
['D', 'E', 'F'],
|
||||
]
|
||||
|
||||
style = TableStyle([
|
||||
# Span 3 columns in first row
|
||||
('SPAN', (0, 0), (2, 0)),
|
||||
|
||||
# Center the spanning cell
|
||||
('ALIGN', (0, 0), (2, 0), 'CENTER'),
|
||||
])
|
||||
|
||||
table = Table(data)
|
||||
table.setStyle(style)
|
||||
```
|
||||
|
||||
**Important:** Cells that are spanned over must contain empty strings `''`.
|
||||
|
||||
### Advanced Spanning Examples
|
||||
|
||||
```python
|
||||
# Span multiple rows and columns
|
||||
data = [
|
||||
['A', 'B', 'B', 'C'],
|
||||
['A', 'D', 'E', 'F'],
|
||||
['A', 'G', 'H', 'I'],
|
||||
]
|
||||
|
||||
style = TableStyle([
|
||||
# Span rows in column 0
|
||||
('SPAN', (0, 0), (0, 2)), # Merge A cells vertically
|
||||
|
||||
# Span columns in row 0
|
||||
('SPAN', (1, 0), (2, 0)), # Merge B cells horizontally
|
||||
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
||||
])
|
||||
```
|
||||
|
||||
## Special Commands
|
||||
|
||||
### Rounded Corners
|
||||
|
||||
```python
|
||||
table = Table(data, cornerRadii=[5, 5, 5, 5]) # [TL, TR, BL, BR]
|
||||
|
||||
# Or in style
|
||||
style = TableStyle([
|
||||
('ROUNDEDCORNERS', [10, 10, 0, 0]), # Rounded top corners only
|
||||
])
|
||||
```
|
||||
|
||||
### No Split
|
||||
|
||||
Prevent table from splitting at specific locations:
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Don't split between rows 0 and 2
|
||||
('NOSPLIT', (0, 0), (-1, 2)),
|
||||
])
|
||||
```
|
||||
|
||||
### Split-Specific Styling
|
||||
|
||||
Apply styles only to first or last part when table splits:
|
||||
|
||||
```python
|
||||
style = TableStyle([
|
||||
# Style for first part after split
|
||||
('LINEBELOW', (0, 'splitfirst'), (-1, 'splitfirst'), 2, colors.red),
|
||||
|
||||
# Style for last part after split
|
||||
('LINEABOVE', (0, 'splitlast'), (-1, 'splitlast'), 2, colors.blue),
|
||||
])
|
||||
```
|
||||
|
||||
## Repeating Headers
|
||||
|
||||
```python
|
||||
# Repeat first row on each page
|
||||
table = Table(data, repeatRows=1)
|
||||
|
||||
# Repeat first 2 rows
|
||||
table = Table(data, repeatRows=2)
|
||||
```
|
||||
|
||||
## Complete Examples
|
||||
|
||||
### Styled Report Table
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Table, TableStyle
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
data = [
|
||||
['Product', 'Quantity', 'Unit Price', 'Total'],
|
||||
['Widget A', '10', '$5.00', '$50.00'],
|
||||
['Widget B', '5', '$12.00', '$60.00'],
|
||||
['Widget C', '20', '$3.00', '$60.00'],
|
||||
['', '', 'Subtotal:', '$170.00'],
|
||||
]
|
||||
|
||||
table = Table(data, colWidths=[2.5*inch, 1*inch, 1*inch, 1*inch])
|
||||
|
||||
style = TableStyle([
|
||||
# Header row
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.darkblue),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 12),
|
||||
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
|
||||
# Data rows
|
||||
('BACKGROUND', (0, 1), (-1, -2), colors.beige),
|
||||
('GRID', (0, 0), (-1, -2), 0.5, colors.grey),
|
||||
('ALIGN', (1, 1), (-1, -1), 'RIGHT'),
|
||||
('ALIGN', (0, 1), (0, -1), 'LEFT'),
|
||||
|
||||
# Total row
|
||||
('BACKGROUND', (0, -1), (-1, -1), colors.lightgrey),
|
||||
('LINEABOVE', (0, -1), (-1, -1), 2, colors.black),
|
||||
('FONTNAME', (2, -1), (-1, -1), 'Helvetica-Bold'),
|
||||
])
|
||||
|
||||
table.setStyle(style)
|
||||
```
|
||||
|
||||
### Alternating Row Colors
|
||||
|
||||
```python
|
||||
data = [
|
||||
['Name', 'Age', 'City'],
|
||||
['Alice', '30', 'New York'],
|
||||
['Bob', '25', 'Boston'],
|
||||
['Charlie', '35', 'Chicago'],
|
||||
['Diana', '28', 'Denver'],
|
||||
]
|
||||
|
||||
table = Table(data, colWidths=[2*inch, 1*inch, 1.5*inch])
|
||||
|
||||
style = TableStyle([
|
||||
# Header
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.darkslategray),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
|
||||
# Alternating rows (zebra striping)
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1),
|
||||
[colors.white, colors.lightgrey]),
|
||||
|
||||
# Borders
|
||||
('BOX', (0, 0), (-1, -1), 2, colors.black),
|
||||
('LINEBELOW', (0, 0), (-1, 0), 2, colors.black),
|
||||
|
||||
# Padding
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 12),
|
||||
('RIGHTPADDING', (0, 0), (-1, -1), 12),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 6),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
|
||||
])
|
||||
|
||||
table.setStyle(style)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Set colWidths explicitly** for consistent layout
|
||||
2. **Use repeatRows** for multi-page tables with headers
|
||||
3. **Apply padding** for better readability (especially LEFTPADDING and RIGHTPADDING)
|
||||
4. **Use ROWBACKGROUNDS** for alternating colors instead of styling each row
|
||||
5. **Put empty strings** in cells that will be spanned
|
||||
6. **Test page breaks** early with realistic data amounts
|
||||
7. **Use Paragraph objects** in cells for complex formatted text
|
||||
8. **Set VALIGN to MIDDLE** for better appearance with varying row heights
|
||||
9. **Keep tables simple** - complex nested tables are hard to maintain
|
||||
10. **Use consistent styling** - define once, apply to all tables
|
||||
@@ -1,394 +0,0 @@
|
||||
# Text and Fonts Reference
|
||||
|
||||
Comprehensive guide to text formatting, paragraph styles, and font handling in ReportLab.
|
||||
|
||||
## Text Encoding
|
||||
|
||||
**IMPORTANT:** All text input should be UTF-8 encoded or Python Unicode objects (since ReportLab 2.0).
|
||||
|
||||
```python
|
||||
# Correct - UTF-8 strings
|
||||
text = "Hello 世界 مرحبا"
|
||||
para = Paragraph(text, style)
|
||||
|
||||
# For legacy data, convert first
|
||||
import codecs
|
||||
decoded_text = codecs.decode(legacy_bytes, 'latin-1')
|
||||
```
|
||||
|
||||
## Paragraph Styles
|
||||
|
||||
### Creating Styles
|
||||
|
||||
```python
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY
|
||||
from reportlab.lib.colors import black, blue, red
|
||||
from reportlab.lib.units import inch
|
||||
|
||||
# Get default styles
|
||||
styles = getSampleStyleSheet()
|
||||
normal = styles['Normal']
|
||||
heading = styles['Heading1']
|
||||
|
||||
# Create custom style
|
||||
custom_style = ParagraphStyle(
|
||||
'CustomStyle',
|
||||
parent=normal, # Inherit from another style
|
||||
|
||||
# Font properties
|
||||
fontName='Helvetica',
|
||||
fontSize=12,
|
||||
leading=14, # Line spacing (should be > fontSize)
|
||||
|
||||
# Indentation (in points)
|
||||
leftIndent=0,
|
||||
rightIndent=0,
|
||||
firstLineIndent=0, # Positive = indent, negative = outdent
|
||||
|
||||
# Spacing
|
||||
spaceBefore=0,
|
||||
spaceAfter=0,
|
||||
|
||||
# Alignment
|
||||
alignment=TA_LEFT, # TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY
|
||||
|
||||
# Colors
|
||||
textColor=black,
|
||||
backColor=None, # Background color
|
||||
|
||||
# Borders
|
||||
borderWidth=0,
|
||||
borderColor=None,
|
||||
borderPadding=0,
|
||||
borderRadius=None,
|
||||
|
||||
# Bullets
|
||||
bulletFontName='Helvetica',
|
||||
bulletFontSize=12,
|
||||
bulletIndent=0,
|
||||
bulletText=None, # Text for bullets (e.g., '•')
|
||||
|
||||
# Advanced
|
||||
wordWrap=None, # 'CJK' for Asian languages
|
||||
allowWidows=1, # Allow widow lines
|
||||
allowOrphans=0, # Prevent orphan lines
|
||||
endDots=None, # Trailing dots for TOC entries
|
||||
splitLongWords=1,
|
||||
hyphenationLang=None, # 'en_US', etc. (requires pyphen)
|
||||
)
|
||||
|
||||
# Add to stylesheet
|
||||
styles.add(custom_style)
|
||||
```
|
||||
|
||||
### Built-in Styles
|
||||
|
||||
```python
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
# Common styles
|
||||
styles['Normal'] # Body text
|
||||
styles['BodyText'] # Similar to Normal
|
||||
styles['Heading1'] # Top-level heading
|
||||
styles['Heading2'] # Second-level heading
|
||||
styles['Heading3'] # Third-level heading
|
||||
styles['Title'] # Document title
|
||||
styles['Bullet'] # Bulleted list items
|
||||
styles['Definition'] # Definition text
|
||||
styles['Code'] # Code samples
|
||||
```
|
||||
|
||||
## Paragraph Formatting
|
||||
|
||||
### Basic Paragraph
|
||||
|
||||
```python
|
||||
from reportlab.platypus import Paragraph
|
||||
|
||||
para = Paragraph("This is a paragraph.", style)
|
||||
story.append(para)
|
||||
```
|
||||
|
||||
### Inline Formatting Tags
|
||||
|
||||
```python
|
||||
text = """
|
||||
<b>Bold text</b>
|
||||
<i>Italic text</i>
|
||||
<u>Underlined text</u>
|
||||
<strike>Strikethrough text</strike>
|
||||
<strong>Strong (bold) text</strong>
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Font Control
|
||||
|
||||
```python
|
||||
text = """
|
||||
<font face="Courier" size="14" color="blue">
|
||||
Custom font, size, and color
|
||||
</font>
|
||||
|
||||
<font color="#FF0000">Hex color codes work too</font>
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Superscripts and Subscripts
|
||||
|
||||
```python
|
||||
text = """
|
||||
H<sub>2</sub>O is water.
|
||||
E=mc<super>2</super> or E=mc<sup>2</sup>
|
||||
X<sub><i>i</i></sub> for subscripted variables
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Greek Letters
|
||||
|
||||
```python
|
||||
text = """
|
||||
<greek>alpha</greek>, <greek>beta</greek>, <greek>gamma</greek>
|
||||
<greek>epsilon</greek>, <greek>pi</greek>, <greek>omega</greek>
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Links
|
||||
|
||||
```python
|
||||
# External link
|
||||
text = '<link href="https://example.com" color="blue">Click here</link>'
|
||||
|
||||
# Internal link (to bookmark)
|
||||
text = '<link href="#section1" color="blue">Go to Section 1</link>'
|
||||
|
||||
# Anchor for internal links
|
||||
text = '<a name="section1"/>Section 1 Heading'
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Inline Images
|
||||
|
||||
```python
|
||||
text = """
|
||||
Here is an inline image: <img src="icon.png" width="12" height="12" valign="middle"/>
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
### Line Breaks
|
||||
|
||||
```python
|
||||
text = """
|
||||
First line<br/>
|
||||
Second line<br/>
|
||||
Third line
|
||||
"""
|
||||
|
||||
para = Paragraph(text, normal_style)
|
||||
```
|
||||
|
||||
## Font Handling
|
||||
|
||||
### Standard Fonts
|
||||
|
||||
ReportLab includes 14 standard PDF fonts (no embedding needed):
|
||||
|
||||
```python
|
||||
# Helvetica family
|
||||
'Helvetica'
|
||||
'Helvetica-Bold'
|
||||
'Helvetica-Oblique'
|
||||
'Helvetica-BoldOblique'
|
||||
|
||||
# Times family
|
||||
'Times-Roman'
|
||||
'Times-Bold'
|
||||
'Times-Italic'
|
||||
'Times-BoldItalic'
|
||||
|
||||
# Courier family
|
||||
'Courier'
|
||||
'Courier-Bold'
|
||||
'Courier-Oblique'
|
||||
'Courier-BoldOblique'
|
||||
|
||||
# Symbol and Dingbats
|
||||
'Symbol'
|
||||
'ZapfDingbats'
|
||||
```
|
||||
|
||||
### TrueType Fonts
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
|
||||
# Register single font
|
||||
pdfmetrics.registerFont(TTFont('CustomFont', 'CustomFont.ttf'))
|
||||
|
||||
# Use in Canvas
|
||||
canvas.setFont('CustomFont', 12)
|
||||
|
||||
# Use in Paragraph style
|
||||
style = ParagraphStyle('Custom', fontName='CustomFont', fontSize=12)
|
||||
```
|
||||
|
||||
### Font Families
|
||||
|
||||
Register related fonts as a family for bold/italic support:
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
from reportlab.lib.fonts import addMapping
|
||||
|
||||
# Register fonts
|
||||
pdfmetrics.registerFont(TTFont('Vera', 'Vera.ttf'))
|
||||
pdfmetrics.registerFont(TTFont('VeraBd', 'VeraBd.ttf'))
|
||||
pdfmetrics.registerFont(TTFont('VeraIt', 'VeraIt.ttf'))
|
||||
pdfmetrics.registerFont(TTFont('VeraBI', 'VeraBI.ttf'))
|
||||
|
||||
# Map family (normal, bold, italic, bold-italic)
|
||||
addMapping('Vera', 0, 0, 'Vera') # normal
|
||||
addMapping('Vera', 1, 0, 'VeraBd') # bold
|
||||
addMapping('Vera', 0, 1, 'VeraIt') # italic
|
||||
addMapping('Vera', 1, 1, 'VeraBI') # bold-italic
|
||||
|
||||
# Now <b> and <i> tags work with this family
|
||||
style = ParagraphStyle('VeraStyle', fontName='Vera', fontSize=12)
|
||||
para = Paragraph("Normal <b>Bold</b> <i>Italic</i> <b><i>Both</i></b>", style)
|
||||
```
|
||||
|
||||
### Font Search Paths
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase.ttfonts import TTFSearchPath
|
||||
|
||||
# Add custom font directory
|
||||
TTFSearchPath.append('/path/to/fonts/')
|
||||
|
||||
# Now fonts in this directory can be found by name
|
||||
pdfmetrics.registerFont(TTFont('MyFont', 'MyFont.ttf'))
|
||||
```
|
||||
|
||||
### Asian Language Support
|
||||
|
||||
#### Using Adobe Language Packs (no embedding)
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||||
|
||||
# Register CID fonts
|
||||
pdfmetrics.registerFont(UnicodeCIDFont('HeiseiMin-W3')) # Japanese
|
||||
pdfmetrics.registerFont(UnicodeCIDFont('STSong-Light')) # Chinese (Simplified)
|
||||
pdfmetrics.registerFont(UnicodeCIDFont('MSung-Light')) # Chinese (Traditional)
|
||||
pdfmetrics.registerFont(UnicodeCIDFont('HYSMyeongJo-Medium')) # Korean
|
||||
|
||||
# Use in styles
|
||||
style = ParagraphStyle('Japanese', fontName='HeiseiMin-W3', fontSize=12)
|
||||
para = Paragraph("日本語テキスト", style)
|
||||
```
|
||||
|
||||
#### Using TrueType Fonts with Asian Characters
|
||||
|
||||
```python
|
||||
# Register TrueType font with full Unicode support
|
||||
pdfmetrics.registerFont(TTFont('SimSun', 'simsun.ttc'))
|
||||
|
||||
style = ParagraphStyle('Chinese', fontName='SimSun', fontSize=12, wordWrap='CJK')
|
||||
para = Paragraph("中文文本", style)
|
||||
```
|
||||
|
||||
Note: Set `wordWrap='CJK'` for proper line breaking in Asian languages.
|
||||
|
||||
## Numbering and Sequences
|
||||
|
||||
Auto-numbering using `<seq>` tags:
|
||||
|
||||
```python
|
||||
# Simple numbering
|
||||
text = "<seq id='chapter'/> Introduction" # Outputs: 1 Introduction
|
||||
text = "<seq id='chapter'/> Methods" # Outputs: 2 Methods
|
||||
|
||||
# Reset counter
|
||||
text = "<seq id='figure' reset='yes'/>"
|
||||
|
||||
# Formatting templates
|
||||
text = "Figure <seq template='%(chapter)s-%(figure+)s' id='figure'/>"
|
||||
# Outputs: Figure 1-1, Figure 1-2, etc.
|
||||
|
||||
# Multi-level numbering
|
||||
text = "Section <seq template='%(chapter)s.%(section+)s' id='section'/>"
|
||||
```
|
||||
|
||||
## Bullets and Lists
|
||||
|
||||
### Using Bullet Style
|
||||
|
||||
```python
|
||||
bullet_style = ParagraphStyle(
|
||||
'Bullet',
|
||||
parent=normal_style,
|
||||
leftIndent=20,
|
||||
bulletIndent=10,
|
||||
bulletText='•', # Unicode bullet
|
||||
bulletFontName='Helvetica',
|
||||
)
|
||||
|
||||
story.append(Paragraph("First item", bullet_style))
|
||||
story.append(Paragraph("Second item", bullet_style))
|
||||
story.append(Paragraph("Third item", bullet_style))
|
||||
```
|
||||
|
||||
### Custom Bullet Characters
|
||||
|
||||
```python
|
||||
# Different bullet styles
|
||||
bulletText='•' # Filled circle
|
||||
bulletText='◦' # Open circle
|
||||
bulletText='▪' # Square
|
||||
bulletText='▸' # Triangle
|
||||
bulletText='→' # Arrow
|
||||
bulletText='1.' # Numbers
|
||||
bulletText='a)' # Letters
|
||||
```
|
||||
|
||||
## Text Measurement
|
||||
|
||||
```python
|
||||
from reportlab.pdfbase.pdfmetrics import stringWidth
|
||||
|
||||
# Measure string width
|
||||
width = stringWidth("Hello World", "Helvetica", 12)
|
||||
|
||||
# Check if text fits in available width
|
||||
max_width = 200
|
||||
if stringWidth(text, font_name, font_size) > max_width:
|
||||
# Text is too wide
|
||||
pass
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use UTF-8** for text input
|
||||
2. **Set leading > fontSize** for readability (typically fontSize + 2)
|
||||
3. **Register font families** for proper bold/italic support
|
||||
4. **Escape HTML** if displaying user content: use `<` for < and `>` for >
|
||||
5. **Use getSampleStyleSheet()** as a starting point, don't create all styles from scratch
|
||||
6. **Test Asian fonts** early if supporting multi-language content
|
||||
7. **Set wordWrap='CJK'** for Chinese/Japanese/Korean text
|
||||
8. **Use stringWidth()** to check if text fits before rendering
|
||||
9. **Define styles once** at document start, reuse throughout
|
||||
10. **Enable hyphenation** for justified text: `hyphenationLang='en_US'` (requires pyphen package)
|
||||
@@ -1,229 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick Document Generator - Helper for creating simple ReportLab documents
|
||||
|
||||
This script provides utility functions for quickly creating common document types
|
||||
without writing boilerplate code.
|
||||
"""
|
||||
|
||||
from reportlab.lib.pagesizes import letter, A4
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
from reportlab.lib.units import inch
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import (
|
||||
SimpleDocTemplate, Paragraph, Spacer, PageBreak,
|
||||
Table, TableStyle, Image, KeepTogether
|
||||
)
|
||||
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT, TA_JUSTIFY
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def create_simple_document(filename, title, author="", content_blocks=None, pagesize=letter):
|
||||
"""
|
||||
Create a simple document with title and content blocks.
|
||||
|
||||
Args:
|
||||
filename: Output PDF filename
|
||||
title: Document title
|
||||
author: Document author (optional)
|
||||
content_blocks: List of dicts with 'type' and 'content' keys
|
||||
type can be: 'heading', 'paragraph', 'bullet', 'space'
|
||||
pagesize: Page size (default: letter)
|
||||
|
||||
Example content_blocks:
|
||||
[
|
||||
{'type': 'heading', 'content': 'Introduction'},
|
||||
{'type': 'paragraph', 'content': 'This is a paragraph.'},
|
||||
{'type': 'bullet', 'content': 'Bullet point item'},
|
||||
{'type': 'space', 'height': 0.2}, # height in inches
|
||||
]
|
||||
"""
|
||||
if content_blocks is None:
|
||||
content_blocks = []
|
||||
|
||||
# Create document
|
||||
doc = SimpleDocTemplate(
|
||||
filename,
|
||||
pagesize=pagesize,
|
||||
rightMargin=72,
|
||||
leftMargin=72,
|
||||
topMargin=72,
|
||||
bottomMargin=18,
|
||||
title=title,
|
||||
author=author
|
||||
)
|
||||
|
||||
# Get styles
|
||||
styles = getSampleStyleSheet()
|
||||
story = []
|
||||
|
||||
# Add title
|
||||
story.append(Paragraph(title, styles['Title']))
|
||||
story.append(Spacer(1, 0.3*inch))
|
||||
|
||||
# Process content blocks
|
||||
for block in content_blocks:
|
||||
block_type = block.get('type', 'paragraph')
|
||||
content = block.get('content', '')
|
||||
|
||||
if block_type == 'heading':
|
||||
story.append(Paragraph(content, styles['Heading1']))
|
||||
story.append(Spacer(1, 0.1*inch))
|
||||
|
||||
elif block_type == 'heading2':
|
||||
story.append(Paragraph(content, styles['Heading2']))
|
||||
story.append(Spacer(1, 0.1*inch))
|
||||
|
||||
elif block_type == 'paragraph':
|
||||
story.append(Paragraph(content, styles['BodyText']))
|
||||
story.append(Spacer(1, 0.1*inch))
|
||||
|
||||
elif block_type == 'bullet':
|
||||
story.append(Paragraph(content, styles['Bullet']))
|
||||
|
||||
elif block_type == 'space':
|
||||
height = block.get('height', 0.2)
|
||||
story.append(Spacer(1, height*inch))
|
||||
|
||||
elif block_type == 'pagebreak':
|
||||
story.append(PageBreak())
|
||||
|
||||
# Build PDF
|
||||
doc.build(story)
|
||||
return filename
|
||||
|
||||
|
||||
def create_styled_table(data, col_widths=None, style_name='default'):
|
||||
"""
|
||||
Create a styled table with common styling presets.
|
||||
|
||||
Args:
|
||||
data: List of lists containing table data
|
||||
col_widths: List of column widths (None for auto)
|
||||
style_name: 'default', 'striped', 'minimal', 'report'
|
||||
|
||||
Returns:
|
||||
Table object ready to add to story
|
||||
"""
|
||||
table = Table(data, colWidths=col_widths)
|
||||
|
||||
if style_name == 'striped':
|
||||
style = TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.darkblue),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 12),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
])
|
||||
|
||||
elif style_name == 'minimal':
|
||||
style = TableStyle([
|
||||
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('LINEABOVE', (0, 0), (-1, 0), 2, colors.black),
|
||||
('LINEBELOW', (0, 0), (-1, 0), 1, colors.black),
|
||||
('LINEBELOW', (0, -1), (-1, -1), 2, colors.black),
|
||||
])
|
||||
|
||||
elif style_name == 'report':
|
||||
style = TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.black),
|
||||
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 11),
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.beige),
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.grey),
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 12),
|
||||
('RIGHTPADDING', (0, 0), (-1, -1), 12),
|
||||
])
|
||||
|
||||
else: # default
|
||||
style = TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
||||
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
|
||||
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
|
||||
('FONTSIZE', (0, 0), (-1, 0), 12),
|
||||
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
|
||||
('BACKGROUND', (0, 1), (-1, -1), colors.white),
|
||||
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
||||
])
|
||||
|
||||
table.setStyle(style)
|
||||
return table
|
||||
|
||||
|
||||
def add_header_footer(canvas, doc, header_text="", footer_text=""):
|
||||
"""
|
||||
Callback function to add headers and footers to each page.
|
||||
|
||||
Usage:
|
||||
from functools import partial
|
||||
callback = partial(add_header_footer, header_text="My Document", footer_text="Confidential")
|
||||
template = PageTemplate(id='normal', frames=[frame], onPage=callback)
|
||||
"""
|
||||
canvas.saveState()
|
||||
|
||||
# Header
|
||||
if header_text:
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.drawString(inch, doc.pagesize[1] - 0.5*inch, header_text)
|
||||
|
||||
# Footer
|
||||
if footer_text:
|
||||
canvas.setFont('Helvetica', 9)
|
||||
canvas.drawString(inch, 0.5*inch, footer_text)
|
||||
|
||||
# Page number
|
||||
canvas.drawRightString(doc.pagesize[0] - inch, 0.5*inch, f"Page {doc.page}")
|
||||
|
||||
canvas.restoreState()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Simple document
|
||||
content = [
|
||||
{'type': 'heading', 'content': 'Introduction'},
|
||||
{'type': 'paragraph', 'content': 'This is a sample paragraph with some text.'},
|
||||
{'type': 'space', 'height': 0.2},
|
||||
{'type': 'heading', 'content': 'Main Content'},
|
||||
{'type': 'paragraph', 'content': 'More content here with <b>bold</b> and <i>italic</i> text.'},
|
||||
{'type': 'bullet', 'content': 'First bullet point'},
|
||||
{'type': 'bullet', 'content': 'Second bullet point'},
|
||||
]
|
||||
|
||||
create_simple_document(
|
||||
"example_document.pdf",
|
||||
"Sample Document",
|
||||
author="John Doe",
|
||||
content_blocks=content
|
||||
)
|
||||
|
||||
print("Created: example_document.pdf")
|
||||
|
||||
# Example 2: Document with styled table
|
||||
doc = SimpleDocTemplate("table_example.pdf", pagesize=letter)
|
||||
story = []
|
||||
styles = getSampleStyleSheet()
|
||||
|
||||
story.append(Paragraph("Sales Report", styles['Title']))
|
||||
story.append(Spacer(1, 0.3*inch))
|
||||
|
||||
# Create table
|
||||
data = [
|
||||
['Product', 'Q1', 'Q2', 'Q3', 'Q4'],
|
||||
['Widget A', '100', '150', '130', '180'],
|
||||
['Widget B', '80', '120', '110', '160'],
|
||||
['Widget C', '90', '110', '100', '140'],
|
||||
]
|
||||
|
||||
table = create_styled_table(data, col_widths=[2*inch, 1*inch, 1*inch, 1*inch, 1*inch], style_name='striped')
|
||||
story.append(table)
|
||||
|
||||
doc.build(story)
|
||||
print("Created: table_example.pdf")
|
||||
@@ -1,780 +0,0 @@
|
||||
---
|
||||
name: scikit-learn
|
||||
description: "ML toolkit. Classification, regression, clustering, PCA, preprocessing, pipelines, GridSearch, cross-validation, RandomForest, SVM, for general machine learning workflows."
|
||||
---
|
||||
|
||||
# Scikit-learn: Machine Learning in Python
|
||||
|
||||
## Overview
|
||||
|
||||
Scikit-learn is Python's premier machine learning library, offering simple and efficient tools for predictive data analysis. Apply this skill for classification, regression, clustering, dimensionality reduction, model selection, preprocessing, and hyperparameter optimization.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Building classification models (spam detection, image recognition, medical diagnosis)
|
||||
- Creating regression models (price prediction, forecasting, trend analysis)
|
||||
- Performing clustering analysis (customer segmentation, pattern discovery)
|
||||
- Reducing dimensionality (PCA, t-SNE for visualization)
|
||||
- Preprocessing data (scaling, encoding, imputation)
|
||||
- Evaluating model performance (cross-validation, metrics)
|
||||
- Tuning hyperparameters (grid search, random search)
|
||||
- Creating machine learning pipelines
|
||||
- Detecting anomalies or outliers
|
||||
- Implementing ensemble methods
|
||||
|
||||
## Core Machine Learning Workflow
|
||||
|
||||
### Standard ML Pipeline
|
||||
|
||||
Follow this general workflow for supervised learning tasks:
|
||||
|
||||
1. **Data Preparation**
|
||||
- Load and explore data
|
||||
- Split into train/test sets
|
||||
- Handle missing values
|
||||
- Encode categorical features
|
||||
- Scale/normalize features
|
||||
|
||||
2. **Model Selection**
|
||||
- Start with baseline model
|
||||
- Try more complex models
|
||||
- Use domain knowledge to guide selection
|
||||
|
||||
3. **Model Training**
|
||||
- Fit model on training data
|
||||
- Use pipelines to prevent data leakage
|
||||
- Apply cross-validation
|
||||
|
||||
4. **Model Evaluation**
|
||||
- Evaluate on test set
|
||||
- Use appropriate metrics
|
||||
- Analyze errors
|
||||
|
||||
5. **Model Optimization**
|
||||
- Tune hyperparameters
|
||||
- Feature engineering
|
||||
- Ensemble methods
|
||||
|
||||
6. **Deployment**
|
||||
- Save model using joblib
|
||||
- Create prediction pipeline
|
||||
- Monitor performance
|
||||
|
||||
### Classification Quick Start
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
# Create pipeline (prevents data leakage)
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
|
||||
])
|
||||
|
||||
# Split data (use stratify for imbalanced classes)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# Train
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# Evaluate
|
||||
y_pred = pipeline.predict(X_test)
|
||||
print(classification_report(y_test, y_pred))
|
||||
|
||||
# Cross-validation for robust evaluation
|
||||
from sklearn.model_selection import cross_val_score
|
||||
scores = cross_val_score(pipeline, X_train, y_train, cv=5)
|
||||
print(f"CV Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})")
|
||||
```
|
||||
|
||||
### Regression Quick Start
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('regressor', RandomForestRegressor(n_estimators=100, random_state=42))
|
||||
])
|
||||
|
||||
# Split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Train
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# Evaluate
|
||||
y_pred = pipeline.predict(X_test)
|
||||
rmse = mean_squared_error(y_test, y_pred, squared=False)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
print(f"RMSE: {rmse:.3f}, R²: {r2:.3f}")
|
||||
```
|
||||
|
||||
## Algorithm Selection Guide
|
||||
|
||||
### Classification Algorithms
|
||||
|
||||
**Start with baseline**: LogisticRegression
|
||||
- Fast, interpretable, works well for linearly separable data
|
||||
- Good for high-dimensional data (text classification)
|
||||
|
||||
**General-purpose**: RandomForestClassifier
|
||||
- Handles non-linear relationships
|
||||
- Robust to outliers
|
||||
- Provides feature importance
|
||||
- Good default choice
|
||||
|
||||
**Best performance**: HistGradientBoostingClassifier
|
||||
- State-of-the-art for tabular data
|
||||
- Fast on large datasets (>10K samples)
|
||||
- Often wins Kaggle competitions
|
||||
|
||||
**Special cases**:
|
||||
- **Small datasets (<1K)**: SVC with RBF kernel
|
||||
- **Very large datasets (>100K)**: SGDClassifier or LinearSVC
|
||||
- **Interpretability critical**: LogisticRegression or DecisionTreeClassifier
|
||||
- **Probabilistic predictions**: GaussianNB or calibrated models
|
||||
- **Text classification**: LogisticRegression with TfidfVectorizer
|
||||
|
||||
### Regression Algorithms
|
||||
|
||||
**Start with baseline**: LinearRegression or Ridge
|
||||
- Fast, interpretable
|
||||
- Works well when relationships are linear
|
||||
|
||||
**General-purpose**: RandomForestRegressor
|
||||
- Handles non-linear relationships
|
||||
- Robust to outliers
|
||||
- Good default choice
|
||||
|
||||
**Best performance**: HistGradientBoostingRegressor
|
||||
- State-of-the-art for tabular data
|
||||
- Fast on large datasets
|
||||
|
||||
**Special cases**:
|
||||
- **Regularization needed**: Ridge (L2) or Lasso (L1 + feature selection)
|
||||
- **Very large datasets**: SGDRegressor
|
||||
- **Outliers present**: HuberRegressor or RANSAC
|
||||
|
||||
### Clustering Algorithms
|
||||
|
||||
**Known number of clusters**: KMeans
|
||||
- Fast and scalable
|
||||
- Assumes spherical clusters
|
||||
|
||||
**Unknown number of clusters**: DBSCAN or HDBSCAN
|
||||
- Handles arbitrary shapes
|
||||
- Automatic outlier detection
|
||||
|
||||
**Hierarchical relationships**: AgglomerativeClustering
|
||||
- Creates hierarchy of clusters
|
||||
- Good for visualization (dendrograms)
|
||||
|
||||
**Soft clustering (probabilities)**: GaussianMixture
|
||||
- Provides cluster probabilities
|
||||
- Handles elliptical clusters
|
||||
|
||||
### Dimensionality Reduction
|
||||
|
||||
**Preprocessing/feature extraction**: PCA
|
||||
- Fast and efficient
|
||||
- Linear transformation
|
||||
- ALWAYS standardize first
|
||||
|
||||
**Visualization only**: t-SNE or UMAP
|
||||
- Preserves local structure
|
||||
- Non-linear
|
||||
- DO NOT use for preprocessing
|
||||
|
||||
**Sparse data (text)**: TruncatedSVD
|
||||
- Works with sparse matrices
|
||||
- Latent Semantic Analysis
|
||||
|
||||
**Non-negative data**: NMF
|
||||
- Interpretable components
|
||||
- Topic modeling
|
||||
|
||||
## Working with Different Data Types
|
||||
|
||||
### Numeric Features
|
||||
|
||||
**Continuous features**:
|
||||
1. Check distribution
|
||||
2. Handle outliers (remove, clip, or use RobustScaler)
|
||||
3. Scale using StandardScaler (most algorithms) or MinMaxScaler (neural networks)
|
||||
|
||||
**Count data**:
|
||||
1. Consider log transformation or sqrt
|
||||
2. Scale after transformation
|
||||
|
||||
**Skewed data**:
|
||||
1. Use PowerTransformer (Yeo-Johnson or Box-Cox)
|
||||
2. Or QuantileTransformer for stronger normalization
|
||||
|
||||
### Categorical Features
|
||||
|
||||
**Low cardinality (<10 categories)**:
|
||||
```python
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
encoder = OneHotEncoder(drop='first', sparse_output=True)
|
||||
```
|
||||
|
||||
**High cardinality (>10 categories)**:
|
||||
```python
|
||||
from sklearn.preprocessing import TargetEncoder
|
||||
encoder = TargetEncoder()
|
||||
# Uses target statistics, prevents leakage with cross-fitting
|
||||
```
|
||||
|
||||
**Ordinal relationships**:
|
||||
```python
|
||||
from sklearn.preprocessing import OrdinalEncoder
|
||||
encoder = OrdinalEncoder(categories=[['small', 'medium', 'large']])
|
||||
```
|
||||
|
||||
### Text Data
|
||||
|
||||
```python
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
text_pipeline = Pipeline([
|
||||
('tfidf', TfidfVectorizer(max_features=1000, stop_words='english')),
|
||||
('classifier', MultinomialNB())
|
||||
])
|
||||
|
||||
text_pipeline.fit(X_train_text, y_train)
|
||||
```
|
||||
|
||||
### Mixed Data Types
|
||||
|
||||
```python
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
# Define feature types
|
||||
numeric_features = ['age', 'income', 'credit_score']
|
||||
categorical_features = ['country', 'occupation']
|
||||
|
||||
# Separate preprocessing pipelines
|
||||
numeric_transformer = Pipeline([
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('scaler', StandardScaler())
|
||||
])
|
||||
|
||||
categorical_transformer = Pipeline([
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=True))
|
||||
])
|
||||
|
||||
# Combine with ColumnTransformer
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numeric_features),
|
||||
('cat', categorical_transformer, categorical_features)
|
||||
])
|
||||
|
||||
# Complete pipeline
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
pipeline = Pipeline([
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', RandomForestClassifier())
|
||||
])
|
||||
|
||||
pipeline.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Model Evaluation
|
||||
|
||||
### Classification Metrics
|
||||
|
||||
**Balanced datasets**: Use accuracy or F1-score
|
||||
|
||||
**Imbalanced datasets**: Use balanced_accuracy, F1-weighted, or ROC-AUC
|
||||
```python
|
||||
from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score
|
||||
|
||||
balanced_acc = balanced_accuracy_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred, average='weighted')
|
||||
|
||||
# ROC-AUC requires probabilities
|
||||
y_proba = model.predict_proba(X_test)
|
||||
auc = roc_auc_score(y_true, y_proba, multi_class='ovr')
|
||||
```
|
||||
|
||||
**Cost-sensitive**: Define custom scorer or adjust decision threshold
|
||||
|
||||
**Comprehensive report**:
|
||||
```python
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
|
||||
print(classification_report(y_true, y_pred))
|
||||
print(confusion_matrix(y_true, y_pred))
|
||||
```
|
||||
|
||||
### Regression Metrics
|
||||
|
||||
**Standard use**: RMSE and R²
|
||||
```python
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
rmse = mean_squared_error(y_true, y_pred, squared=False)
|
||||
r2 = r2_score(y_true, y_pred)
|
||||
```
|
||||
|
||||
**Outliers present**: Use MAE (robust to outliers)
|
||||
```python
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
**Percentage errors matter**: Use MAPE
|
||||
```python
|
||||
from sklearn.metrics import mean_absolute_percentage_error
|
||||
mape = mean_absolute_percentage_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
### Cross-Validation
|
||||
|
||||
**Standard approach** (5-10 folds):
|
||||
```python
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
scores = cross_val_score(model, X, y, cv=5, scoring='accuracy')
|
||||
print(f"CV Score: {scores.mean():.3f} (+/- {scores.std():.3f})")
|
||||
```
|
||||
|
||||
**Imbalanced classes** (use stratification):
|
||||
```python
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
|
||||
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
||||
scores = cross_val_score(model, X, y, cv=cv)
|
||||
```
|
||||
|
||||
**Time series** (respect temporal order):
|
||||
```python
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
|
||||
cv = TimeSeriesSplit(n_splits=5)
|
||||
scores = cross_val_score(model, X, y, cv=cv)
|
||||
```
|
||||
|
||||
**Multiple metrics**:
|
||||
```python
|
||||
from sklearn.model_selection import cross_validate
|
||||
|
||||
scoring = ['accuracy', 'precision_weighted', 'recall_weighted', 'f1_weighted']
|
||||
results = cross_validate(model, X, y, cv=5, scoring=scoring)
|
||||
|
||||
for metric in scoring:
|
||||
scores = results[f'test_{metric}']
|
||||
print(f"{metric}: {scores.mean():.3f}")
|
||||
```
|
||||
|
||||
## Hyperparameter Tuning
|
||||
|
||||
### Grid Search (Exhaustive)
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
param_grid = {
|
||||
'n_estimators': [100, 200, 500],
|
||||
'max_depth': [10, 20, 30, None],
|
||||
'min_samples_split': [2, 5, 10]
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_grid,
|
||||
cv=5,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1, # Use all CPU cores
|
||||
verbose=1
|
||||
)
|
||||
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
print(f"Best parameters: {grid_search.best_params_}")
|
||||
print(f"Best CV score: {grid_search.best_score_:.3f}")
|
||||
|
||||
# Use best model
|
||||
best_model = grid_search.best_estimator_
|
||||
test_score = best_model.score(X_test, y_test)
|
||||
```
|
||||
|
||||
### Random Search (Faster)
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from scipy.stats import randint, uniform
|
||||
|
||||
param_distributions = {
|
||||
'n_estimators': randint(100, 1000),
|
||||
'max_depth': randint(5, 50),
|
||||
'min_samples_split': randint(2, 20),
|
||||
'max_features': uniform(0.1, 0.9)
|
||||
}
|
||||
|
||||
random_search = RandomizedSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_distributions,
|
||||
n_iter=100, # Number of combinations to try
|
||||
cv=5,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
random_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### Pipeline Hyperparameter Tuning
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('svm', SVC())
|
||||
])
|
||||
|
||||
# Use double underscore for nested parameters
|
||||
param_grid = {
|
||||
'svm__C': [0.1, 1, 10, 100],
|
||||
'svm__kernel': ['rbf', 'linear'],
|
||||
'svm__gamma': ['scale', 'auto', 0.001, 0.01]
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(pipeline, param_grid, cv=5, n_jobs=-1)
|
||||
grid_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Feature Engineering and Selection
|
||||
|
||||
### Feature Importance
|
||||
|
||||
```python
|
||||
# Tree-based models have built-in feature importance
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
model = RandomForestClassifier(n_estimators=100)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
importances = model.feature_importances_
|
||||
feature_importance_df = pd.DataFrame({
|
||||
'feature': feature_names,
|
||||
'importance': importances
|
||||
}).sort_values('importance', ascending=False)
|
||||
|
||||
# Permutation importance (works for any model)
|
||||
from sklearn.inspection import permutation_importance
|
||||
|
||||
result = permutation_importance(
|
||||
model, X_test, y_test,
|
||||
n_repeats=10,
|
||||
random_state=42,
|
||||
n_jobs=-1
|
||||
)
|
||||
|
||||
importance_df = pd.DataFrame({
|
||||
'feature': feature_names,
|
||||
'importance': result.importances_mean,
|
||||
'std': result.importances_std
|
||||
}).sort_values('importance', ascending=False)
|
||||
```
|
||||
|
||||
### Feature Selection Methods
|
||||
|
||||
**Univariate selection**:
|
||||
```python
|
||||
from sklearn.feature_selection import SelectKBest, f_classif
|
||||
|
||||
selector = SelectKBest(f_classif, k=10)
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
selected_features = selector.get_support(indices=True)
|
||||
```
|
||||
|
||||
**Recursive Feature Elimination**:
|
||||
```python
|
||||
from sklearn.feature_selection import RFECV
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
selector = RFECV(
|
||||
RandomForestClassifier(n_estimators=100),
|
||||
step=1,
|
||||
cv=5,
|
||||
n_jobs=-1
|
||||
)
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
print(f"Optimal features: {selector.n_features_}")
|
||||
```
|
||||
|
||||
**Model-based selection**:
|
||||
```python
|
||||
from sklearn.feature_selection import SelectFromModel
|
||||
|
||||
selector = SelectFromModel(
|
||||
RandomForestClassifier(n_estimators=100),
|
||||
threshold='median' # or '0.5*mean', or specific value
|
||||
)
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
```
|
||||
|
||||
### Polynomial Features
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
pipeline = Pipeline([
|
||||
('poly', PolynomialFeatures(degree=2, include_bias=False)),
|
||||
('scaler', StandardScaler()),
|
||||
('ridge', Ridge())
|
||||
])
|
||||
|
||||
pipeline.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Common Patterns and Best Practices
|
||||
|
||||
### Always Use Pipelines
|
||||
|
||||
Pipelines prevent data leakage and ensure proper workflow:
|
||||
|
||||
✅ **Correct**:
|
||||
```python
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('model', LogisticRegression())
|
||||
])
|
||||
pipeline.fit(X_train, y_train)
|
||||
y_pred = pipeline.predict(X_test)
|
||||
```
|
||||
|
||||
❌ **Wrong** (data leakage):
|
||||
```python
|
||||
scaler = StandardScaler().fit(X) # Fit on all data!
|
||||
X_train, X_test = train_test_split(scaler.transform(X))
|
||||
```
|
||||
|
||||
### Stratify for Imbalanced Classes
|
||||
|
||||
```python
|
||||
# Always use stratify for classification with imbalanced classes
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, stratify=y, random_state=42
|
||||
)
|
||||
```
|
||||
|
||||
### Scale When Necessary
|
||||
|
||||
**Scale for**: SVM, Neural Networks, KNN, Linear Models with regularization, PCA, Gradient Descent
|
||||
|
||||
**Don't scale for**: Tree-based models (Random Forest, Gradient Boosting), Naive Bayes
|
||||
|
||||
### Handle Missing Values
|
||||
|
||||
```python
|
||||
from sklearn.impute import SimpleImputer
|
||||
|
||||
# Numeric: use median (robust to outliers)
|
||||
imputer = SimpleImputer(strategy='median')
|
||||
|
||||
# Categorical: use constant value or most_frequent
|
||||
imputer = SimpleImputer(strategy='constant', fill_value='missing')
|
||||
```
|
||||
|
||||
### Use Appropriate Metrics
|
||||
|
||||
- **Balanced classification**: accuracy, F1
|
||||
- **Imbalanced classification**: balanced_accuracy, F1-weighted, ROC-AUC
|
||||
- **Regression with outliers**: MAE instead of RMSE
|
||||
- **Cost-sensitive**: custom scorer
|
||||
|
||||
### Set Random States
|
||||
|
||||
```python
|
||||
# For reproducibility
|
||||
model = RandomForestClassifier(random_state=42)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, random_state=42
|
||||
)
|
||||
```
|
||||
|
||||
### Use Parallel Processing
|
||||
|
||||
```python
|
||||
# Use all CPU cores
|
||||
model = RandomForestClassifier(n_jobs=-1)
|
||||
grid_search = GridSearchCV(model, param_grid, n_jobs=-1)
|
||||
```
|
||||
|
||||
## Unsupervised Learning
|
||||
|
||||
### Clustering Workflow
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
# Always scale for clustering
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
|
||||
# Elbow method to find optimal k
|
||||
inertias = []
|
||||
silhouette_scores = []
|
||||
K_range = range(2, 11)
|
||||
|
||||
for k in K_range:
|
||||
kmeans = KMeans(n_clusters=k, random_state=42)
|
||||
labels = kmeans.fit_predict(X_scaled)
|
||||
inertias.append(kmeans.inertia_)
|
||||
silhouette_scores.append(silhouette_score(X_scaled, labels))
|
||||
|
||||
# Plot and choose k
|
||||
import matplotlib.pyplot as plt
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
||||
ax1.plot(K_range, inertias, 'bo-')
|
||||
ax1.set_xlabel('k')
|
||||
ax1.set_ylabel('Inertia')
|
||||
ax2.plot(K_range, silhouette_scores, 'ro-')
|
||||
ax2.set_xlabel('k')
|
||||
ax2.set_ylabel('Silhouette Score')
|
||||
plt.show()
|
||||
|
||||
# Fit final model
|
||||
optimal_k = 5 # Based on elbow/silhouette
|
||||
kmeans = KMeans(n_clusters=optimal_k, random_state=42)
|
||||
labels = kmeans.fit_predict(X_scaled)
|
||||
```
|
||||
|
||||
### Dimensionality Reduction
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
# ALWAYS scale before PCA
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
|
||||
# Specify variance to retain
|
||||
pca = PCA(n_components=0.95) # Keep 95% of variance
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
|
||||
print(f"Original features: {X.shape[1]}")
|
||||
print(f"Reduced features: {pca.n_components_}")
|
||||
print(f"Variance explained: {pca.explained_variance_ratio_.sum():.3f}")
|
||||
|
||||
# Visualize explained variance
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(np.cumsum(pca.explained_variance_ratio_))
|
||||
plt.xlabel('Number of components')
|
||||
plt.ylabel('Cumulative explained variance')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
### Visualization with t-SNE
|
||||
|
||||
```python
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
# Reduce to 50 dimensions with PCA first (faster)
|
||||
pca = PCA(n_components=min(50, X.shape[1]))
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
|
||||
# Apply t-SNE (only for visualization!)
|
||||
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
||||
X_tsne = tsne.fit_transform(X_pca)
|
||||
|
||||
# Visualize
|
||||
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='viridis', alpha=0.6)
|
||||
plt.colorbar()
|
||||
plt.title('t-SNE Visualization')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
```python
|
||||
import joblib
|
||||
|
||||
# Save model or pipeline
|
||||
joblib.dump(model, 'model.pkl')
|
||||
joblib.dump(pipeline, 'pipeline.pkl')
|
||||
|
||||
# Load
|
||||
loaded_model = joblib.load('model.pkl')
|
||||
loaded_pipeline = joblib.load('pipeline.pkl')
|
||||
|
||||
# Use loaded model
|
||||
predictions = loaded_model.predict(X_new)
|
||||
```
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
This skill includes comprehensive reference files:
|
||||
|
||||
- **`references/supervised_learning.md`**: Detailed coverage of all classification and regression algorithms, parameters, use cases, and selection guidelines
|
||||
- **`references/preprocessing.md`**: Complete guide to data preprocessing including scaling, encoding, imputation, transformations, and best practices
|
||||
- **`references/model_evaluation.md`**: In-depth coverage of cross-validation strategies, metrics, hyperparameter tuning, and validation techniques
|
||||
- **`references/unsupervised_learning.md`**: Comprehensive guide to clustering, dimensionality reduction, anomaly detection, and evaluation methods
|
||||
- **`references/pipelines_and_composition.md`**: Complete guide to Pipeline, ColumnTransformer, FeatureUnion, custom transformers, and composition patterns
|
||||
- **`references/quick_reference.md`**: Quick lookup guide with code snippets, common patterns, and decision trees for algorithm selection
|
||||
|
||||
Read these files when:
|
||||
- Need detailed parameter explanations for specific algorithms
|
||||
- Comparing multiple algorithms for a task
|
||||
- Understanding evaluation metrics in depth
|
||||
- Building complex preprocessing workflows
|
||||
- Troubleshooting common issues
|
||||
|
||||
Example search patterns:
|
||||
```python
|
||||
# To find information about specific algorithms
|
||||
grep -r "GradientBoosting" references/
|
||||
|
||||
# To find preprocessing techniques
|
||||
grep -r "OneHotEncoder" references/preprocessing.md
|
||||
|
||||
# To find evaluation metrics
|
||||
grep -r "f1_score" references/model_evaluation.md
|
||||
```
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
1. **Data leakage**: Always use pipelines, fit only on training data
|
||||
2. **Not scaling**: Scale for distance-based algorithms (SVM, KNN, Neural Networks)
|
||||
3. **Wrong metrics**: Use appropriate metrics for imbalanced data
|
||||
4. **Not using cross-validation**: Single train-test split can be misleading
|
||||
5. **Forgetting stratification**: Stratify for imbalanced classification
|
||||
6. **Using t-SNE for preprocessing**: t-SNE is for visualization only!
|
||||
7. **Not setting random_state**: Results won't be reproducible
|
||||
8. **Ignoring class imbalance**: Use stratification, appropriate metrics, or resampling
|
||||
9. **PCA without scaling**: Components will be dominated by high-variance features
|
||||
10. **Testing on training data**: Always evaluate on held-out test set
|
||||
@@ -1,601 +0,0 @@
|
||||
# Model Evaluation and Selection in scikit-learn
|
||||
|
||||
## Overview
|
||||
Model evaluation assesses how well models generalize to unseen data. Scikit-learn provides three main APIs for evaluation:
|
||||
1. **Estimator score methods**: Built-in evaluation (accuracy for classifiers, R² for regressors)
|
||||
2. **Scoring parameter**: Used in cross-validation and hyperparameter tuning
|
||||
3. **Metric functions**: Specialized evaluation in `sklearn.metrics`
|
||||
|
||||
## Cross-Validation
|
||||
|
||||
Cross-validation evaluates model performance by splitting data into multiple train/test sets. This addresses overfitting: "a model that would just repeat the labels of the samples that it has just seen would have a perfect score but would fail to predict anything useful on yet-unseen data."
|
||||
|
||||
### Basic Cross-Validation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import cross_val_score
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
model = LogisticRegression()
|
||||
scores = cross_val_score(model, X, y, cv=5, scoring='accuracy')
|
||||
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})")
|
||||
```
|
||||
|
||||
### Cross-Validation Strategies
|
||||
|
||||
#### For i.i.d. Data
|
||||
|
||||
**KFold**: Standard k-fold cross-validation
|
||||
- Splits data into k equal folds
|
||||
- Each fold used once as test set
|
||||
- `n_splits`: Number of folds (typically 5 or 10)
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import KFold
|
||||
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
||||
```
|
||||
|
||||
**RepeatedKFold**: Repeats KFold with different randomization
|
||||
- More robust estimation
|
||||
- Computationally expensive
|
||||
|
||||
**LeaveOneOut (LOO)**: Each sample is a test set
|
||||
- Maximum training data usage
|
||||
- Very computationally expensive
|
||||
- High variance in estimates
|
||||
- Use only for small datasets (<1000 samples)
|
||||
|
||||
**ShuffleSplit**: Random train/test splits
|
||||
- Flexible train/test sizes
|
||||
- Can control number of iterations
|
||||
- Good for quick evaluation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=42)
|
||||
```
|
||||
|
||||
#### For Imbalanced Classes
|
||||
|
||||
**StratifiedKFold**: Preserves class proportions in each fold
|
||||
- Essential for imbalanced datasets
|
||||
- Default for classification in cross_val_score()
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
||||
```
|
||||
|
||||
**StratifiedShuffleSplit**: Stratified random splits
|
||||
|
||||
#### For Grouped Data
|
||||
|
||||
Use when samples are not independent (e.g., multiple measurements from same subject).
|
||||
|
||||
**GroupKFold**: Groups don't appear in both train and test
|
||||
```python
|
||||
from sklearn.model_selection import GroupKFold
|
||||
cv = GroupKFold(n_splits=5)
|
||||
scores = cross_val_score(model, X, y, groups=groups, cv=cv)
|
||||
```
|
||||
|
||||
**StratifiedGroupKFold**: Combines stratification with group separation
|
||||
|
||||
**LeaveOneGroupOut**: Each group becomes a test set
|
||||
|
||||
#### For Time Series
|
||||
|
||||
**TimeSeriesSplit**: Expanding window approach
|
||||
- Successive training sets are supersets
|
||||
- Respects temporal ordering
|
||||
- No data leakage from future to past
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
cv = TimeSeriesSplit(n_splits=5)
|
||||
for train_idx, test_idx in cv.split(X):
|
||||
# Train on indices 0 to t, test on t+1 to t+k
|
||||
pass
|
||||
```
|
||||
|
||||
### Cross-Validation Functions
|
||||
|
||||
**cross_val_score**: Returns array of scores
|
||||
```python
|
||||
scores = cross_val_score(model, X, y, cv=5, scoring='f1_weighted')
|
||||
```
|
||||
|
||||
**cross_validate**: Returns multiple metrics and timing
|
||||
```python
|
||||
results = cross_validate(
|
||||
model, X, y, cv=5,
|
||||
scoring=['accuracy', 'f1_weighted', 'roc_auc'],
|
||||
return_train_score=True,
|
||||
return_estimator=True # Returns fitted estimators
|
||||
)
|
||||
print(results['test_accuracy'])
|
||||
print(results['fit_time'])
|
||||
```
|
||||
|
||||
**cross_val_predict**: Returns predictions for model blending/visualization
|
||||
```python
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
y_pred = cross_val_predict(model, X, y, cv=5)
|
||||
# Use for confusion matrix, error analysis, etc.
|
||||
```
|
||||
|
||||
## Hyperparameter Tuning
|
||||
|
||||
### GridSearchCV
|
||||
Exhaustively searches all parameter combinations.
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
param_grid = {
|
||||
'n_estimators': [100, 200, 500],
|
||||
'max_depth': [10, 20, 30, None],
|
||||
'min_samples_split': [2, 5, 10],
|
||||
'min_samples_leaf': [1, 2, 4]
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_grid,
|
||||
cv=5,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1, # Use all CPU cores
|
||||
verbose=2
|
||||
)
|
||||
|
||||
grid_search.fit(X_train, y_train)
|
||||
print("Best parameters:", grid_search.best_params_)
|
||||
print("Best score:", grid_search.best_score_)
|
||||
|
||||
# Use best model
|
||||
best_model = grid_search.best_estimator_
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Small parameter spaces
|
||||
- When computational resources allow
|
||||
- When exhaustive search is desired
|
||||
|
||||
### RandomizedSearchCV
|
||||
Samples parameter combinations from distributions.
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from scipy.stats import randint, uniform
|
||||
|
||||
param_distributions = {
|
||||
'n_estimators': randint(100, 1000),
|
||||
'max_depth': randint(5, 50),
|
||||
'min_samples_split': randint(2, 20),
|
||||
'min_samples_leaf': randint(1, 10),
|
||||
'max_features': uniform(0.1, 0.9)
|
||||
}
|
||||
|
||||
random_search = RandomizedSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_distributions,
|
||||
n_iter=100, # Number of parameter settings sampled
|
||||
cv=5,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
random_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Large parameter spaces
|
||||
- When budget is limited
|
||||
- Often finds good parameters faster than GridSearchCV
|
||||
|
||||
**Advantage**: "Budget can be chosen independent of the number of parameters and possible values"
|
||||
|
||||
### Successive Halving
|
||||
|
||||
**HalvingGridSearchCV** and **HalvingRandomSearchCV**: Tournament-style selection
|
||||
|
||||
**How it works**:
|
||||
1. Start with many candidates, minimal resources
|
||||
2. Eliminate poor performers
|
||||
3. Increase resources for remaining candidates
|
||||
4. Repeat until best candidates found
|
||||
|
||||
**When to use**:
|
||||
- Large parameter spaces
|
||||
- Expensive model training
|
||||
- When many parameter combinations are clearly inferior
|
||||
|
||||
```python
|
||||
from sklearn.experimental import enable_halving_search_cv
|
||||
from sklearn.model_selection import HalvingGridSearchCV
|
||||
|
||||
halving_search = HalvingGridSearchCV(
|
||||
estimator,
|
||||
param_grid,
|
||||
factor=3, # Proportion of candidates eliminated each round
|
||||
cv=5
|
||||
)
|
||||
```
|
||||
|
||||
## Classification Metrics
|
||||
|
||||
### Accuracy-Based Metrics
|
||||
|
||||
**Accuracy**: Proportion of correct predictions
|
||||
```python
|
||||
from sklearn.metrics import accuracy_score
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
```
|
||||
|
||||
**When to use**: Balanced datasets only
|
||||
**When NOT to use**: Imbalanced datasets (misleading)
|
||||
|
||||
**Balanced Accuracy**: Average recall per class
|
||||
```python
|
||||
from sklearn.metrics import balanced_accuracy_score
|
||||
bal_acc = balanced_accuracy_score(y_true, y_pred)
|
||||
```
|
||||
|
||||
**When to use**: Imbalanced datasets, ensures all classes matter equally
|
||||
|
||||
### Precision, Recall, F-Score
|
||||
|
||||
**Precision**: Of predicted positives, how many are actually positive
|
||||
- Formula: TP / (TP + FP)
|
||||
- Answers: "How reliable are positive predictions?"
|
||||
|
||||
**Recall** (Sensitivity): Of actual positives, how many are predicted positive
|
||||
- Formula: TP / (TP + FN)
|
||||
- Answers: "How complete is positive detection?"
|
||||
|
||||
**F1-Score**: Harmonic mean of precision and recall
|
||||
- Formula: 2 * (precision * recall) / (precision + recall)
|
||||
- Balanced measure when both precision and recall are important
|
||||
|
||||
```python
|
||||
from sklearn.metrics import precision_recall_fscore_support, f1_score
|
||||
|
||||
precision, recall, f1, support = precision_recall_fscore_support(
|
||||
y_true, y_pred, average='weighted'
|
||||
)
|
||||
|
||||
# Or individually
|
||||
f1 = f1_score(y_true, y_pred, average='weighted')
|
||||
```
|
||||
|
||||
**Averaging strategies for multiclass**:
|
||||
- `binary`: Binary classification only
|
||||
- `micro`: Calculate globally (total TP, FP, FN)
|
||||
- `macro`: Calculate per class, unweighted mean (all classes equal)
|
||||
- `weighted`: Calculate per class, weighted by support (class frequency)
|
||||
- `samples`: For multilabel classification
|
||||
|
||||
**When to use**:
|
||||
- `macro`: When all classes equally important (even rare ones)
|
||||
- `weighted`: When class frequency matters
|
||||
- `micro`: When overall performance across all samples matters
|
||||
|
||||
### Confusion Matrix
|
||||
|
||||
Shows true positives, false positives, true negatives, false negatives.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Class 0', 'Class 1'])
|
||||
disp.plot()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
### ROC Curve and AUC
|
||||
|
||||
**ROC (Receiver Operating Characteristic)**: Plot of true positive rate vs false positive rate at different thresholds
|
||||
|
||||
**AUC (Area Under Curve)**: Measures overall ability to discriminate between classes
|
||||
- 1.0 = perfect classifier
|
||||
- 0.5 = random classifier
|
||||
- <0.5 = worse than random
|
||||
|
||||
```python
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Requires probability predictions
|
||||
y_proba = model.predict_proba(X_test)[:, 1] # Probabilities for positive class
|
||||
|
||||
auc = roc_auc_score(y_true, y_proba)
|
||||
fpr, tpr, thresholds = roc_curve(y_true, y_proba)
|
||||
|
||||
plt.plot(fpr, tpr, label=f'AUC = {auc:.3f}')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Multiclass ROC**: Use `multi_class='ovr'` (one-vs-rest) or `'ovo'` (one-vs-one)
|
||||
|
||||
```python
|
||||
auc = roc_auc_score(y_true, y_proba, multi_class='ovr')
|
||||
```
|
||||
|
||||
### Log Loss
|
||||
|
||||
Measures probability calibration quality.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import log_loss
|
||||
loss = log_loss(y_true, y_proba)
|
||||
```
|
||||
|
||||
**When to use**: When probability quality matters, not just class predictions
|
||||
**Lower is better**: Perfect predictions have log loss of 0
|
||||
|
||||
### Classification Report
|
||||
|
||||
Comprehensive summary of precision, recall, f1-score per class.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
print(classification_report(y_true, y_pred, target_names=['Class 0', 'Class 1']))
|
||||
```
|
||||
|
||||
## Regression Metrics
|
||||
|
||||
### Mean Squared Error (MSE)
|
||||
Average squared difference between predictions and true values.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import mean_squared_error
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
rmse = mean_squared_error(y_true, y_pred, squared=False) # Root MSE
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Penalizes large errors heavily (squared term)
|
||||
- Same units as target² (use RMSE for same units as target)
|
||||
- Lower is better
|
||||
|
||||
### Mean Absolute Error (MAE)
|
||||
Average absolute difference between predictions and true values.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- More robust to outliers than MSE
|
||||
- Same units as target
|
||||
- More interpretable
|
||||
- Lower is better
|
||||
|
||||
**MSE vs MAE**: Use MAE when outliers shouldn't dominate the metric
|
||||
|
||||
### R² Score (Coefficient of Determination)
|
||||
Proportion of variance explained by the model.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import r2_score
|
||||
r2 = r2_score(y_true, y_pred)
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- 1.0 = perfect predictions
|
||||
- 0.0 = model as good as mean
|
||||
- <0.0 = model worse than mean (possible!)
|
||||
- Higher is better
|
||||
|
||||
**Note**: Can be negative for models that perform worse than predicting the mean.
|
||||
|
||||
### Mean Absolute Percentage Error (MAPE)
|
||||
Percentage-based error metric.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import mean_absolute_percentage_error
|
||||
mape = mean_absolute_percentage_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
**When to use**: When relative errors matter more than absolute errors
|
||||
**Warning**: Undefined when true values are zero
|
||||
|
||||
### Median Absolute Error
|
||||
Median of absolute errors (robust to outliers).
|
||||
|
||||
```python
|
||||
from sklearn.metrics import median_absolute_error
|
||||
med_ae = median_absolute_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
### Max Error
|
||||
Maximum residual error.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import max_error
|
||||
max_err = max_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
**When to use**: When worst-case performance matters
|
||||
|
||||
## Custom Scoring Functions
|
||||
|
||||
Create custom scorers for GridSearchCV and cross_val_score:
|
||||
|
||||
```python
|
||||
from sklearn.metrics import make_scorer, fbeta_score
|
||||
|
||||
# F2 score (weights recall higher than precision)
|
||||
f2_scorer = make_scorer(fbeta_score, beta=2)
|
||||
|
||||
# Custom function
|
||||
def custom_metric(y_true, y_pred):
|
||||
# Your custom logic
|
||||
return score
|
||||
|
||||
custom_scorer = make_scorer(custom_metric, greater_is_better=True)
|
||||
|
||||
# Use in cross-validation or grid search
|
||||
scores = cross_val_score(model, X, y, cv=5, scoring=custom_scorer)
|
||||
```
|
||||
|
||||
## Scoring Parameter Options
|
||||
|
||||
Common scoring strings for `scoring` parameter:
|
||||
|
||||
**Classification**:
|
||||
- `'accuracy'`, `'balanced_accuracy'`
|
||||
- `'precision'`, `'recall'`, `'f1'` (add `_macro`, `_micro`, `_weighted` for multiclass)
|
||||
- `'roc_auc'`, `'roc_auc_ovr'`, `'roc_auc_ovo'`
|
||||
- `'log_loss'` (lower is better, negate for maximization)
|
||||
- `'jaccard'` (Jaccard similarity)
|
||||
|
||||
**Regression**:
|
||||
- `'r2'`
|
||||
- `'neg_mean_squared_error'`, `'neg_root_mean_squared_error'`
|
||||
- `'neg_mean_absolute_error'`
|
||||
- `'neg_mean_absolute_percentage_error'`
|
||||
- `'neg_median_absolute_error'`
|
||||
|
||||
**Note**: Many metrics are negated (neg_*) so GridSearchCV can maximize them.
|
||||
|
||||
## Validation Strategies
|
||||
|
||||
### Train-Test Split
|
||||
Simple single split.
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
stratify=y # For classification with imbalanced classes
|
||||
)
|
||||
```
|
||||
|
||||
**When to use**: Large datasets, quick evaluation
|
||||
**Parameters**:
|
||||
- `test_size`: Proportion for test (typically 0.2-0.3)
|
||||
- `stratify`: Preserves class proportions
|
||||
- `random_state`: Reproducibility
|
||||
|
||||
### Train-Validation-Test Split
|
||||
Three-way split for hyperparameter tuning.
|
||||
|
||||
```python
|
||||
# First split: train+val and test
|
||||
X_trainval, X_test, y_trainval, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Second split: train and validation
|
||||
X_train, X_val, y_train, y_val = train_test_split(
|
||||
X_trainval, y_trainval, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Or use GridSearchCV with train+val, then evaluate on test
|
||||
```
|
||||
|
||||
**When to use**: Model selection and final evaluation
|
||||
**Strategy**:
|
||||
1. Train: Model training
|
||||
2. Validation: Hyperparameter tuning
|
||||
3. Test: Final, unbiased evaluation (touch only once!)
|
||||
|
||||
### Learning Curves
|
||||
|
||||
Diagnose bias vs variance issues.
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import learning_curve
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
train_sizes, train_scores, val_scores = learning_curve(
|
||||
model, X, y,
|
||||
cv=5,
|
||||
train_sizes=np.linspace(0.1, 1.0, 10),
|
||||
scoring='accuracy',
|
||||
n_jobs=-1
|
||||
)
|
||||
|
||||
plt.plot(train_sizes, train_scores.mean(axis=1), label='Training score')
|
||||
plt.plot(train_sizes, val_scores.mean(axis=1), label='Validation score')
|
||||
plt.xlabel('Training set size')
|
||||
plt.ylabel('Score')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Large gap between train and validation: **Overfitting** (high variance)
|
||||
- Both scores low: **Underfitting** (high bias)
|
||||
- Scores converging but low: Need better features or more complex model
|
||||
- Validation score still improving: More data would help
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Metric Selection Guidelines
|
||||
|
||||
**Classification - Balanced classes**:
|
||||
- Accuracy or F1-score
|
||||
|
||||
**Classification - Imbalanced classes**:
|
||||
- Balanced accuracy
|
||||
- F1-score (weighted or macro)
|
||||
- ROC-AUC
|
||||
- Precision-Recall curve
|
||||
|
||||
**Classification - Cost-sensitive**:
|
||||
- Custom scorer with cost matrix
|
||||
- Adjust threshold on probabilities
|
||||
|
||||
**Regression - Typical use**:
|
||||
- RMSE (sensitive to outliers)
|
||||
- R² (proportion of variance explained)
|
||||
|
||||
**Regression - Outliers present**:
|
||||
- MAE (robust to outliers)
|
||||
- Median absolute error
|
||||
|
||||
**Regression - Percentage errors matter**:
|
||||
- MAPE
|
||||
|
||||
### Cross-Validation Guidelines
|
||||
|
||||
**Number of folds**:
|
||||
- 5-10 folds typical
|
||||
- More folds = more computation, less variance in estimate
|
||||
- LeaveOneOut only for small datasets
|
||||
|
||||
**Stratification**:
|
||||
- Always use for classification with imbalanced classes
|
||||
- Use StratifiedKFold by default for classification
|
||||
|
||||
**Grouping**:
|
||||
- Always use when samples are not independent
|
||||
- Time series: Always use TimeSeriesSplit
|
||||
|
||||
**Nested cross-validation**:
|
||||
- For unbiased performance estimate when doing hyperparameter tuning
|
||||
- Outer loop: Performance estimation
|
||||
- Inner loop: Hyperparameter selection
|
||||
|
||||
### Avoiding Common Pitfalls
|
||||
|
||||
1. **Data leakage**: Fit preprocessors only on training data within each CV fold (use Pipeline!)
|
||||
2. **Test set leakage**: Never use test set for model selection
|
||||
3. **Improper metric**: Use metrics appropriate for problem (balanced_accuracy for imbalanced data)
|
||||
4. **Multiple testing**: More models evaluated = higher chance of random good results
|
||||
5. **Temporal leakage**: For time series, use TimeSeriesSplit, not random splits
|
||||
6. **Target leakage**: Features shouldn't contain information not available at prediction time
|
||||
@@ -1,679 +0,0 @@
|
||||
# Pipelines and Composite Estimators in scikit-learn
|
||||
|
||||
## Overview
|
||||
Pipelines chain multiple estimators into a single unit, ensuring proper workflow sequencing and preventing data leakage. As the documentation states: "Pipeline can be used to chain multiple estimators into one. This is useful as there is often a fixed sequence of steps in processing the data, for example feature selection, normalization and classification."
|
||||
|
||||
## Pipeline Basics
|
||||
|
||||
### Creating Pipelines
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
# Method 1: List of (name, estimator) tuples
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('pca', PCA(n_components=10)),
|
||||
('classifier', LogisticRegression())
|
||||
])
|
||||
|
||||
# Method 2: Using make_pipeline (auto-generates names)
|
||||
from sklearn.pipeline import make_pipeline
|
||||
pipeline = make_pipeline(
|
||||
StandardScaler(),
|
||||
PCA(n_components=10),
|
||||
LogisticRegression()
|
||||
)
|
||||
```
|
||||
|
||||
### Using Pipelines
|
||||
|
||||
```python
|
||||
# Fit and predict like any estimator
|
||||
pipeline.fit(X_train, y_train)
|
||||
y_pred = pipeline.predict(X_test)
|
||||
score = pipeline.score(X_test, y_test)
|
||||
|
||||
# Access steps
|
||||
pipeline.named_steps['scaler']
|
||||
pipeline.steps[0] # Returns ('scaler', StandardScaler(...))
|
||||
pipeline[0] # Returns StandardScaler(...) object
|
||||
pipeline['scaler'] # Returns StandardScaler(...) object
|
||||
|
||||
# Get final estimator
|
||||
pipeline[-1] # Returns LogisticRegression(...) object
|
||||
```
|
||||
|
||||
### Pipeline Rules
|
||||
|
||||
**All steps except the last must be transformers** (have `fit()` and `transform()` methods).
|
||||
|
||||
**The final step** can be:
|
||||
- Predictor (classifier/regressor) with `fit()` and `predict()`
|
||||
- Transformer with `fit()` and `transform()`
|
||||
- Any estimator with at least `fit()`
|
||||
|
||||
### Pipeline Benefits
|
||||
|
||||
1. **Convenience**: Single `fit()` and `predict()` call
|
||||
2. **Prevents data leakage**: Ensures proper fit/transform on train/test
|
||||
3. **Joint parameter selection**: Tune all steps together with GridSearchCV
|
||||
4. **Reproducibility**: Encapsulates entire workflow
|
||||
|
||||
## Accessing and Setting Parameters
|
||||
|
||||
### Nested Parameters
|
||||
|
||||
Access step parameters using `stepname__parameter` syntax:
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('clf', LogisticRegression())
|
||||
])
|
||||
|
||||
# Grid search over pipeline parameters
|
||||
param_grid = {
|
||||
'scaler__with_mean': [True, False],
|
||||
'clf__C': [0.1, 1.0, 10.0],
|
||||
'clf__penalty': ['l1', 'l2']
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(pipeline, param_grid, cv=5)
|
||||
grid_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### Setting Parameters
|
||||
|
||||
```python
|
||||
# Set parameters
|
||||
pipeline.set_params(clf__C=10.0, scaler__with_std=False)
|
||||
|
||||
# Get parameters
|
||||
params = pipeline.get_params()
|
||||
```
|
||||
|
||||
## Caching Intermediate Results
|
||||
|
||||
Cache fitted transformers to avoid recomputation:
|
||||
|
||||
```python
|
||||
from tempfile import mkdtemp
|
||||
from shutil import rmtree
|
||||
|
||||
# Create cache directory
|
||||
cachedir = mkdtemp()
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('pca', PCA(n_components=10)),
|
||||
('clf', LogisticRegression())
|
||||
], memory=cachedir)
|
||||
|
||||
# When doing grid search, scaler and PCA only fit once per fold
|
||||
grid_search = GridSearchCV(pipeline, param_grid, cv=5)
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
# Clean up cache
|
||||
rmtree(cachedir)
|
||||
|
||||
# Or use joblib for persistent caching
|
||||
from joblib import Memory
|
||||
memory = Memory(location='./cache', verbose=0)
|
||||
pipeline = Pipeline([...], memory=memory)
|
||||
```
|
||||
|
||||
**When to use caching**:
|
||||
- Expensive transformations (PCA, feature selection)
|
||||
- Grid search over final estimator parameters only
|
||||
- Multiple experiments with same preprocessing
|
||||
|
||||
## ColumnTransformer
|
||||
|
||||
Apply different transformations to different columns (essential for heterogeneous data).
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
|
||||
# Define which transformations for which columns
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', StandardScaler(), ['age', 'income', 'credit_score']),
|
||||
('cat', OneHotEncoder(), ['country', 'occupation'])
|
||||
],
|
||||
remainder='drop' # What to do with remaining columns
|
||||
)
|
||||
|
||||
X_transformed = preprocessor.fit_transform(X)
|
||||
```
|
||||
|
||||
### Column Selection Methods
|
||||
|
||||
```python
|
||||
# Method 1: Column names (list of strings)
|
||||
('num', StandardScaler(), ['age', 'income'])
|
||||
|
||||
# Method 2: Column indices (list of integers)
|
||||
('num', StandardScaler(), [0, 1, 2])
|
||||
|
||||
# Method 3: Boolean mask
|
||||
('num', StandardScaler(), [True, True, False, True, False])
|
||||
|
||||
# Method 4: Slice
|
||||
('num', StandardScaler(), slice(0, 3))
|
||||
|
||||
# Method 5: make_column_selector (by dtype or pattern)
|
||||
from sklearn.compose import make_column_selector as selector
|
||||
|
||||
preprocessor = ColumnTransformer([
|
||||
('num', StandardScaler(), selector(dtype_include='number')),
|
||||
('cat', OneHotEncoder(), selector(dtype_include='object'))
|
||||
])
|
||||
|
||||
# Select by pattern
|
||||
selector(pattern='.*_score$') # All columns ending with '_score'
|
||||
```
|
||||
|
||||
### Remainder Parameter
|
||||
|
||||
Controls what happens to columns not specified:
|
||||
|
||||
```python
|
||||
# Drop remaining columns (default)
|
||||
remainder='drop'
|
||||
|
||||
# Pass through remaining columns unchanged
|
||||
remainder='passthrough'
|
||||
|
||||
# Apply transformer to remaining columns
|
||||
remainder=StandardScaler()
|
||||
```
|
||||
|
||||
### Full Pipeline with ColumnTransformer
|
||||
|
||||
```python
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
# Separate preprocessing for numeric and categorical
|
||||
numeric_features = ['age', 'income', 'credit_score']
|
||||
categorical_features = ['country', 'occupation', 'education']
|
||||
|
||||
numeric_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('scaler', StandardScaler())
|
||||
])
|
||||
|
||||
categorical_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||
])
|
||||
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numeric_features),
|
||||
('cat', categorical_transformer, categorical_features)
|
||||
])
|
||||
|
||||
# Complete pipeline
|
||||
clf = Pipeline(steps=[
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', RandomForestClassifier())
|
||||
])
|
||||
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
# Grid search over preprocessing and model parameters
|
||||
param_grid = {
|
||||
'preprocessor__num__imputer__strategy': ['mean', 'median'],
|
||||
'preprocessor__cat__onehot__max_categories': [10, 20, None],
|
||||
'classifier__n_estimators': [100, 200],
|
||||
'classifier__max_depth': [10, 20, None]
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(clf, param_grid, cv=5)
|
||||
grid_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## FeatureUnion
|
||||
|
||||
Combine multiple transformer outputs by concatenating features side-by-side.
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import FeatureUnion
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.feature_selection import SelectKBest
|
||||
|
||||
# Combine PCA and feature selection
|
||||
combined_features = FeatureUnion([
|
||||
('pca', PCA(n_components=10)),
|
||||
('univ_select', SelectKBest(k=5))
|
||||
])
|
||||
|
||||
X_features = combined_features.fit_transform(X, y)
|
||||
# Result: 15 features (10 from PCA + 5 from SelectKBest)
|
||||
|
||||
# In a pipeline
|
||||
pipeline = Pipeline([
|
||||
('features', combined_features),
|
||||
('classifier', LogisticRegression())
|
||||
])
|
||||
```
|
||||
|
||||
### FeatureUnion with Transformers on Different Data
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import FeatureUnion
|
||||
from sklearn.preprocessing import FunctionTransformer
|
||||
import numpy as np
|
||||
|
||||
def get_numeric_data(X):
|
||||
return X[:, :3] # First 3 columns
|
||||
|
||||
def get_text_data(X):
|
||||
return X[:, 3] # 4th column (text)
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
combined = FeatureUnion([
|
||||
('numeric_features', Pipeline([
|
||||
('selector', FunctionTransformer(get_numeric_data)),
|
||||
('scaler', StandardScaler())
|
||||
])),
|
||||
('text_features', Pipeline([
|
||||
('selector', FunctionTransformer(get_text_data)),
|
||||
('tfidf', TfidfVectorizer())
|
||||
]))
|
||||
])
|
||||
```
|
||||
|
||||
**Note**: ColumnTransformer is usually more convenient than FeatureUnion for heterogeneous data.
|
||||
|
||||
## Common Pipeline Patterns
|
||||
|
||||
### Classification Pipeline
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.feature_selection import SelectKBest, f_classif
|
||||
from sklearn.svm import SVC
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('feature_selection', SelectKBest(f_classif, k=10)),
|
||||
('classifier', SVC(kernel='rbf'))
|
||||
])
|
||||
```
|
||||
|
||||
### Regression Pipeline
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
|
||||
from sklearn.linear_model import Ridge
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('poly', PolynomialFeatures(degree=2)),
|
||||
('ridge', Ridge(alpha=1.0))
|
||||
])
|
||||
```
|
||||
|
||||
### Text Classification Pipeline
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
|
||||
pipeline = Pipeline([
|
||||
('tfidf', TfidfVectorizer(max_features=1000)),
|
||||
('classifier', MultinomialNB())
|
||||
])
|
||||
|
||||
# Works directly with text
|
||||
pipeline.fit(X_train_text, y_train)
|
||||
y_pred = pipeline.predict(X_test_text)
|
||||
```
|
||||
|
||||
### Image Processing Pipeline
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('pca', PCA(n_components=100)),
|
||||
('mlp', MLPClassifier(hidden_layer_sizes=(100, 50)))
|
||||
])
|
||||
```
|
||||
|
||||
### Dimensionality Reduction + Clustering
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('pca', PCA(n_components=10)),
|
||||
('kmeans', KMeans(n_clusters=5))
|
||||
])
|
||||
|
||||
labels = pipeline.fit_predict(X)
|
||||
```
|
||||
|
||||
## Custom Transformers
|
||||
|
||||
### Using FunctionTransformer
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import FunctionTransformer
|
||||
import numpy as np
|
||||
|
||||
# Log transformation
|
||||
log_transformer = FunctionTransformer(np.log1p)
|
||||
|
||||
# Custom function
|
||||
def custom_transform(X):
|
||||
# Your transformation logic
|
||||
return X_transformed
|
||||
|
||||
custom_transformer = FunctionTransformer(custom_transform)
|
||||
|
||||
# In pipeline
|
||||
pipeline = Pipeline([
|
||||
('log', log_transformer),
|
||||
('scaler', StandardScaler()),
|
||||
('model', LinearRegression())
|
||||
])
|
||||
```
|
||||
|
||||
### Creating Custom Transformer Class
|
||||
|
||||
```python
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
|
||||
class CustomTransformer(BaseEstimator, TransformerMixin):
|
||||
def __init__(self, parameter=1.0):
|
||||
self.parameter = parameter
|
||||
|
||||
def fit(self, X, y=None):
|
||||
# Learn parameters from X
|
||||
self.learned_param_ = X.mean() # Example
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
# Transform X using learned parameters
|
||||
return X * self.parameter - self.learned_param_
|
||||
|
||||
# Optional: for pipelines that need inverse transform
|
||||
def inverse_transform(self, X):
|
||||
return (X + self.learned_param_) / self.parameter
|
||||
|
||||
# Use in pipeline
|
||||
pipeline = Pipeline([
|
||||
('custom', CustomTransformer(parameter=2.0)),
|
||||
('model', LinearRegression())
|
||||
])
|
||||
```
|
||||
|
||||
**Key requirements**:
|
||||
- Inherit from `BaseEstimator` and `TransformerMixin`
|
||||
- Implement `fit()` and `transform()` methods
|
||||
- `fit()` must return `self`
|
||||
- Use trailing underscore for learned attributes (`learned_param_`)
|
||||
- Constructor parameters should be stored as attributes
|
||||
|
||||
### Transformer for Pandas DataFrames
|
||||
|
||||
```python
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
import pandas as pd
|
||||
|
||||
class DataFrameTransformer(BaseEstimator, TransformerMixin):
|
||||
def __init__(self, columns=None):
|
||||
self.columns = columns
|
||||
|
||||
def fit(self, X, y=None):
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
if isinstance(X, pd.DataFrame):
|
||||
if self.columns:
|
||||
return X[self.columns].values
|
||||
return X.values
|
||||
return X
|
||||
```
|
||||
|
||||
## Visualization
|
||||
|
||||
### Display Pipeline in Jupyter
|
||||
|
||||
```python
|
||||
from sklearn import set_config
|
||||
|
||||
# Enable HTML display
|
||||
set_config(display='diagram')
|
||||
|
||||
# Now displaying the pipeline shows interactive diagram
|
||||
pipeline
|
||||
```
|
||||
|
||||
### Print Pipeline Structure
|
||||
|
||||
```python
|
||||
from sklearn.utils import estimator_html_repr
|
||||
|
||||
# Get HTML representation
|
||||
html = estimator_html_repr(pipeline)
|
||||
|
||||
# Or just print
|
||||
print(pipeline)
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Conditional Transformations
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler, FunctionTransformer
|
||||
|
||||
def conditional_scale(X, scale=True):
|
||||
if scale:
|
||||
return StandardScaler().fit_transform(X)
|
||||
return X
|
||||
|
||||
pipeline = Pipeline([
|
||||
('conditional_scaler', FunctionTransformer(
|
||||
conditional_scale,
|
||||
kw_args={'scale': True}
|
||||
)),
|
||||
('model', LogisticRegression())
|
||||
])
|
||||
```
|
||||
|
||||
### Multiple Preprocessing Paths
|
||||
|
||||
```python
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
# Different preprocessing for different feature types
|
||||
preprocessor = ColumnTransformer([
|
||||
# Numeric: impute + scale
|
||||
('num_standard', Pipeline([
|
||||
('imputer', SimpleImputer(strategy='mean')),
|
||||
('scaler', StandardScaler())
|
||||
]), ['age', 'income']),
|
||||
|
||||
# Numeric: impute + log + scale
|
||||
('num_skewed', Pipeline([
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('log', FunctionTransformer(np.log1p)),
|
||||
('scaler', StandardScaler())
|
||||
]), ['price', 'revenue']),
|
||||
|
||||
# Categorical: impute + one-hot
|
||||
('cat', Pipeline([
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||
]), ['category', 'region']),
|
||||
|
||||
# Text: TF-IDF
|
||||
('text', TfidfVectorizer(), 'description')
|
||||
])
|
||||
```
|
||||
|
||||
### Feature Engineering Pipeline
|
||||
|
||||
```python
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
|
||||
class FeatureEngineer(BaseEstimator, TransformerMixin):
|
||||
def fit(self, X, y=None):
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
X = X.copy()
|
||||
# Add engineered features
|
||||
X['age_income_ratio'] = X['age'] / (X['income'] + 1)
|
||||
X['total_score'] = X['score1'] + X['score2'] + X['score3']
|
||||
return X
|
||||
|
||||
pipeline = Pipeline([
|
||||
('engineer', FeatureEngineer()),
|
||||
('preprocessor', preprocessor),
|
||||
('model', RandomForestClassifier())
|
||||
])
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Always Use Pipelines When
|
||||
|
||||
1. **Preprocessing is needed**: Scaling, encoding, imputation
|
||||
2. **Cross-validation**: Ensures proper fit/transform split
|
||||
3. **Hyperparameter tuning**: Joint optimization of preprocessing and model
|
||||
4. **Production deployment**: Single object to serialize
|
||||
5. **Multiple steps**: Any workflow with >1 step
|
||||
|
||||
### Pipeline Do's
|
||||
|
||||
- ✅ Fit pipeline only on training data
|
||||
- ✅ Use ColumnTransformer for heterogeneous data
|
||||
- ✅ Cache expensive transformations during grid search
|
||||
- ✅ Use make_pipeline for simple cases
|
||||
- ✅ Set verbose=True to debug issues
|
||||
- ✅ Use remainder='passthrough' when appropriate
|
||||
|
||||
### Pipeline Don'ts
|
||||
|
||||
- ❌ Fit preprocessing on full dataset before split (data leakage!)
|
||||
- ❌ Manually transform test data (use pipeline.predict())
|
||||
- ❌ Forget to handle missing values before scaling
|
||||
- ❌ Mix pandas DataFrames and arrays inconsistently
|
||||
- ❌ Skip using pipelines for "just one preprocessing step"
|
||||
|
||||
### Data Leakage Prevention
|
||||
|
||||
```python
|
||||
# ❌ WRONG - Data leakage
|
||||
scaler = StandardScaler().fit(X) # Fit on all data
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
X_train_scaled = scaler.transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# ✅ CORRECT - No leakage with pipeline
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('model', LogisticRegression())
|
||||
])
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
pipeline.fit(X_train, y_train) # Scaler fits only on train
|
||||
y_pred = pipeline.predict(X_test) # Scaler transforms only on test
|
||||
|
||||
# ✅ CORRECT - No leakage in cross-validation
|
||||
scores = cross_val_score(pipeline, X, y, cv=5)
|
||||
# Each fold: scaler fits on train folds, transforms on test fold
|
||||
```
|
||||
|
||||
### Debugging Pipelines
|
||||
|
||||
```python
|
||||
# Examine intermediate outputs
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('pca', PCA(n_components=10)),
|
||||
('model', LogisticRegression())
|
||||
])
|
||||
|
||||
# Fit pipeline
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# Get output after scaling
|
||||
X_scaled = pipeline.named_steps['scaler'].transform(X_train)
|
||||
|
||||
# Get output after PCA
|
||||
X_pca = pipeline[:-1].transform(X_train) # All steps except last
|
||||
|
||||
# Or build partial pipeline
|
||||
partial_pipeline = Pipeline(pipeline.steps[:-1])
|
||||
X_transformed = partial_pipeline.transform(X_train)
|
||||
```
|
||||
|
||||
### Saving and Loading Pipelines
|
||||
|
||||
```python
|
||||
import joblib
|
||||
|
||||
# Save pipeline
|
||||
joblib.dump(pipeline, 'model_pipeline.pkl')
|
||||
|
||||
# Load pipeline
|
||||
pipeline = joblib.load('model_pipeline.pkl')
|
||||
|
||||
# Use loaded pipeline
|
||||
y_pred = pipeline.predict(X_new)
|
||||
```
|
||||
|
||||
## Common Errors and Solutions
|
||||
|
||||
**Error**: `ValueError: could not convert string to float`
|
||||
- **Cause**: Categorical features not encoded
|
||||
- **Solution**: Add OneHotEncoder or OrdinalEncoder to pipeline
|
||||
|
||||
**Error**: `All intermediate steps should be transformers`
|
||||
- **Cause**: Non-transformer in non-final position
|
||||
- **Solution**: Ensure only last step is predictor
|
||||
|
||||
**Error**: `X has different number of features than during fitting`
|
||||
- **Cause**: Different columns in train and test
|
||||
- **Solution**: Ensure consistent column handling, use `handle_unknown='ignore'` in OneHotEncoder
|
||||
|
||||
**Error**: Different results in cross-validation vs train-test split
|
||||
- **Cause**: Data leakage (fitting preprocessing on all data)
|
||||
- **Solution**: Always use Pipeline for preprocessing
|
||||
|
||||
**Error**: Pipeline too slow during grid search
|
||||
- **Solution**: Use caching with `memory` parameter
|
||||
@@ -1,413 +0,0 @@
|
||||
# Data Preprocessing in scikit-learn
|
||||
|
||||
## Overview
|
||||
Preprocessing transforms raw data into a format suitable for machine learning algorithms. Many algorithms require standardized or normalized data to perform well.
|
||||
|
||||
## Standardization and Scaling
|
||||
|
||||
### StandardScaler
|
||||
Removes mean and scales to unit variance (z-score normalization).
|
||||
|
||||
**Formula**: `z = (x - μ) / σ`
|
||||
|
||||
**Use cases**:
|
||||
- Most ML algorithms (especially SVM, neural networks, PCA)
|
||||
- When features have different units or scales
|
||||
- When assuming Gaussian-like distribution
|
||||
|
||||
**Important**: Fit only on training data, then transform both train and test sets.
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test) # Use same parameters
|
||||
```
|
||||
|
||||
### MinMaxScaler
|
||||
Scales features to a specified range, typically [0, 1].
|
||||
|
||||
**Formula**: `X_scaled = (X - X_min) / (X_max - X_min)`
|
||||
|
||||
**Use cases**:
|
||||
- When bounded range is needed
|
||||
- Neural networks (often prefer [0, 1] range)
|
||||
- When distribution is not Gaussian
|
||||
- Image pixel values
|
||||
|
||||
**Parameters**:
|
||||
- `feature_range`: Tuple (min, max), default (0, 1)
|
||||
|
||||
**Warning**: Sensitive to outliers since it uses min/max.
|
||||
|
||||
### MaxAbsScaler
|
||||
Scales to [-1, 1] by dividing by maximum absolute value.
|
||||
|
||||
**Use cases**:
|
||||
- Sparse data (preserves sparsity)
|
||||
- Data already centered at zero
|
||||
- When sign of values is meaningful
|
||||
|
||||
**Advantage**: Doesn't shift/center the data, preserves zero entries.
|
||||
|
||||
### RobustScaler
|
||||
Uses median and interquartile range (IQR) instead of mean and standard deviation.
|
||||
|
||||
**Formula**: `X_scaled = (X - median) / IQR`
|
||||
|
||||
**Use cases**:
|
||||
- When outliers are present
|
||||
- When StandardScaler produces skewed results
|
||||
- Robust statistics preferred
|
||||
|
||||
**Parameters**:
|
||||
- `quantile_range`: Tuple (q_min, q_max), default (25.0, 75.0)
|
||||
|
||||
## Normalization
|
||||
|
||||
### normalize() function and Normalizer
|
||||
Scales individual samples (rows) to unit norm, not features (columns).
|
||||
|
||||
**Use cases**:
|
||||
- Text classification (TF-IDF vectors)
|
||||
- When similarity metrics (dot product, cosine) are used
|
||||
- When each sample should have equal weight
|
||||
|
||||
**Norms**:
|
||||
- `l1`: Manhattan norm (sum of absolutes = 1)
|
||||
- `l2`: Euclidean norm (sum of squares = 1) - **most common**
|
||||
- `max`: Maximum absolute value = 1
|
||||
|
||||
**Key difference from scalers**: Operates on rows (samples), not columns (features).
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import Normalizer
|
||||
normalizer = Normalizer(norm='l2')
|
||||
X_normalized = normalizer.transform(X)
|
||||
```
|
||||
|
||||
## Encoding Categorical Features
|
||||
|
||||
### OrdinalEncoder
|
||||
Converts categories to integers (0 to n_categories - 1).
|
||||
|
||||
**Use cases**:
|
||||
- Ordinal relationships exist (small < medium < large)
|
||||
- Preprocessing before other transformations
|
||||
- Tree-based algorithms (which can handle integers)
|
||||
|
||||
**Parameters**:
|
||||
- `handle_unknown`: 'error' or 'use_encoded_value'
|
||||
- `unknown_value`: Value for unknown categories
|
||||
- `encoded_missing_value`: Value for missing data
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import OrdinalEncoder
|
||||
encoder = OrdinalEncoder()
|
||||
X_encoded = encoder.fit_transform(X_categorical)
|
||||
```
|
||||
|
||||
### OneHotEncoder
|
||||
Creates binary columns for each category.
|
||||
|
||||
**Use cases**:
|
||||
- Nominal categories (no order)
|
||||
- Linear models, neural networks
|
||||
- When category relationships shouldn't be assumed
|
||||
|
||||
**Parameters**:
|
||||
- `drop`: 'first', 'if_binary', array-like (prevents multicollinearity)
|
||||
- `sparse_output`: True (default, memory efficient) or False
|
||||
- `handle_unknown`: 'error', 'ignore', 'infrequent_if_exist'
|
||||
- `min_frequency`: Group infrequent categories
|
||||
- `max_categories`: Limit number of categories
|
||||
|
||||
**High cardinality handling**:
|
||||
```python
|
||||
encoder = OneHotEncoder(min_frequency=100, handle_unknown='infrequent_if_exist')
|
||||
# Groups categories appearing < 100 times into 'infrequent' category
|
||||
```
|
||||
|
||||
**Memory tip**: Use `sparse_output=True` (default) for high-cardinality features.
|
||||
|
||||
### TargetEncoder
|
||||
Uses target statistics to encode categories.
|
||||
|
||||
**Use cases**:
|
||||
- High-cardinality categorical features (zip codes, user IDs)
|
||||
- When linear relationships with target are expected
|
||||
- Often improves performance over one-hot encoding
|
||||
|
||||
**How it works**:
|
||||
- Replaces category with mean of target for that category
|
||||
- Uses cross-fitting during fit_transform() to prevent target leakage
|
||||
- Applies smoothing to handle rare categories
|
||||
|
||||
**Parameters**:
|
||||
- `smooth`: Smoothing parameter for rare categories
|
||||
- `cv`: Cross-validation strategy
|
||||
|
||||
**Warning**: Only for supervised learning. Requires target variable.
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import TargetEncoder
|
||||
encoder = TargetEncoder()
|
||||
X_encoded = encoder.fit_transform(X_categorical, y)
|
||||
```
|
||||
|
||||
### LabelEncoder
|
||||
Encodes target labels into integers 0 to n_classes - 1.
|
||||
|
||||
**Use cases**: Encoding target variable for classification (not features!)
|
||||
|
||||
**Important**: Use `LabelEncoder` for targets, not features. For features, use OrdinalEncoder or OneHotEncoder.
|
||||
|
||||
### Binarizer
|
||||
Converts numeric values to binary (0 or 1) based on threshold.
|
||||
|
||||
**Use cases**: Creating binary features from continuous values
|
||||
|
||||
## Non-linear Transformations
|
||||
|
||||
### QuantileTransformer
|
||||
Maps features to uniform or normal distribution using rank transformation.
|
||||
|
||||
**Use cases**:
|
||||
- Unusual distributions (bimodal, heavy tails)
|
||||
- Reducing outlier impact
|
||||
- When normal distribution is desired
|
||||
|
||||
**Parameters**:
|
||||
- `output_distribution`: 'uniform' (default) or 'normal'
|
||||
- `n_quantiles`: Number of quantiles (default: min(1000, n_samples))
|
||||
|
||||
**Effect**: Strong transformation that reduces outlier influence and makes data more Gaussian-like.
|
||||
|
||||
### PowerTransformer
|
||||
Applies parametric monotonic transformation to make data more Gaussian.
|
||||
|
||||
**Methods**:
|
||||
- `yeo-johnson`: Works with positive and negative values (default)
|
||||
- `box-cox`: Only positive values
|
||||
|
||||
**Use cases**:
|
||||
- Skewed distributions
|
||||
- When Gaussian assumption is important
|
||||
- Variance stabilization
|
||||
|
||||
**Advantage**: Less radical than QuantileTransformer, preserves more of original relationships.
|
||||
|
||||
## Discretization
|
||||
|
||||
### KBinsDiscretizer
|
||||
Bins continuous features into discrete intervals.
|
||||
|
||||
**Strategies**:
|
||||
- `uniform`: Equal-width bins
|
||||
- `quantile`: Equal-frequency bins
|
||||
- `kmeans`: K-means clustering to determine bins
|
||||
|
||||
**Encoding**:
|
||||
- `ordinal`: Integer encoding (0 to n_bins - 1)
|
||||
- `onehot`: One-hot encoding
|
||||
- `onehot-dense`: Dense one-hot encoding
|
||||
|
||||
**Use cases**:
|
||||
- Making linear models handle non-linear relationships
|
||||
- Reducing noise in features
|
||||
- Making features more interpretable
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import KBinsDiscretizer
|
||||
disc = KBinsDiscretizer(n_bins=5, encode='onehot', strategy='quantile')
|
||||
X_binned = disc.fit_transform(X)
|
||||
```
|
||||
|
||||
## Feature Generation
|
||||
|
||||
### PolynomialFeatures
|
||||
Generates polynomial and interaction features.
|
||||
|
||||
**Parameters**:
|
||||
- `degree`: Polynomial degree
|
||||
- `interaction_only`: Only multiplicative interactions (no x²)
|
||||
- `include_bias`: Include constant feature
|
||||
|
||||
**Use cases**:
|
||||
- Adding non-linearity to linear models
|
||||
- Feature engineering
|
||||
- Polynomial regression
|
||||
|
||||
**Warning**: Number of features grows rapidly: (n+d)!/d!n! for degree d.
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
poly = PolynomialFeatures(degree=2, include_bias=False)
|
||||
X_poly = poly.fit_transform(X)
|
||||
# [x1, x2] → [x1, x2, x1², x1·x2, x2²]
|
||||
```
|
||||
|
||||
### SplineTransformer
|
||||
Generates B-spline basis functions.
|
||||
|
||||
**Use cases**:
|
||||
- Smooth non-linear transformations
|
||||
- Alternative to PolynomialFeatures (less oscillation at boundaries)
|
||||
- Generalized additive models (GAMs)
|
||||
|
||||
**Parameters**:
|
||||
- `n_knots`: Number of knots
|
||||
- `degree`: Spline degree
|
||||
- `knots`: Knot positions ('uniform', 'quantile', or array)
|
||||
|
||||
## Missing Value Handling
|
||||
|
||||
### SimpleImputer
|
||||
Imputes missing values with various strategies.
|
||||
|
||||
**Strategies**:
|
||||
- `mean`: Mean of column (numeric only)
|
||||
- `median`: Median of column (numeric only)
|
||||
- `most_frequent`: Mode (numeric or categorical)
|
||||
- `constant`: Fill with constant value
|
||||
|
||||
**Parameters**:
|
||||
- `strategy`: Imputation strategy
|
||||
- `fill_value`: Value when strategy='constant'
|
||||
- `missing_values`: What represents missing (np.nan, None, specific value)
|
||||
|
||||
```python
|
||||
from sklearn.impute import SimpleImputer
|
||||
imputer = SimpleImputer(strategy='median')
|
||||
X_imputed = imputer.fit_transform(X)
|
||||
```
|
||||
|
||||
### KNNImputer
|
||||
Imputes using k-nearest neighbors.
|
||||
|
||||
**Use cases**: When relationships between features should inform imputation
|
||||
|
||||
**Parameters**:
|
||||
- `n_neighbors`: Number of neighbors
|
||||
- `weights`: 'uniform' or 'distance'
|
||||
|
||||
### IterativeImputer
|
||||
Models each feature with missing values as function of other features.
|
||||
|
||||
**Use cases**:
|
||||
- Complex relationships between features
|
||||
- When multiple features have missing values
|
||||
- Higher quality imputation (but slower)
|
||||
|
||||
**Parameters**:
|
||||
- `estimator`: Estimator for regression (default: BayesianRidge)
|
||||
- `max_iter`: Maximum iterations
|
||||
|
||||
## Function Transformers
|
||||
|
||||
### FunctionTransformer
|
||||
Applies custom function to data.
|
||||
|
||||
**Use cases**:
|
||||
- Custom transformations in pipelines
|
||||
- Log transformation, square root, etc.
|
||||
- Domain-specific preprocessing
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import FunctionTransformer
|
||||
import numpy as np
|
||||
|
||||
log_transformer = FunctionTransformer(np.log1p, validate=True)
|
||||
X_log = log_transformer.transform(X)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Feature Scaling Guidelines
|
||||
|
||||
**Always scale**:
|
||||
- SVM, neural networks
|
||||
- K-nearest neighbors
|
||||
- Linear/Logistic regression with regularization
|
||||
- PCA, LDA
|
||||
- Gradient descent-based algorithms
|
||||
|
||||
**Don't need to scale**:
|
||||
- Tree-based algorithms (Decision Trees, Random Forests, Gradient Boosting)
|
||||
- Naive Bayes
|
||||
|
||||
### Pipeline Integration
|
||||
|
||||
Always use preprocessing within pipelines to prevent data leakage:
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('classifier', LogisticRegression())
|
||||
])
|
||||
|
||||
pipeline.fit(X_train, y_train) # Scaler fit only on train data
|
||||
y_pred = pipeline.predict(X_test) # Scaler transform only on test data
|
||||
```
|
||||
|
||||
### Common Transformations by Data Type
|
||||
|
||||
**Numeric - Continuous**:
|
||||
- StandardScaler (most common)
|
||||
- MinMaxScaler (neural networks)
|
||||
- RobustScaler (outliers present)
|
||||
- PowerTransformer (skewed data)
|
||||
|
||||
**Numeric - Count Data**:
|
||||
- sqrt or log transformation
|
||||
- QuantileTransformer
|
||||
- StandardScaler after transformation
|
||||
|
||||
**Categorical - Low Cardinality (<10 categories)**:
|
||||
- OneHotEncoder
|
||||
|
||||
**Categorical - High Cardinality (>10 categories)**:
|
||||
- TargetEncoder (supervised)
|
||||
- Frequency encoding
|
||||
- OneHotEncoder with min_frequency parameter
|
||||
|
||||
**Categorical - Ordinal**:
|
||||
- OrdinalEncoder
|
||||
|
||||
**Text**:
|
||||
- CountVectorizer or TfidfVectorizer
|
||||
- Normalizer after vectorization
|
||||
|
||||
### Data Leakage Prevention
|
||||
|
||||
1. **Fit only on training data**: Never include test data when fitting preprocessors
|
||||
2. **Use pipelines**: Ensures proper fit/transform separation
|
||||
3. **Cross-validation**: Use Pipeline with cross_val_score() for proper evaluation
|
||||
4. **Target encoding**: Use cv parameter in TargetEncoder for cross-fitting
|
||||
|
||||
```python
|
||||
# WRONG - data leakage
|
||||
scaler = StandardScaler().fit(X_full)
|
||||
X_train_scaled = scaler.transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# CORRECT - no leakage
|
||||
scaler = StandardScaler().fit(X_train)
|
||||
X_train_scaled = scaler.transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
```
|
||||
|
||||
## Preprocessing Checklist
|
||||
|
||||
Before modeling:
|
||||
1. Handle missing values (imputation or removal)
|
||||
2. Encode categorical variables appropriately
|
||||
3. Scale/normalize numeric features (if needed for algorithm)
|
||||
4. Handle outliers (RobustScaler, clipping, removal)
|
||||
5. Create additional features if beneficial (PolynomialFeatures, domain knowledge)
|
||||
6. Check for data leakage in preprocessing steps
|
||||
7. Wrap everything in a Pipeline
|
||||
@@ -1,625 +0,0 @@
|
||||
# Scikit-learn Quick Reference
|
||||
|
||||
## Essential Imports
|
||||
|
||||
```python
|
||||
# Core
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
|
||||
from sklearn.pipeline import Pipeline, make_pipeline
|
||||
from sklearn.compose import ColumnTransformer
|
||||
|
||||
# Preprocessing
|
||||
from sklearn.preprocessing import (
|
||||
StandardScaler, MinMaxScaler, RobustScaler,
|
||||
OneHotEncoder, OrdinalEncoder, LabelEncoder,
|
||||
PolynomialFeatures
|
||||
)
|
||||
from sklearn.impute import SimpleImputer
|
||||
|
||||
# Models - Classification
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.ensemble import (
|
||||
RandomForestClassifier,
|
||||
GradientBoostingClassifier,
|
||||
HistGradientBoostingClassifier
|
||||
)
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
|
||||
# Models - Regression
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
from sklearn.ensemble import (
|
||||
RandomForestRegressor,
|
||||
GradientBoostingRegressor,
|
||||
HistGradientBoostingRegressor
|
||||
)
|
||||
|
||||
# Clustering
|
||||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
# Dimensionality Reduction
|
||||
from sklearn.decomposition import PCA, NMF, TruncatedSVD
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Metrics
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, precision_score, recall_score, f1_score,
|
||||
confusion_matrix, classification_report,
|
||||
mean_squared_error, r2_score, mean_absolute_error
|
||||
)
|
||||
```
|
||||
|
||||
## Basic Workflow Template
|
||||
|
||||
### Classification
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
# Split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# Scale features
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# Train model
|
||||
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
||||
model.fit(X_train_scaled, y_train)
|
||||
|
||||
# Predict and evaluate
|
||||
y_pred = model.predict(X_test_scaled)
|
||||
print(classification_report(y_test, y_pred))
|
||||
```
|
||||
|
||||
### Regression
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
# Split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Scale features
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# Train model
|
||||
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
||||
model.fit(X_train_scaled, y_train)
|
||||
|
||||
# Predict and evaluate
|
||||
y_pred = model.predict(X_test_scaled)
|
||||
print(f"RMSE: {mean_squared_error(y_test, y_pred, squared=False):.3f}")
|
||||
print(f"R²: {r2_score(y_test, y_pred):.3f}")
|
||||
```
|
||||
|
||||
### With Pipeline (Recommended)
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split, cross_val_score
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
|
||||
])
|
||||
|
||||
# Split and train
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42
|
||||
)
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# Evaluate
|
||||
score = pipeline.score(X_test, y_test)
|
||||
cv_scores = cross_val_score(pipeline, X_train, y_train, cv=5)
|
||||
print(f"Test accuracy: {score:.3f}")
|
||||
print(f"CV accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
|
||||
```
|
||||
|
||||
## Common Preprocessing Patterns
|
||||
|
||||
### Numeric Data
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
numeric_transformer = Pipeline([
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('scaler', StandardScaler())
|
||||
])
|
||||
```
|
||||
|
||||
### Categorical Data
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
categorical_transformer = Pipeline([
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||
])
|
||||
```
|
||||
|
||||
### Mixed Data with ColumnTransformer
|
||||
|
||||
```python
|
||||
from sklearn.compose import ColumnTransformer
|
||||
|
||||
numeric_features = ['age', 'income', 'credit_score']
|
||||
categorical_features = ['country', 'occupation']
|
||||
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numeric_features),
|
||||
('cat', categorical_transformer, categorical_features)
|
||||
])
|
||||
|
||||
# Complete pipeline
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
pipeline = Pipeline([
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', RandomForestClassifier())
|
||||
])
|
||||
```
|
||||
|
||||
## Model Selection Cheat Sheet
|
||||
|
||||
### Quick Decision Tree
|
||||
|
||||
```
|
||||
Is it supervised?
|
||||
├─ Yes
|
||||
│ ├─ Predicting categories? → Classification
|
||||
│ │ ├─ Start with: LogisticRegression (baseline)
|
||||
│ │ ├─ Then try: RandomForestClassifier
|
||||
│ │ └─ Best performance: HistGradientBoostingClassifier
|
||||
│ └─ Predicting numbers? → Regression
|
||||
│ ├─ Start with: LinearRegression/Ridge (baseline)
|
||||
│ ├─ Then try: RandomForestRegressor
|
||||
│ └─ Best performance: HistGradientBoostingRegressor
|
||||
└─ No
|
||||
├─ Grouping similar items? → Clustering
|
||||
│ ├─ Know # clusters: KMeans
|
||||
│ └─ Unknown # clusters: DBSCAN or HDBSCAN
|
||||
├─ Reducing dimensions?
|
||||
│ ├─ For preprocessing: PCA
|
||||
│ └─ For visualization: t-SNE or UMAP
|
||||
└─ Finding outliers? → IsolationForest or LocalOutlierFactor
|
||||
```
|
||||
|
||||
### Algorithm Selection by Data Size
|
||||
|
||||
- **Small (<1K samples)**: Any algorithm
|
||||
- **Medium (1K-100K)**: Random Forests, Gradient Boosting, Neural Networks
|
||||
- **Large (>100K)**: SGDClassifier/Regressor, HistGradientBoosting, LinearSVC
|
||||
|
||||
### When to Scale Features
|
||||
|
||||
**Always scale**:
|
||||
- SVM, Neural Networks
|
||||
- K-Nearest Neighbors
|
||||
- Linear/Logistic Regression (with regularization)
|
||||
- PCA, LDA
|
||||
- Any gradient descent algorithm
|
||||
|
||||
**Don't need to scale**:
|
||||
- Tree-based (Decision Trees, Random Forests, Gradient Boosting)
|
||||
- Naive Bayes
|
||||
|
||||
## Hyperparameter Tuning
|
||||
|
||||
### GridSearchCV
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
param_grid = {
|
||||
'n_estimators': [100, 200, 500],
|
||||
'max_depth': [10, 20, None],
|
||||
'min_samples_split': [2, 5, 10]
|
||||
}
|
||||
|
||||
grid_search = GridSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_grid,
|
||||
cv=5,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1
|
||||
)
|
||||
|
||||
grid_search.fit(X_train, y_train)
|
||||
best_model = grid_search.best_estimator_
|
||||
print(f"Best params: {grid_search.best_params_}")
|
||||
```
|
||||
|
||||
### RandomizedSearchCV (Faster)
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from scipy.stats import randint, uniform
|
||||
|
||||
param_distributions = {
|
||||
'n_estimators': randint(100, 1000),
|
||||
'max_depth': randint(5, 50),
|
||||
'min_samples_split': randint(2, 20)
|
||||
}
|
||||
|
||||
random_search = RandomizedSearchCV(
|
||||
RandomForestClassifier(random_state=42),
|
||||
param_distributions,
|
||||
n_iter=50, # Number of combinations to try
|
||||
cv=5,
|
||||
n_jobs=-1,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
random_search.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### Pipeline with GridSearchCV
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
pipeline = Pipeline([
|
||||
('scaler', StandardScaler()),
|
||||
('svm', SVC())
|
||||
])
|
||||
|
||||
param_grid = {
|
||||
'svm__C': [0.1, 1, 10],
|
||||
'svm__kernel': ['rbf', 'linear'],
|
||||
'svm__gamma': ['scale', 'auto']
|
||||
}
|
||||
|
||||
grid = GridSearchCV(pipeline, param_grid, cv=5)
|
||||
grid.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Cross-Validation
|
||||
|
||||
### Basic Cross-Validation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
scores = cross_val_score(model, X, y, cv=5, scoring='accuracy')
|
||||
print(f"Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})")
|
||||
```
|
||||
|
||||
### Multiple Metrics
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import cross_validate
|
||||
|
||||
scoring = ['accuracy', 'precision_weighted', 'recall_weighted', 'f1_weighted']
|
||||
results = cross_validate(model, X, y, cv=5, scoring=scoring)
|
||||
|
||||
for metric in scoring:
|
||||
scores = results[f'test_{metric}']
|
||||
print(f"{metric}: {scores.mean():.3f} (+/- {scores.std():.3f})")
|
||||
```
|
||||
|
||||
### Custom CV Strategies
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit
|
||||
|
||||
# For imbalanced classification
|
||||
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
||||
|
||||
# For time series
|
||||
cv = TimeSeriesSplit(n_splits=5)
|
||||
|
||||
scores = cross_val_score(model, X, y, cv=cv)
|
||||
```
|
||||
|
||||
## Common Metrics
|
||||
|
||||
### Classification
|
||||
|
||||
```python
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, balanced_accuracy_score,
|
||||
precision_score, recall_score, f1_score,
|
||||
confusion_matrix, classification_report,
|
||||
roc_auc_score
|
||||
)
|
||||
|
||||
# Basic metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred, average='weighted')
|
||||
|
||||
# Comprehensive report
|
||||
print(classification_report(y_true, y_pred))
|
||||
|
||||
# ROC AUC (requires probabilities)
|
||||
y_proba = model.predict_proba(X_test)[:, 1]
|
||||
auc = roc_auc_score(y_true, y_proba)
|
||||
```
|
||||
|
||||
### Regression
|
||||
|
||||
```python
|
||||
from sklearn.metrics import (
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
r2_score
|
||||
)
|
||||
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
rmse = mean_squared_error(y_true, y_pred, squared=False)
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
r2 = r2_score(y_true, y_pred)
|
||||
|
||||
print(f"RMSE: {rmse:.3f}")
|
||||
print(f"MAE: {mae:.3f}")
|
||||
print(f"R²: {r2:.3f}")
|
||||
```
|
||||
|
||||
## Feature Engineering
|
||||
|
||||
### Polynomial Features
|
||||
|
||||
```python
|
||||
from sklearn.preprocessing import PolynomialFeatures
|
||||
|
||||
poly = PolynomialFeatures(degree=2, include_bias=False)
|
||||
X_poly = poly.fit_transform(X)
|
||||
# [x1, x2] → [x1, x2, x1², x1·x2, x2²]
|
||||
```
|
||||
|
||||
### Feature Selection
|
||||
|
||||
```python
|
||||
from sklearn.feature_selection import (
|
||||
SelectKBest, f_classif,
|
||||
RFE,
|
||||
SelectFromModel
|
||||
)
|
||||
|
||||
# Univariate selection
|
||||
selector = SelectKBest(f_classif, k=10)
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
|
||||
# Recursive feature elimination
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
rfe = RFE(RandomForestClassifier(), n_features_to_select=10)
|
||||
X_selected = rfe.fit_transform(X, y)
|
||||
|
||||
# Model-based selection
|
||||
selector = SelectFromModel(
|
||||
RandomForestClassifier(n_estimators=100),
|
||||
threshold='median'
|
||||
)
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
```
|
||||
|
||||
### Feature Importance
|
||||
|
||||
```python
|
||||
# Tree-based models
|
||||
model = RandomForestClassifier()
|
||||
model.fit(X_train, y_train)
|
||||
importances = model.feature_importances_
|
||||
|
||||
# Visualize
|
||||
import matplotlib.pyplot as plt
|
||||
indices = np.argsort(importances)[::-1]
|
||||
plt.bar(range(X.shape[1]), importances[indices])
|
||||
plt.xticks(range(X.shape[1]), feature_names[indices], rotation=90)
|
||||
plt.show()
|
||||
|
||||
# Permutation importance (works for any model)
|
||||
from sklearn.inspection import permutation_importance
|
||||
result = permutation_importance(model, X_test, y_test, n_repeats=10)
|
||||
importances = result.importances_mean
|
||||
```
|
||||
|
||||
## Clustering
|
||||
|
||||
### K-Means
|
||||
|
||||
```python
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
# Always scale for k-means
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
|
||||
# Fit k-means
|
||||
kmeans = KMeans(n_clusters=3, random_state=42)
|
||||
labels = kmeans.fit_predict(X_scaled)
|
||||
|
||||
# Evaluate
|
||||
from sklearn.metrics import silhouette_score
|
||||
score = silhouette_score(X_scaled, labels)
|
||||
print(f"Silhouette score: {score:.3f}")
|
||||
```
|
||||
|
||||
### Elbow Method
|
||||
|
||||
```python
|
||||
inertias = []
|
||||
K_range = range(2, 11)
|
||||
|
||||
for k in K_range:
|
||||
kmeans = KMeans(n_clusters=k, random_state=42)
|
||||
kmeans.fit(X_scaled)
|
||||
inertias.append(kmeans.inertia_)
|
||||
|
||||
plt.plot(K_range, inertias, 'bo-')
|
||||
plt.xlabel('k')
|
||||
plt.ylabel('Inertia')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
### DBSCAN
|
||||
|
||||
```python
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
dbscan = DBSCAN(eps=0.5, min_samples=5)
|
||||
labels = dbscan.fit_predict(X_scaled)
|
||||
|
||||
# -1 indicates noise/outliers
|
||||
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
||||
n_noise = list(labels).count(-1)
|
||||
print(f"Clusters: {n_clusters}, Noise points: {n_noise}")
|
||||
```
|
||||
|
||||
## Dimensionality Reduction
|
||||
|
||||
### PCA
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
# Always scale before PCA
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
|
||||
# Specify n_components
|
||||
pca = PCA(n_components=2)
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
|
||||
# Or specify variance to retain
|
||||
pca = PCA(n_components=0.95) # Keep 95% variance
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
|
||||
print(f"Explained variance: {pca.explained_variance_ratio_}")
|
||||
print(f"Components needed: {pca.n_components_}")
|
||||
```
|
||||
|
||||
### t-SNE (Visualization Only)
|
||||
|
||||
```python
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Reduce to 50 dimensions with PCA first (recommended)
|
||||
pca = PCA(n_components=50)
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
|
||||
# Apply t-SNE
|
||||
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
||||
X_tsne = tsne.fit_transform(X_pca)
|
||||
|
||||
# Visualize
|
||||
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='viridis')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
```python
|
||||
import joblib
|
||||
|
||||
# Save model
|
||||
joblib.dump(model, 'model.pkl')
|
||||
|
||||
# Save pipeline
|
||||
joblib.dump(pipeline, 'pipeline.pkl')
|
||||
|
||||
# Load
|
||||
model = joblib.load('model.pkl')
|
||||
pipeline = joblib.load('pipeline.pkl')
|
||||
|
||||
# Use loaded model
|
||||
y_pred = model.predict(X_new)
|
||||
```
|
||||
|
||||
## Common Pitfalls and Solutions
|
||||
|
||||
### Data Leakage
|
||||
❌ **Wrong**: Fit on all data before split
|
||||
```python
|
||||
scaler = StandardScaler().fit(X)
|
||||
X_train, X_test = train_test_split(scaler.transform(X))
|
||||
```
|
||||
|
||||
✅ **Correct**: Use pipeline or fit only on train
|
||||
```python
|
||||
X_train, X_test = train_test_split(X)
|
||||
pipeline = Pipeline([('scaler', StandardScaler()), ('model', model)])
|
||||
pipeline.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### Not Scaling
|
||||
❌ **Wrong**: Using SVM without scaling
|
||||
```python
|
||||
svm = SVC()
|
||||
svm.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
✅ **Correct**: Scale for SVM
|
||||
```python
|
||||
pipeline = Pipeline([('scaler', StandardScaler()), ('svm', SVC())])
|
||||
pipeline.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### Wrong Metric for Imbalanced Data
|
||||
❌ **Wrong**: Using accuracy for 99:1 imbalance
|
||||
```python
|
||||
accuracy = accuracy_score(y_true, y_pred) # Can be misleading
|
||||
```
|
||||
|
||||
✅ **Correct**: Use appropriate metrics
|
||||
```python
|
||||
f1 = f1_score(y_true, y_pred, average='weighted')
|
||||
balanced_acc = balanced_accuracy_score(y_true, y_pred)
|
||||
```
|
||||
|
||||
### Not Using Stratification
|
||||
❌ **Wrong**: Random split for imbalanced data
|
||||
```python
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
|
||||
```
|
||||
|
||||
✅ **Correct**: Stratify for imbalanced classes
|
||||
```python
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, stratify=y
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Use n_jobs=-1** for parallel processing (RandomForest, GridSearchCV)
|
||||
2. **Use HistGradientBoosting** for large datasets (>10K samples)
|
||||
3. **Use MiniBatchKMeans** for large clustering tasks
|
||||
4. **Use IncrementalPCA** for data that doesn't fit in memory
|
||||
5. **Use sparse matrices** for high-dimensional sparse data (text)
|
||||
6. **Cache transformers** in pipelines during grid search
|
||||
7. **Use RandomizedSearchCV** instead of GridSearchCV for large parameter spaces
|
||||
8. **Reduce dimensionality** with PCA before applying expensive algorithms
|
||||
@@ -1,261 +0,0 @@
|
||||
# Supervised Learning in scikit-learn
|
||||
|
||||
## Overview
|
||||
Supervised learning algorithms learn patterns from labeled training data to make predictions on new data. Scikit-learn organizes supervised learning into 17 major categories.
|
||||
|
||||
## Linear Models
|
||||
|
||||
### Regression
|
||||
- **LinearRegression**: Ordinary least squares regression
|
||||
- **Ridge**: L2-regularized regression, good for multicollinearity
|
||||
- **Lasso**: L1-regularized regression, performs feature selection
|
||||
- **ElasticNet**: Combined L1/L2 regularization
|
||||
- **LassoLars**: Lasso using Least Angle Regression algorithm
|
||||
- **BayesianRidge**: Bayesian approach with automatic relevance determination
|
||||
|
||||
### Classification
|
||||
- **LogisticRegression**: Binary and multiclass classification
|
||||
- **RidgeClassifier**: Ridge regression for classification
|
||||
- **SGDClassifier**: Linear classifiers with SGD training
|
||||
|
||||
**Use cases**: Baseline models, interpretable predictions, high-dimensional data, when linear relationships are expected
|
||||
|
||||
**Key parameters**:
|
||||
- `alpha`: Regularization strength (higher = more regularization)
|
||||
- `fit_intercept`: Whether to calculate intercept
|
||||
- `solver`: Optimization algorithm ('lbfgs', 'saga', 'liblinear')
|
||||
|
||||
## Support Vector Machines (SVM)
|
||||
|
||||
- **SVC**: Support Vector Classification
|
||||
- **SVR**: Support Vector Regression
|
||||
- **LinearSVC**: Linear SVM using liblinear (faster for large datasets)
|
||||
- **OneClassSVM**: Unsupervised outlier detection
|
||||
|
||||
**Use cases**: Complex non-linear decision boundaries, high-dimensional spaces, when clear margin of separation exists
|
||||
|
||||
**Key parameters**:
|
||||
- `kernel`: 'linear', 'poly', 'rbf', 'sigmoid'
|
||||
- `C`: Regularization parameter (lower = more regularization)
|
||||
- `gamma`: Kernel coefficient ('scale', 'auto', or float)
|
||||
- `degree`: Polynomial degree (for poly kernel)
|
||||
|
||||
**Performance tip**: SVMs don't scale well beyond tens of thousands of samples. Use LinearSVC for large datasets with linear kernel.
|
||||
|
||||
## Decision Trees
|
||||
|
||||
- **DecisionTreeClassifier**: Classification tree
|
||||
- **DecisionTreeRegressor**: Regression tree
|
||||
- **ExtraTreeClassifier/Regressor**: Extremely randomized tree
|
||||
|
||||
**Use cases**: Non-linear relationships, feature importance analysis, interpretable rules, handling mixed data types
|
||||
|
||||
**Key parameters**:
|
||||
- `max_depth`: Maximum tree depth (controls overfitting)
|
||||
- `min_samples_split`: Minimum samples to split a node
|
||||
- `min_samples_leaf`: Minimum samples in leaf node
|
||||
- `max_features`: Number of features to consider for splits
|
||||
- `criterion`: 'gini', 'entropy' (classification); 'squared_error', 'absolute_error' (regression)
|
||||
|
||||
**Overfitting prevention**: Limit `max_depth`, increase `min_samples_split/leaf`, use pruning with `ccp_alpha`
|
||||
|
||||
## Ensemble Methods
|
||||
|
||||
### Random Forests
|
||||
- **RandomForestClassifier**: Ensemble of decision trees
|
||||
- **RandomForestRegressor**: Regression variant
|
||||
|
||||
**Use cases**: Robust general-purpose algorithm, reduces overfitting vs single trees, handles non-linear relationships
|
||||
|
||||
**Key parameters**:
|
||||
- `n_estimators`: Number of trees (higher = better but slower)
|
||||
- `max_depth`: Maximum tree depth
|
||||
- `max_features`: Features per split ('sqrt', 'log2', int, float)
|
||||
- `bootstrap`: Whether to use bootstrap samples
|
||||
- `n_jobs`: Parallel processing (-1 uses all cores)
|
||||
|
||||
### Gradient Boosting
|
||||
- **HistGradientBoostingClassifier/Regressor**: Histogram-based, fast for large datasets (>10k samples)
|
||||
- **GradientBoostingClassifier/Regressor**: Traditional implementation, better for small datasets
|
||||
|
||||
**Use cases**: High-performance predictions, winning Kaggle competitions, structured/tabular data
|
||||
|
||||
**Key parameters**:
|
||||
- `n_estimators`: Number of boosting stages
|
||||
- `learning_rate`: Shrinks contribution of each tree
|
||||
- `max_depth`: Maximum tree depth (typically 3-8)
|
||||
- `subsample`: Fraction of samples per tree (enables stochastic gradient boosting)
|
||||
- `early_stopping`: Stop when validation score stops improving
|
||||
|
||||
**Performance tip**: HistGradientBoosting is orders of magnitude faster for large datasets
|
||||
|
||||
### AdaBoost
|
||||
- **AdaBoostClassifier/Regressor**: Adaptive boosting
|
||||
|
||||
**Use cases**: Boosting weak learners, less prone to overfitting than other methods
|
||||
|
||||
**Key parameters**:
|
||||
- `estimator`: Base estimator (default: DecisionTreeClassifier with max_depth=1)
|
||||
- `n_estimators`: Number of boosting iterations
|
||||
- `learning_rate`: Weight applied to each classifier
|
||||
|
||||
### Bagging
|
||||
- **BaggingClassifier/Regressor**: Bootstrap aggregating with any base estimator
|
||||
|
||||
**Use cases**: Reducing variance of unstable models, parallel ensemble creation
|
||||
|
||||
**Key parameters**:
|
||||
- `estimator`: Base estimator to fit
|
||||
- `n_estimators`: Number of estimators
|
||||
- `max_samples`: Samples to draw per estimator
|
||||
- `bootstrap`: Whether to use replacement
|
||||
|
||||
### Voting & Stacking
|
||||
- **VotingClassifier/Regressor**: Combines different model types
|
||||
- **StackingClassifier/Regressor**: Meta-learner trained on base predictions
|
||||
|
||||
**Use cases**: Combining diverse models, leveraging different model strengths
|
||||
|
||||
## Neural Networks
|
||||
|
||||
- **MLPClassifier**: Multi-layer perceptron classifier
|
||||
- **MLPRegressor**: Multi-layer perceptron regressor
|
||||
|
||||
**Use cases**: Complex non-linear patterns, when gradient boosting is too slow, deep feature learning
|
||||
|
||||
**Key parameters**:
|
||||
- `hidden_layer_sizes`: Tuple of hidden layer sizes (e.g., (100, 50))
|
||||
- `activation`: 'relu', 'tanh', 'logistic'
|
||||
- `solver`: 'adam', 'lbfgs', 'sgd'
|
||||
- `alpha`: L2 regularization term
|
||||
- `learning_rate`: Learning rate schedule
|
||||
- `early_stopping`: Stop when validation score stops improving
|
||||
|
||||
**Important**: Feature scaling is critical for neural networks. Always use StandardScaler or similar.
|
||||
|
||||
## Nearest Neighbors
|
||||
|
||||
- **KNeighborsClassifier/Regressor**: K-nearest neighbors
|
||||
- **RadiusNeighborsClassifier/Regressor**: Radius-based neighbors
|
||||
- **NearestCentroid**: Classification using class centroids
|
||||
|
||||
**Use cases**: Simple baseline, irregular decision boundaries, when interpretability isn't critical
|
||||
|
||||
**Key parameters**:
|
||||
- `n_neighbors`: Number of neighbors (typically 3-11)
|
||||
- `weights`: 'uniform' or 'distance' (distance-weighted voting)
|
||||
- `metric`: Distance metric ('euclidean', 'manhattan', 'minkowski')
|
||||
- `algorithm`: 'auto', 'ball_tree', 'kd_tree', 'brute'
|
||||
|
||||
## Naive Bayes
|
||||
|
||||
- **GaussianNB**: Assumes Gaussian distribution of features
|
||||
- **MultinomialNB**: For discrete counts (text classification)
|
||||
- **BernoulliNB**: For binary/boolean features
|
||||
- **CategoricalNB**: For categorical features
|
||||
- **ComplementNB**: Adapted for imbalanced datasets
|
||||
|
||||
**Use cases**: Text classification, fast baseline, when features are independent, small training sets
|
||||
|
||||
**Key parameters**:
|
||||
- `alpha`: Smoothing parameter (Laplace/Lidstone smoothing)
|
||||
- `fit_prior`: Whether to learn class prior probabilities
|
||||
|
||||
## Linear/Quadratic Discriminant Analysis
|
||||
|
||||
- **LinearDiscriminantAnalysis**: Linear decision boundary with dimensionality reduction
|
||||
- **QuadraticDiscriminantAnalysis**: Quadratic decision boundary
|
||||
|
||||
**Use cases**: When classes have Gaussian distributions, dimensionality reduction, when covariance assumptions hold
|
||||
|
||||
## Gaussian Processes
|
||||
|
||||
- **GaussianProcessClassifier**: Probabilistic classification
|
||||
- **GaussianProcessRegressor**: Probabilistic regression with uncertainty estimates
|
||||
|
||||
**Use cases**: When uncertainty quantification is important, small datasets, smooth function approximation
|
||||
|
||||
**Key parameters**:
|
||||
- `kernel`: Covariance function (RBF, Matern, RationalQuadratic, etc.)
|
||||
- `alpha`: Noise level
|
||||
|
||||
**Limitation**: Doesn't scale well to large datasets (O(n³) complexity)
|
||||
|
||||
## Stochastic Gradient Descent
|
||||
|
||||
- **SGDClassifier**: Linear classifiers with SGD
|
||||
- **SGDRegressor**: Linear regressors with SGD
|
||||
|
||||
**Use cases**: Very large datasets (>100k samples), online learning, when data doesn't fit in memory
|
||||
|
||||
**Key parameters**:
|
||||
- `loss`: Loss function ('hinge', 'log_loss', 'squared_error', etc.)
|
||||
- `penalty`: Regularization ('l2', 'l1', 'elasticnet')
|
||||
- `alpha`: Regularization strength
|
||||
- `learning_rate`: Learning rate schedule
|
||||
|
||||
## Semi-Supervised Learning
|
||||
|
||||
- **SelfTrainingClassifier**: Self-training with any base classifier
|
||||
- **LabelPropagation**: Label propagation through graph
|
||||
- **LabelSpreading**: Label spreading (modified label propagation)
|
||||
|
||||
**Use cases**: When labeled data is scarce but unlabeled data is abundant
|
||||
|
||||
## Feature Selection
|
||||
|
||||
- **VarianceThreshold**: Remove low-variance features
|
||||
- **SelectKBest**: Select K highest scoring features
|
||||
- **SelectPercentile**: Select top percentile of features
|
||||
- **RFE**: Recursive feature elimination
|
||||
- **RFECV**: RFE with cross-validation
|
||||
- **SelectFromModel**: Select features based on importance
|
||||
- **SequentialFeatureSelector**: Forward/backward feature selection
|
||||
|
||||
**Use cases**: Reducing dimensionality, removing irrelevant features, improving interpretability, reducing overfitting
|
||||
|
||||
## Probability Calibration
|
||||
|
||||
- **CalibratedClassifierCV**: Calibrate classifier probabilities
|
||||
|
||||
**Use cases**: When probability estimates are important (not just class predictions), especially with SVM and Naive Bayes
|
||||
|
||||
**Methods**:
|
||||
- `sigmoid`: Platt scaling
|
||||
- `isotonic`: Isotonic regression (more flexible, needs more data)
|
||||
|
||||
## Multi-Output Methods
|
||||
|
||||
- **MultiOutputClassifier**: Fit one classifier per target
|
||||
- **MultiOutputRegressor**: Fit one regressor per target
|
||||
- **ClassifierChain**: Models dependencies between targets
|
||||
- **RegressorChain**: Regression variant
|
||||
|
||||
**Use cases**: Predicting multiple related targets simultaneously
|
||||
|
||||
## Specialized Regression
|
||||
|
||||
- **IsotonicRegression**: Monotonic regression
|
||||
- **QuantileRegressor**: Quantile regression for prediction intervals
|
||||
|
||||
## Algorithm Selection Guidelines
|
||||
|
||||
**Start with**:
|
||||
1. **Logistic Regression** (classification) or **LinearRegression/Ridge** (regression) as baseline
|
||||
2. **RandomForestClassifier/Regressor** for general non-linear problems
|
||||
3. **HistGradientBoostingClassifier/Regressor** when best performance is needed
|
||||
|
||||
**Consider dataset size**:
|
||||
- Small (<1k samples): SVM, Gaussian Processes, any algorithm
|
||||
- Medium (1k-100k): Random Forests, Gradient Boosting, Neural Networks
|
||||
- Large (>100k): SGD, HistGradientBoosting, LinearSVC
|
||||
|
||||
**Consider interpretability needs**:
|
||||
- High interpretability: Linear models, Decision Trees, Naive Bayes
|
||||
- Medium: Random Forests (feature importance), Rule extraction
|
||||
- Low (black box acceptable): Gradient Boosting, Neural Networks, SVM with RBF kernel
|
||||
|
||||
**Consider training time**:
|
||||
- Fast: Linear models, Naive Bayes, Decision Trees
|
||||
- Medium: Random Forests (parallelizable), SVM (small data)
|
||||
- Slow: Gradient Boosting, Neural Networks, SVM (large data), Gaussian Processes
|
||||
@@ -1,728 +0,0 @@
|
||||
# Unsupervised Learning in scikit-learn
|
||||
|
||||
## Overview
|
||||
Unsupervised learning discovers patterns in data without labeled targets. Main tasks include clustering (grouping similar samples), dimensionality reduction (reducing feature count), and anomaly detection (finding outliers).
|
||||
|
||||
## Clustering Algorithms
|
||||
|
||||
### K-Means
|
||||
|
||||
Groups data into k clusters by minimizing within-cluster variance.
|
||||
|
||||
**Algorithm**:
|
||||
1. Initialize k centroids (k-means++ initialization recommended)
|
||||
2. Assign each point to nearest centroid
|
||||
3. Update centroids to mean of assigned points
|
||||
4. Repeat until convergence
|
||||
|
||||
```python
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
kmeans = KMeans(
|
||||
n_clusters=3,
|
||||
init='k-means++', # Smart initialization
|
||||
n_init=10, # Number of times to run with different seeds
|
||||
max_iter=300,
|
||||
random_state=42
|
||||
)
|
||||
labels = kmeans.fit_predict(X)
|
||||
centroids = kmeans.cluster_centers_
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Customer segmentation
|
||||
- Image compression
|
||||
- Data preprocessing (clustering as features)
|
||||
|
||||
**Strengths**:
|
||||
- Fast and scalable
|
||||
- Simple to understand
|
||||
- Works well with spherical clusters
|
||||
|
||||
**Limitations**:
|
||||
- Assumes spherical clusters of similar size
|
||||
- Sensitive to initialization (mitigated by k-means++)
|
||||
- Must specify k beforehand
|
||||
- Sensitive to outliers
|
||||
|
||||
**Choosing k**: Use elbow method, silhouette score, or domain knowledge
|
||||
|
||||
**Variants**:
|
||||
- **MiniBatchKMeans**: Faster for large datasets, uses mini-batches
|
||||
- **KMeans with n_init='auto'**: Adaptive number of initializations
|
||||
|
||||
### DBSCAN
|
||||
|
||||
Density-Based Spatial Clustering of Applications with Noise. Identifies clusters as dense regions separated by sparse areas.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
dbscan = DBSCAN(
|
||||
eps=0.5, # Maximum distance between neighbors
|
||||
min_samples=5, # Minimum points to form dense region
|
||||
metric='euclidean'
|
||||
)
|
||||
labels = dbscan.fit_predict(X)
|
||||
# -1 indicates noise/outliers
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Arbitrary cluster shapes
|
||||
- Outlier detection
|
||||
- When cluster count is unknown
|
||||
- Geographic/spatial data
|
||||
|
||||
**Strengths**:
|
||||
- Discovers arbitrary-shaped clusters
|
||||
- Automatically detects outliers
|
||||
- Doesn't require specifying number of clusters
|
||||
- Robust to outliers
|
||||
|
||||
**Limitations**:
|
||||
- Struggles with varying densities
|
||||
- Sensitive to eps and min_samples parameters
|
||||
- Not deterministic (border points may vary)
|
||||
|
||||
**Parameter tuning**:
|
||||
- `eps`: Plot k-distance graph, look for elbow
|
||||
- `min_samples`: Rule of thumb: 2 * dimensions
|
||||
|
||||
### HDBSCAN
|
||||
|
||||
Hierarchical DBSCAN that handles variable cluster densities.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import HDBSCAN
|
||||
|
||||
hdbscan = HDBSCAN(
|
||||
min_cluster_size=5,
|
||||
min_samples=None, # Defaults to min_cluster_size
|
||||
metric='euclidean'
|
||||
)
|
||||
labels = hdbscan.fit_predict(X)
|
||||
```
|
||||
|
||||
**Advantages over DBSCAN**:
|
||||
- Handles variable density clusters
|
||||
- More robust parameter selection
|
||||
- Provides cluster membership probabilities
|
||||
- Hierarchical structure
|
||||
|
||||
**Use cases**: When DBSCAN struggles with varying densities
|
||||
|
||||
### Hierarchical Clustering
|
||||
|
||||
Builds nested cluster hierarchies using agglomerative (bottom-up) approach.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
|
||||
agg_clust = AgglomerativeClustering(
|
||||
n_clusters=3,
|
||||
linkage='ward', # 'ward', 'complete', 'average', 'single'
|
||||
metric='euclidean'
|
||||
)
|
||||
labels = agg_clust.fit_predict(X)
|
||||
|
||||
# Visualize with dendrogram
|
||||
from scipy.cluster.hierarchy import dendrogram, linkage as scipy_linkage
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
linkage_matrix = scipy_linkage(X, method='ward')
|
||||
dendrogram(linkage_matrix)
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Linkage methods**:
|
||||
- `ward`: Minimizes variance (only with Euclidean) - **most common**
|
||||
- `complete`: Maximum distance between clusters
|
||||
- `average`: Average distance between clusters
|
||||
- `single`: Minimum distance between clusters
|
||||
|
||||
**Use cases**:
|
||||
- When hierarchical structure is meaningful
|
||||
- Taxonomy/phylogenetic trees
|
||||
- When visualization is important (dendrograms)
|
||||
|
||||
**Strengths**:
|
||||
- No need to specify k initially (cut dendrogram at desired level)
|
||||
- Produces hierarchy of clusters
|
||||
- Deterministic
|
||||
|
||||
**Limitations**:
|
||||
- Computationally expensive (O(n²) to O(n³))
|
||||
- Not suitable for large datasets
|
||||
- Cannot undo previous merges
|
||||
|
||||
### Spectral Clustering
|
||||
|
||||
Performs dimensionality reduction using affinity matrix before clustering.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import SpectralClustering
|
||||
|
||||
spectral = SpectralClustering(
|
||||
n_clusters=3,
|
||||
affinity='rbf', # 'rbf', 'nearest_neighbors', 'precomputed'
|
||||
gamma=1.0,
|
||||
n_neighbors=10,
|
||||
random_state=42
|
||||
)
|
||||
labels = spectral.fit_predict(X)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Non-convex clusters
|
||||
- Image segmentation
|
||||
- Graph clustering
|
||||
- When similarity matrix is available
|
||||
|
||||
**Strengths**:
|
||||
- Handles non-convex clusters
|
||||
- Works with similarity matrices
|
||||
- Often better than k-means for complex shapes
|
||||
|
||||
**Limitations**:
|
||||
- Computationally expensive
|
||||
- Requires specifying number of clusters
|
||||
- Memory intensive
|
||||
|
||||
### Mean Shift
|
||||
|
||||
Discovers clusters through iterative centroid updates based on density.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import MeanShift, estimate_bandwidth
|
||||
|
||||
# Estimate bandwidth
|
||||
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
|
||||
|
||||
mean_shift = MeanShift(bandwidth=bandwidth)
|
||||
labels = mean_shift.fit_predict(X)
|
||||
cluster_centers = mean_shift.cluster_centers_
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- When cluster count is unknown
|
||||
- Computer vision applications
|
||||
- Object tracking
|
||||
|
||||
**Strengths**:
|
||||
- Automatically determines number of clusters
|
||||
- Handles arbitrary shapes
|
||||
- No assumptions about cluster shape
|
||||
|
||||
**Limitations**:
|
||||
- Computationally expensive
|
||||
- Very sensitive to bandwidth parameter
|
||||
- Doesn't scale well
|
||||
|
||||
### Affinity Propagation
|
||||
|
||||
Uses message-passing between samples to identify exemplars.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import AffinityPropagation
|
||||
|
||||
affinity_prop = AffinityPropagation(
|
||||
damping=0.5, # Damping factor (0.5-1.0)
|
||||
preference=None, # Self-preference (controls number of clusters)
|
||||
random_state=42
|
||||
)
|
||||
labels = affinity_prop.fit_predict(X)
|
||||
exemplars = affinity_prop.cluster_centers_indices_
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- When number of clusters is unknown
|
||||
- When exemplars (representative samples) are needed
|
||||
|
||||
**Strengths**:
|
||||
- Automatically determines number of clusters
|
||||
- Identifies exemplar samples
|
||||
- No initialization required
|
||||
|
||||
**Limitations**:
|
||||
- Very slow: O(n²t) where t is iterations
|
||||
- Not suitable for large datasets
|
||||
- Memory intensive
|
||||
|
||||
### Gaussian Mixture Models (GMM)
|
||||
|
||||
Probabilistic model assuming data comes from mixture of Gaussian distributions.
|
||||
|
||||
```python
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
gmm = GaussianMixture(
|
||||
n_components=3,
|
||||
covariance_type='full', # 'full', 'tied', 'diag', 'spherical'
|
||||
random_state=42
|
||||
)
|
||||
labels = gmm.fit_predict(X)
|
||||
probabilities = gmm.predict_proba(X) # Soft clustering
|
||||
```
|
||||
|
||||
**Covariance types**:
|
||||
- `full`: Each component has its own covariance matrix
|
||||
- `tied`: All components share same covariance
|
||||
- `diag`: Diagonal covariance (independent features)
|
||||
- `spherical`: Spherical covariance (isotropic)
|
||||
|
||||
**Use cases**:
|
||||
- When soft clustering is needed (probabilities)
|
||||
- When clusters have different shapes/sizes
|
||||
- Generative modeling
|
||||
- Density estimation
|
||||
|
||||
**Strengths**:
|
||||
- Provides probabilities (soft clustering)
|
||||
- Can handle elliptical clusters
|
||||
- Generative model (can sample new data)
|
||||
- Model selection with BIC/AIC
|
||||
|
||||
**Limitations**:
|
||||
- Assumes Gaussian distributions
|
||||
- Sensitive to initialization
|
||||
- Can converge to local optima
|
||||
|
||||
**Model selection**:
|
||||
```python
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import numpy as np
|
||||
|
||||
n_components_range = range(2, 10)
|
||||
bic_scores = []
|
||||
|
||||
for n in n_components_range:
|
||||
gmm = GaussianMixture(n_components=n, random_state=42)
|
||||
gmm.fit(X)
|
||||
bic_scores.append(gmm.bic(X))
|
||||
|
||||
optimal_n = n_components_range[np.argmin(bic_scores)]
|
||||
```
|
||||
|
||||
### BIRCH
|
||||
|
||||
Builds Clustering Feature Tree for memory-efficient processing of large datasets.
|
||||
|
||||
```python
|
||||
from sklearn.cluster import Birch
|
||||
|
||||
birch = Birch(
|
||||
n_clusters=3,
|
||||
threshold=0.5,
|
||||
branching_factor=50
|
||||
)
|
||||
labels = birch.fit_predict(X)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Very large datasets
|
||||
- Streaming data
|
||||
- Memory constraints
|
||||
|
||||
**Strengths**:
|
||||
- Memory efficient
|
||||
- Single pass over data
|
||||
- Incremental learning
|
||||
|
||||
## Dimensionality Reduction
|
||||
|
||||
### Principal Component Analysis (PCA)
|
||||
|
||||
Finds orthogonal components that explain maximum variance.
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import PCA
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Specify number of components
|
||||
pca = PCA(n_components=2, random_state=42)
|
||||
X_transformed = pca.fit_transform(X)
|
||||
|
||||
print("Explained variance ratio:", pca.explained_variance_ratio_)
|
||||
print("Total variance explained:", pca.explained_variance_ratio_.sum())
|
||||
|
||||
# Or specify variance to retain
|
||||
pca = PCA(n_components=0.95) # Keep 95% of variance
|
||||
X_transformed = pca.fit_transform(X)
|
||||
print(f"Components needed: {pca.n_components_}")
|
||||
|
||||
# Visualize explained variance
|
||||
plt.plot(np.cumsum(pca.explained_variance_ratio_))
|
||||
plt.xlabel('Number of components')
|
||||
plt.ylabel('Cumulative explained variance')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Visualization (reduce to 2-3 dimensions)
|
||||
- Remove multicollinearity
|
||||
- Noise reduction
|
||||
- Speed up training
|
||||
- Feature extraction
|
||||
|
||||
**Strengths**:
|
||||
- Fast and efficient
|
||||
- Reduces multicollinearity
|
||||
- Works well for linear relationships
|
||||
- Interpretable components
|
||||
|
||||
**Limitations**:
|
||||
- Only linear transformations
|
||||
- Sensitive to scaling (always standardize first!)
|
||||
- Components may be hard to interpret
|
||||
|
||||
**Variants**:
|
||||
- **IncrementalPCA**: For datasets that don't fit in memory
|
||||
- **KernelPCA**: Non-linear dimensionality reduction
|
||||
- **SparsePCA**: Sparse loadings for interpretability
|
||||
|
||||
### t-SNE
|
||||
|
||||
t-Distributed Stochastic Neighbor Embedding for visualization.
|
||||
|
||||
```python
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
perplexity=30, # Balance local vs global structure (5-50)
|
||||
learning_rate='auto',
|
||||
n_iter=1000,
|
||||
random_state=42
|
||||
)
|
||||
X_embedded = tsne.fit_transform(X)
|
||||
|
||||
# Visualize
|
||||
plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y)
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Visualization only (do not use for preprocessing!)
|
||||
- Exploring high-dimensional data
|
||||
- Finding clusters visually
|
||||
|
||||
**Important notes**:
|
||||
- **Only for visualization**, not for preprocessing
|
||||
- Each run produces different results (use random_state for reproducibility)
|
||||
- Slow for large datasets
|
||||
- Cannot transform new data (no transform() method)
|
||||
|
||||
**Parameter tuning**:
|
||||
- `perplexity`: 5-50, larger for larger datasets
|
||||
- Lower perplexity = focus on local structure
|
||||
- Higher perplexity = focus on global structure
|
||||
|
||||
### UMAP
|
||||
|
||||
Uniform Manifold Approximation and Projection (requires umap-learn package).
|
||||
|
||||
**Advantages over t-SNE**:
|
||||
- Preserves global structure better
|
||||
- Faster
|
||||
- Can transform new data
|
||||
- Can be used for preprocessing (not just visualization)
|
||||
|
||||
### Truncated SVD (LSA)
|
||||
|
||||
Similar to PCA but works with sparse matrices (e.g., TF-IDF).
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
|
||||
svd = TruncatedSVD(n_components=100, random_state=42)
|
||||
X_reduced = svd.fit_transform(X_sparse)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Text data (after TF-IDF)
|
||||
- Sparse matrices
|
||||
- Latent Semantic Analysis (LSA)
|
||||
|
||||
### Non-negative Matrix Factorization (NMF)
|
||||
|
||||
Factorizes data into non-negative components.
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import NMF
|
||||
|
||||
nmf = NMF(n_components=10, init='nndsvd', random_state=42)
|
||||
W = nmf.fit_transform(X) # Document-topic matrix
|
||||
H = nmf.components_ # Topic-word matrix
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Topic modeling
|
||||
- Audio source separation
|
||||
- Image processing
|
||||
- When non-negativity is important (e.g., counts)
|
||||
|
||||
**Strengths**:
|
||||
- Interpretable components (additive, non-negative)
|
||||
- Sparse representations
|
||||
|
||||
### Independent Component Analysis (ICA)
|
||||
|
||||
Separates multivariate signal into independent components.
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import FastICA
|
||||
|
||||
ica = FastICA(n_components=10, random_state=42)
|
||||
X_independent = ica.fit_transform(X)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Blind source separation
|
||||
- Signal processing
|
||||
- Feature extraction when independence is expected
|
||||
|
||||
### Factor Analysis
|
||||
|
||||
Models observed variables as linear combinations of latent factors plus noise.
|
||||
|
||||
```python
|
||||
from sklearn.decomposition import FactorAnalysis
|
||||
|
||||
fa = FactorAnalysis(n_components=5, random_state=42)
|
||||
X_factors = fa.fit_transform(X)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- When noise is heteroscedastic
|
||||
- Latent variable modeling
|
||||
- Psychology/social science research
|
||||
|
||||
**Difference from PCA**: Models noise explicitly, assumes features have independent noise
|
||||
|
||||
## Anomaly Detection
|
||||
|
||||
### One-Class SVM
|
||||
|
||||
Learns boundary around normal data.
|
||||
|
||||
```python
|
||||
from sklearn.svm import OneClassSVM
|
||||
|
||||
oc_svm = OneClassSVM(
|
||||
nu=0.1, # Proportion of outliers expected
|
||||
kernel='rbf',
|
||||
gamma='auto'
|
||||
)
|
||||
oc_svm.fit(X_train)
|
||||
predictions = oc_svm.predict(X_test) # 1 for inliers, -1 for outliers
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Novelty detection
|
||||
- When only normal data is available for training
|
||||
|
||||
### Isolation Forest
|
||||
|
||||
Isolates outliers using random forests.
|
||||
|
||||
```python
|
||||
from sklearn.ensemble import IsolationForest
|
||||
|
||||
iso_forest = IsolationForest(
|
||||
contamination=0.1, # Expected proportion of outliers
|
||||
random_state=42
|
||||
)
|
||||
predictions = iso_forest.fit_predict(X) # 1 for inliers, -1 for outliers
|
||||
scores = iso_forest.score_samples(X) # Anomaly scores
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- General anomaly detection
|
||||
- Works well with high-dimensional data
|
||||
- Fast and scalable
|
||||
|
||||
**Strengths**:
|
||||
- Fast
|
||||
- Effective in high dimensions
|
||||
- Low memory requirements
|
||||
|
||||
### Local Outlier Factor (LOF)
|
||||
|
||||
Detects outliers based on local density deviation.
|
||||
|
||||
```python
|
||||
from sklearn.neighbors import LocalOutlierFactor
|
||||
|
||||
lof = LocalOutlierFactor(
|
||||
n_neighbors=20,
|
||||
contamination=0.1
|
||||
)
|
||||
predictions = lof.fit_predict(X) # 1 for inliers, -1 for outliers
|
||||
scores = lof.negative_outlier_factor_ # Anomaly scores (negative)
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Finding local outliers
|
||||
- When global methods fail
|
||||
|
||||
## Clustering Evaluation
|
||||
|
||||
### With Ground Truth Labels
|
||||
|
||||
When true labels are available (for validation):
|
||||
|
||||
**Adjusted Rand Index (ARI)**:
|
||||
```python
|
||||
from sklearn.metrics import adjusted_rand_score
|
||||
ari = adjusted_rand_score(y_true, y_pred)
|
||||
# Range: [-1, 1], 1 = perfect, 0 = random
|
||||
```
|
||||
|
||||
**Normalized Mutual Information (NMI)**:
|
||||
```python
|
||||
from sklearn.metrics import normalized_mutual_info_score
|
||||
nmi = normalized_mutual_info_score(y_true, y_pred)
|
||||
# Range: [0, 1], 1 = perfect
|
||||
```
|
||||
|
||||
**V-Measure**:
|
||||
```python
|
||||
from sklearn.metrics import v_measure_score
|
||||
v = v_measure_score(y_true, y_pred)
|
||||
# Range: [0, 1], harmonic mean of homogeneity and completeness
|
||||
```
|
||||
|
||||
### Without Ground Truth Labels
|
||||
|
||||
When true labels are unavailable (unsupervised evaluation):
|
||||
|
||||
**Silhouette Score**:
|
||||
Measures how similar objects are to their own cluster vs other clusters.
|
||||
|
||||
```python
|
||||
from sklearn.metrics import silhouette_score, silhouette_samples
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
score = silhouette_score(X, labels)
|
||||
# Range: [-1, 1], higher is better
|
||||
# >0.7: Strong structure
|
||||
# 0.5-0.7: Reasonable structure
|
||||
# 0.25-0.5: Weak structure
|
||||
# <0.25: No substantial structure
|
||||
|
||||
# Per-sample scores for detailed analysis
|
||||
sample_scores = silhouette_samples(X, labels)
|
||||
|
||||
# Visualize silhouette plot
|
||||
for i in range(n_clusters):
|
||||
cluster_scores = sample_scores[labels == i]
|
||||
cluster_scores.sort()
|
||||
plt.barh(range(len(cluster_scores)), cluster_scores)
|
||||
plt.axvline(x=score, color='red', linestyle='--')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
**Davies-Bouldin Index**:
|
||||
```python
|
||||
from sklearn.metrics import davies_bouldin_score
|
||||
db = davies_bouldin_score(X, labels)
|
||||
# Lower is better, 0 = perfect
|
||||
```
|
||||
|
||||
**Calinski-Harabasz Index** (Variance Ratio Criterion):
|
||||
```python
|
||||
from sklearn.metrics import calinski_harabasz_score
|
||||
ch = calinski_harabasz_score(X, labels)
|
||||
# Higher is better
|
||||
```
|
||||
|
||||
**Inertia** (K-Means specific):
|
||||
```python
|
||||
inertia = kmeans.inertia_
|
||||
# Sum of squared distances to nearest cluster center
|
||||
# Use for elbow method
|
||||
```
|
||||
|
||||
### Elbow Method (K-Means)
|
||||
|
||||
```python
|
||||
from sklearn.cluster import KMeans
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
inertias = []
|
||||
K_range = range(2, 11)
|
||||
|
||||
for k in K_range:
|
||||
kmeans = KMeans(n_clusters=k, random_state=42)
|
||||
kmeans.fit(X)
|
||||
inertias.append(kmeans.inertia_)
|
||||
|
||||
plt.plot(K_range, inertias, 'bo-')
|
||||
plt.xlabel('Number of clusters (k)')
|
||||
plt.ylabel('Inertia')
|
||||
plt.title('Elbow Method')
|
||||
plt.show()
|
||||
# Look for "elbow" where inertia starts decreasing more slowly
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Clustering Algorithm Selection
|
||||
|
||||
**Use K-Means when**:
|
||||
- Clusters are spherical and similar size
|
||||
- Speed is important
|
||||
- Data is not too high-dimensional
|
||||
|
||||
**Use DBSCAN when**:
|
||||
- Arbitrary cluster shapes
|
||||
- Number of clusters unknown
|
||||
- Outlier detection needed
|
||||
|
||||
**Use Hierarchical when**:
|
||||
- Hierarchy is meaningful
|
||||
- Small to medium datasets
|
||||
- Visualization is important
|
||||
|
||||
**Use GMM when**:
|
||||
- Soft clustering needed
|
||||
- Clusters have different shapes/sizes
|
||||
- Probabilistic interpretation needed
|
||||
|
||||
**Use Spectral Clustering when**:
|
||||
- Non-convex clusters
|
||||
- Have similarity matrix
|
||||
- Moderate dataset size
|
||||
|
||||
### Preprocessing for Clustering
|
||||
|
||||
1. **Always scale features**: Use StandardScaler or MinMaxScaler
|
||||
2. **Handle outliers**: Remove or use robust algorithms (DBSCAN, HDBSCAN)
|
||||
3. **Reduce dimensionality if needed**: PCA for speed, careful with interpretation
|
||||
4. **Check for categorical variables**: Encode appropriately or use specialized algorithms
|
||||
|
||||
### Dimensionality Reduction Guidelines
|
||||
|
||||
**For preprocessing/feature extraction**:
|
||||
- PCA (linear relationships)
|
||||
- TruncatedSVD (sparse data)
|
||||
- NMF (non-negative data)
|
||||
|
||||
**For visualization only**:
|
||||
- t-SNE (preserves local structure)
|
||||
- UMAP (preserves both local and global structure)
|
||||
|
||||
**Always**:
|
||||
- Standardize features before PCA
|
||||
- Use appropriate n_components (elbow plot, explained variance)
|
||||
- Don't use t-SNE for anything except visualization
|
||||
|
||||
### Common Pitfalls
|
||||
|
||||
1. **Not scaling data**: Most algorithms sensitive to scale
|
||||
2. **Using t-SNE for preprocessing**: Only for visualization!
|
||||
3. **Overfitting cluster count**: Too many clusters = overfitting noise
|
||||
4. **Ignoring outliers**: Can severely affect centroid-based methods
|
||||
5. **Wrong metric**: Euclidean assumes all features equally important
|
||||
6. **Not validating results**: Always check with multiple metrics and domain knowledge
|
||||
7. **PCA without standardization**: Components dominated by high-variance features
|
||||
@@ -1,219 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Complete classification pipeline with preprocessing, training, evaluation, and hyperparameter tuning.
|
||||
Demonstrates best practices for scikit-learn workflows.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
|
||||
import joblib
|
||||
|
||||
|
||||
def create_preprocessing_pipeline(numeric_features, categorical_features):
|
||||
"""
|
||||
Create preprocessing pipeline for mixed data types.
|
||||
|
||||
Args:
|
||||
numeric_features: List of numeric column names
|
||||
categorical_features: List of categorical column names
|
||||
|
||||
Returns:
|
||||
ColumnTransformer with appropriate preprocessing for each data type
|
||||
"""
|
||||
numeric_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('scaler', StandardScaler())
|
||||
])
|
||||
|
||||
categorical_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=True))
|
||||
])
|
||||
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numeric_features),
|
||||
('cat', categorical_transformer, categorical_features)
|
||||
])
|
||||
|
||||
return preprocessor
|
||||
|
||||
|
||||
def create_full_pipeline(preprocessor, classifier=None):
|
||||
"""
|
||||
Create complete ML pipeline with preprocessing and classification.
|
||||
|
||||
Args:
|
||||
preprocessor: Preprocessing ColumnTransformer
|
||||
classifier: Classifier instance (default: RandomForestClassifier)
|
||||
|
||||
Returns:
|
||||
Complete Pipeline
|
||||
"""
|
||||
if classifier is None:
|
||||
classifier = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
|
||||
|
||||
pipeline = Pipeline(steps=[
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', classifier)
|
||||
])
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def evaluate_model(pipeline, X_train, y_train, X_test, y_test, cv=5):
|
||||
"""
|
||||
Evaluate model using cross-validation and test set.
|
||||
|
||||
Args:
|
||||
pipeline: Trained pipeline
|
||||
X_train, y_train: Training data
|
||||
X_test, y_test: Test data
|
||||
cv: Number of cross-validation folds
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
# Cross-validation on training set
|
||||
cv_scores = cross_val_score(pipeline, X_train, y_train, cv=cv, scoring='accuracy')
|
||||
|
||||
# Test set evaluation
|
||||
y_pred = pipeline.predict(X_test)
|
||||
test_score = pipeline.score(X_test, y_test)
|
||||
|
||||
# Get probabilities if available
|
||||
try:
|
||||
y_proba = pipeline.predict_proba(X_test)
|
||||
if len(np.unique(y_test)) == 2:
|
||||
# Binary classification
|
||||
auc = roc_auc_score(y_test, y_proba[:, 1])
|
||||
else:
|
||||
# Multiclass
|
||||
auc = roc_auc_score(y_test, y_proba, multi_class='ovr')
|
||||
except:
|
||||
auc = None
|
||||
|
||||
results = {
|
||||
'cv_mean': cv_scores.mean(),
|
||||
'cv_std': cv_scores.std(),
|
||||
'test_score': test_score,
|
||||
'auc': auc,
|
||||
'classification_report': classification_report(y_test, y_pred),
|
||||
'confusion_matrix': confusion_matrix(y_test, y_pred)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def tune_hyperparameters(pipeline, X_train, y_train, param_grid, cv=5):
|
||||
"""
|
||||
Perform hyperparameter tuning using GridSearchCV.
|
||||
|
||||
Args:
|
||||
pipeline: Pipeline to tune
|
||||
X_train, y_train: Training data
|
||||
param_grid: Dictionary of parameters to search
|
||||
cv: Number of cross-validation folds
|
||||
|
||||
Returns:
|
||||
GridSearchCV object with best model
|
||||
"""
|
||||
grid_search = GridSearchCV(
|
||||
pipeline,
|
||||
param_grid,
|
||||
cv=cv,
|
||||
scoring='f1_weighted',
|
||||
n_jobs=-1,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
print(f"Best parameters: {grid_search.best_params_}")
|
||||
print(f"Best CV score: {grid_search.best_score_:.3f}")
|
||||
|
||||
return grid_search
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Example usage of the classification pipeline.
|
||||
"""
|
||||
# Load your data here
|
||||
# X, y = load_data()
|
||||
|
||||
# Example with synthetic data
|
||||
from sklearn.datasets import make_classification
|
||||
X, y = make_classification(
|
||||
n_samples=1000,
|
||||
n_features=20,
|
||||
n_informative=15,
|
||||
n_redundant=5,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
# Convert to DataFrame for demonstration
|
||||
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
|
||||
X = pd.DataFrame(X, columns=feature_names)
|
||||
|
||||
# Split features into numeric and categorical (all numeric in this example)
|
||||
numeric_features = feature_names
|
||||
categorical_features = []
|
||||
|
||||
# Split data (use stratify for imbalanced classes)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# Create preprocessing pipeline
|
||||
preprocessor = create_preprocessing_pipeline(numeric_features, categorical_features)
|
||||
|
||||
# Create full pipeline
|
||||
pipeline = create_full_pipeline(preprocessor)
|
||||
|
||||
# Train model
|
||||
print("Training model...")
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# Evaluate model
|
||||
print("\nEvaluating model...")
|
||||
results = evaluate_model(pipeline, X_train, y_train, X_test, y_test)
|
||||
|
||||
print(f"CV Accuracy: {results['cv_mean']:.3f} (+/- {results['cv_std']:.3f})")
|
||||
print(f"Test Accuracy: {results['test_score']:.3f}")
|
||||
if results['auc']:
|
||||
print(f"ROC-AUC: {results['auc']:.3f}")
|
||||
print("\nClassification Report:")
|
||||
print(results['classification_report'])
|
||||
|
||||
# Hyperparameter tuning (optional)
|
||||
print("\nTuning hyperparameters...")
|
||||
param_grid = {
|
||||
'classifier__n_estimators': [100, 200],
|
||||
'classifier__max_depth': [10, 20, None],
|
||||
'classifier__min_samples_split': [2, 5]
|
||||
}
|
||||
|
||||
grid_search = tune_hyperparameters(pipeline, X_train, y_train, param_grid)
|
||||
|
||||
# Evaluate best model
|
||||
print("\nEvaluating tuned model...")
|
||||
best_pipeline = grid_search.best_estimator_
|
||||
y_pred = best_pipeline.predict(X_test)
|
||||
print(classification_report(y_test, y_pred))
|
||||
|
||||
# Save model
|
||||
print("\nSaving model...")
|
||||
joblib.dump(best_pipeline, 'best_model.pkl')
|
||||
print("Model saved as 'best_model.pkl'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,291 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clustering analysis script with multiple algorithms and evaluation.
|
||||
Demonstrates k-means, DBSCAN, and hierarchical clustering with visualization.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
|
||||
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
|
||||
from sklearn.decomposition import PCA
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
def scale_data(X):
|
||||
"""
|
||||
Scale features using StandardScaler.
|
||||
ALWAYS scale data before clustering!
|
||||
|
||||
Args:
|
||||
X: Feature matrix
|
||||
|
||||
Returns:
|
||||
Scaled feature matrix and fitted scaler
|
||||
"""
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
return X_scaled, scaler
|
||||
|
||||
|
||||
def find_optimal_k(X_scaled, k_range=range(2, 11)):
|
||||
"""
|
||||
Find optimal number of clusters using elbow method and silhouette score.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
k_range: Range of k values to try
|
||||
|
||||
Returns:
|
||||
Dictionary with inertias and silhouette scores
|
||||
"""
|
||||
inertias = []
|
||||
silhouette_scores = []
|
||||
|
||||
for k in k_range:
|
||||
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
|
||||
labels = kmeans.fit_predict(X_scaled)
|
||||
inertias.append(kmeans.inertia_)
|
||||
silhouette_scores.append(silhouette_score(X_scaled, labels))
|
||||
|
||||
return {
|
||||
'k_values': list(k_range),
|
||||
'inertias': inertias,
|
||||
'silhouette_scores': silhouette_scores
|
||||
}
|
||||
|
||||
|
||||
def plot_elbow_silhouette(results):
|
||||
"""
|
||||
Plot elbow method and silhouette scores.
|
||||
|
||||
Args:
|
||||
results: Dictionary from find_optimal_k
|
||||
"""
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# Elbow plot
|
||||
ax1.plot(results['k_values'], results['inertias'], 'bo-')
|
||||
ax1.set_xlabel('Number of clusters (k)')
|
||||
ax1.set_ylabel('Inertia')
|
||||
ax1.set_title('Elbow Method')
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Silhouette plot
|
||||
ax2.plot(results['k_values'], results['silhouette_scores'], 'ro-')
|
||||
ax2.set_xlabel('Number of clusters (k)')
|
||||
ax2.set_ylabel('Silhouette Score')
|
||||
ax2.set_title('Silhouette Score vs k')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('elbow_silhouette.png', dpi=300, bbox_inches='tight')
|
||||
print("Saved elbow and silhouette plots to 'elbow_silhouette.png'")
|
||||
plt.close()
|
||||
|
||||
|
||||
def evaluate_clustering(X_scaled, labels, algorithm_name):
|
||||
"""
|
||||
Evaluate clustering using multiple metrics.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
labels: Cluster labels
|
||||
algorithm_name: Name of clustering algorithm
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation metrics
|
||||
"""
|
||||
# Filter out noise points for DBSCAN (-1 labels)
|
||||
mask = labels != -1
|
||||
X_filtered = X_scaled[mask]
|
||||
labels_filtered = labels[mask]
|
||||
|
||||
n_clusters = len(set(labels_filtered))
|
||||
n_noise = list(labels).count(-1)
|
||||
|
||||
results = {
|
||||
'algorithm': algorithm_name,
|
||||
'n_clusters': n_clusters,
|
||||
'n_noise': n_noise
|
||||
}
|
||||
|
||||
# Calculate metrics if we have valid clusters
|
||||
if n_clusters > 1:
|
||||
results['silhouette'] = silhouette_score(X_filtered, labels_filtered)
|
||||
results['davies_bouldin'] = davies_bouldin_score(X_filtered, labels_filtered)
|
||||
results['calinski_harabasz'] = calinski_harabasz_score(X_filtered, labels_filtered)
|
||||
else:
|
||||
results['silhouette'] = None
|
||||
results['davies_bouldin'] = None
|
||||
results['calinski_harabasz'] = None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def perform_kmeans(X_scaled, n_clusters=3):
|
||||
"""
|
||||
Perform k-means clustering.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
n_clusters: Number of clusters
|
||||
|
||||
Returns:
|
||||
Fitted KMeans model and labels
|
||||
"""
|
||||
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
||||
labels = kmeans.fit_predict(X_scaled)
|
||||
return kmeans, labels
|
||||
|
||||
|
||||
def perform_dbscan(X_scaled, eps=0.5, min_samples=5):
|
||||
"""
|
||||
Perform DBSCAN clustering.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
eps: Maximum distance between neighbors
|
||||
min_samples: Minimum points to form dense region
|
||||
|
||||
Returns:
|
||||
Fitted DBSCAN model and labels
|
||||
"""
|
||||
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
||||
labels = dbscan.fit_predict(X_scaled)
|
||||
return dbscan, labels
|
||||
|
||||
|
||||
def perform_hierarchical(X_scaled, n_clusters=3, linkage='ward'):
|
||||
"""
|
||||
Perform hierarchical clustering.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
n_clusters: Number of clusters
|
||||
linkage: Linkage criterion ('ward', 'complete', 'average', 'single')
|
||||
|
||||
Returns:
|
||||
Fitted AgglomerativeClustering model and labels
|
||||
"""
|
||||
hierarchical = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage)
|
||||
labels = hierarchical.fit_predict(X_scaled)
|
||||
return hierarchical, labels
|
||||
|
||||
|
||||
def visualize_clusters_2d(X_scaled, labels, algorithm_name, method='pca'):
|
||||
"""
|
||||
Visualize clusters in 2D using PCA or t-SNE.
|
||||
|
||||
Args:
|
||||
X_scaled: Scaled feature matrix
|
||||
labels: Cluster labels
|
||||
algorithm_name: Name of algorithm for title
|
||||
method: 'pca' or 'tsne'
|
||||
"""
|
||||
# Reduce to 2D
|
||||
if method == 'pca':
|
||||
pca = PCA(n_components=2, random_state=42)
|
||||
X_2d = pca.fit_transform(X_scaled)
|
||||
variance = pca.explained_variance_ratio_
|
||||
xlabel = f'PC1 ({variance[0]:.1%} variance)'
|
||||
ylabel = f'PC2 ({variance[1]:.1%} variance)'
|
||||
else:
|
||||
from sklearn.manifold import TSNE
|
||||
# Use PCA first to speed up t-SNE
|
||||
pca = PCA(n_components=min(50, X_scaled.shape[1]), random_state=42)
|
||||
X_pca = pca.fit_transform(X_scaled)
|
||||
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
||||
X_2d = tsne.fit_transform(X_pca)
|
||||
xlabel = 't-SNE 1'
|
||||
ylabel = 't-SNE 2'
|
||||
|
||||
# Plot
|
||||
plt.figure(figsize=(10, 8))
|
||||
scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=labels, cmap='viridis', alpha=0.6, s=50)
|
||||
plt.colorbar(scatter, label='Cluster')
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(f'{algorithm_name} Clustering ({method.upper()})')
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
filename = f'{algorithm_name.lower().replace(" ", "_")}_{method}.png'
|
||||
plt.savefig(filename, dpi=300, bbox_inches='tight')
|
||||
print(f"Saved visualization to '{filename}'")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Example clustering analysis workflow.
|
||||
"""
|
||||
# Load your data here
|
||||
# X = load_data()
|
||||
|
||||
# Example with synthetic data
|
||||
from sklearn.datasets import make_blobs
|
||||
X, y_true = make_blobs(
|
||||
n_samples=500,
|
||||
n_features=10,
|
||||
centers=4,
|
||||
cluster_std=1.0,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
print(f"Dataset shape: {X.shape}")
|
||||
|
||||
# Scale data (ALWAYS scale for clustering!)
|
||||
print("\nScaling data...")
|
||||
X_scaled, scaler = scale_data(X)
|
||||
|
||||
# Find optimal k
|
||||
print("\nFinding optimal number of clusters...")
|
||||
results = find_optimal_k(X_scaled)
|
||||
plot_elbow_silhouette(results)
|
||||
|
||||
# Based on elbow/silhouette, choose optimal k
|
||||
optimal_k = 4 # Adjust based on plots
|
||||
|
||||
# Perform k-means
|
||||
print(f"\nPerforming k-means with k={optimal_k}...")
|
||||
kmeans, kmeans_labels = perform_kmeans(X_scaled, n_clusters=optimal_k)
|
||||
kmeans_results = evaluate_clustering(X_scaled, kmeans_labels, 'K-Means')
|
||||
|
||||
# Perform DBSCAN
|
||||
print("\nPerforming DBSCAN...")
|
||||
dbscan, dbscan_labels = perform_dbscan(X_scaled, eps=0.5, min_samples=5)
|
||||
dbscan_results = evaluate_clustering(X_scaled, dbscan_labels, 'DBSCAN')
|
||||
|
||||
# Perform hierarchical clustering
|
||||
print("\nPerforming hierarchical clustering...")
|
||||
hierarchical, hier_labels = perform_hierarchical(X_scaled, n_clusters=optimal_k)
|
||||
hier_results = evaluate_clustering(X_scaled, hier_labels, 'Hierarchical')
|
||||
|
||||
# Print results
|
||||
print("\n" + "="*60)
|
||||
print("CLUSTERING RESULTS")
|
||||
print("="*60)
|
||||
|
||||
for results in [kmeans_results, dbscan_results, hier_results]:
|
||||
print(f"\n{results['algorithm']}:")
|
||||
print(f" Clusters: {results['n_clusters']}")
|
||||
if results['n_noise'] > 0:
|
||||
print(f" Noise points: {results['n_noise']}")
|
||||
if results['silhouette']:
|
||||
print(f" Silhouette Score: {results['silhouette']:.3f}")
|
||||
print(f" Davies-Bouldin Index: {results['davies_bouldin']:.3f} (lower is better)")
|
||||
print(f" Calinski-Harabasz Index: {results['calinski_harabasz']:.1f} (higher is better)")
|
||||
|
||||
# Visualize clusters
|
||||
print("\nCreating visualizations...")
|
||||
visualize_clusters_2d(X_scaled, kmeans_labels, 'K-Means', method='pca')
|
||||
visualize_clusters_2d(X_scaled, dbscan_labels, 'DBSCAN', method='pca')
|
||||
visualize_clusters_2d(X_scaled, hier_labels, 'Hierarchical', method='pca')
|
||||
|
||||
print("\nClustering analysis complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,290 +0,0 @@
|
||||
---
|
||||
name: tooluniverse
|
||||
description: Use this skill when working with scientific research tools and workflows across bioinformatics, cheminformatics, genomics, structural biology, proteomics, and drug discovery. This skill provides access to 600+ scientific tools including machine learning models, datasets, APIs, and analysis packages. Use when searching for scientific tools, executing computational biology workflows, composing multi-step research pipelines, accessing databases like OpenTargets/PubChem/UniProt/PDB/ChEMBL, performing tool discovery for research tasks, or integrating scientific computational resources into LLM workflows.
|
||||
---
|
||||
|
||||
# ToolUniverse
|
||||
|
||||
## Overview
|
||||
|
||||
ToolUniverse is a unified ecosystem that enables AI agents to function as research scientists by providing standardized access to 600+ scientific resources. Use this skill to discover, execute, and compose scientific tools across multiple research domains including bioinformatics, cheminformatics, genomics, structural biology, proteomics, and drug discovery.
|
||||
|
||||
**Key Capabilities:**
|
||||
- Access 600+ scientific tools, models, datasets, and APIs
|
||||
- Discover tools using natural language, semantic search, or keywords
|
||||
- Execute tools through standardized AI-Tool Interaction Protocol
|
||||
- Compose multi-step workflows for complex research problems
|
||||
- Integration with Claude Desktop/Code via Model Context Protocol (MCP)
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Searching for scientific tools by function or domain (e.g., "find protein structure prediction tools")
|
||||
- Executing computational biology workflows (e.g., disease target identification, drug discovery, genomics analysis)
|
||||
- Accessing scientific databases (OpenTargets, PubChem, UniProt, PDB, ChEMBL, KEGG, etc.)
|
||||
- Composing multi-step research pipelines (e.g., target discovery → structure prediction → virtual screening)
|
||||
- Working with bioinformatics, cheminformatics, or structural biology tasks
|
||||
- Analyzing gene expression, protein sequences, molecular structures, or clinical data
|
||||
- Performing literature searches, pathway enrichment, or variant annotation
|
||||
- Building automated scientific research workflows
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Setup
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
# Initialize and load tools
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools() # Loads 600+ scientific tools
|
||||
|
||||
# Discover tools
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "disease target associations",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
|
||||
# Execute a tool
|
||||
result = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": "EFO_0000537"} # Hypertension
|
||||
})
|
||||
```
|
||||
|
||||
### Model Context Protocol (MCP)
|
||||
For Claude Desktop/Code integration:
|
||||
```bash
|
||||
tooluniverse-smcp
|
||||
```
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### 1. Tool Discovery
|
||||
|
||||
Find relevant tools for your research task:
|
||||
|
||||
**Three discovery methods:**
|
||||
- `Tool_Finder` - Embedding-based semantic search (requires GPU)
|
||||
- `Tool_Finder_LLM` - LLM-based semantic search (no GPU required)
|
||||
- `Tool_Finder_Keyword` - Fast keyword search
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Search by natural language description
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_LLM",
|
||||
"arguments": {
|
||||
"description": "Find tools for RNA sequencing differential expression analysis",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
|
||||
# Review available tools
|
||||
for tool in tools:
|
||||
print(f"{tool['name']}: {tool['description']}")
|
||||
```
|
||||
|
||||
**See `references/tool-discovery.md` for:**
|
||||
- Detailed discovery methods and search strategies
|
||||
- Domain-specific keyword suggestions
|
||||
- Best practices for finding tools
|
||||
|
||||
### 2. Tool Execution
|
||||
|
||||
Execute individual tools through the standardized interface:
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Execute disease-target lookup
|
||||
targets = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": "EFO_0000616"} # Breast cancer
|
||||
})
|
||||
|
||||
# Get protein structure
|
||||
structure = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {"uniprot_id": "P12345"}
|
||||
})
|
||||
|
||||
# Calculate molecular properties
|
||||
properties = tu.run({
|
||||
"name": "RDKit_calculate_descriptors",
|
||||
"arguments": {"smiles": "CCO"} # Ethanol
|
||||
})
|
||||
```
|
||||
|
||||
**See `references/tool-execution.md` for:**
|
||||
- Real-world execution examples across domains
|
||||
- Tool parameter handling and validation
|
||||
- Result processing and error handling
|
||||
- Best practices for production use
|
||||
|
||||
### 3. Tool Composition and Workflows
|
||||
|
||||
Compose multiple tools for complex research workflows:
|
||||
|
||||
**Drug Discovery Example:**
|
||||
```python
|
||||
# 1. Find disease targets
|
||||
targets = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": "EFO_0000616"}
|
||||
})
|
||||
|
||||
# 2. Get protein structures
|
||||
structures = []
|
||||
for target in targets[:5]:
|
||||
structure = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {"uniprot_id": target['uniprot_id']}
|
||||
})
|
||||
structures.append(structure)
|
||||
|
||||
# 3. Screen compounds
|
||||
hits = []
|
||||
for structure in structures:
|
||||
compounds = tu.run({
|
||||
"name": "ZINC_virtual_screening",
|
||||
"arguments": {
|
||||
"structure": structure,
|
||||
"library": "lead-like",
|
||||
"top_n": 100
|
||||
}
|
||||
})
|
||||
hits.extend(compounds)
|
||||
|
||||
# 4. Evaluate drug-likeness
|
||||
drug_candidates = []
|
||||
for compound in hits:
|
||||
props = tu.run({
|
||||
"name": "RDKit_calculate_drug_properties",
|
||||
"arguments": {"smiles": compound['smiles']}
|
||||
})
|
||||
if props['lipinski_pass']:
|
||||
drug_candidates.append(compound)
|
||||
```
|
||||
|
||||
**See `references/tool-composition.md` for:**
|
||||
- Complete workflow examples (drug discovery, genomics, clinical)
|
||||
- Sequential and parallel tool composition patterns
|
||||
- Output processing hooks
|
||||
- Workflow best practices
|
||||
|
||||
## Scientific Domains
|
||||
|
||||
ToolUniverse supports 600+ tools across major scientific domains:
|
||||
|
||||
**Bioinformatics:**
|
||||
- Sequence analysis, alignment, BLAST
|
||||
- Gene expression (RNA-seq, DESeq2)
|
||||
- Pathway enrichment (KEGG, Reactome, GO)
|
||||
- Variant annotation (VEP, ClinVar)
|
||||
|
||||
**Cheminformatics:**
|
||||
- Molecular descriptors and fingerprints
|
||||
- Drug discovery and virtual screening
|
||||
- ADMET prediction and drug-likeness
|
||||
- Chemical databases (PubChem, ChEMBL, ZINC)
|
||||
|
||||
**Structural Biology:**
|
||||
- Protein structure prediction (AlphaFold)
|
||||
- Structure retrieval (PDB)
|
||||
- Binding site detection
|
||||
- Protein-protein interactions
|
||||
|
||||
**Proteomics:**
|
||||
- Mass spectrometry analysis
|
||||
- Protein databases (UniProt, STRING)
|
||||
- Post-translational modifications
|
||||
|
||||
**Genomics:**
|
||||
- Genome assembly and annotation
|
||||
- Copy number variation
|
||||
- Clinical genomics workflows
|
||||
|
||||
**Medical/Clinical:**
|
||||
- Disease databases (OpenTargets, OMIM)
|
||||
- Clinical trials and FDA data
|
||||
- Variant classification
|
||||
|
||||
**See `references/domains.md` for:**
|
||||
- Complete domain categorization
|
||||
- Tool examples by discipline
|
||||
- Cross-domain applications
|
||||
- Search strategies by domain
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
This skill includes comprehensive reference files that provide detailed information for specific aspects:
|
||||
|
||||
- **`references/installation.md`** - Installation, setup, MCP configuration, platform integration
|
||||
- **`references/tool-discovery.md`** - Discovery methods, search strategies, listing tools
|
||||
- **`references/tool-execution.md`** - Execution patterns, real-world examples, error handling
|
||||
- **`references/tool-composition.md`** - Workflow composition, complex pipelines, parallel execution
|
||||
- **`references/domains.md`** - Tool categorization by domain, use case examples
|
||||
- **`references/api_reference.md`** - Python API documentation, hooks, protocols
|
||||
|
||||
**Workflow:** When helping with specific tasks, reference the appropriate file for detailed instructions. For example, if searching for tools, consult `references/tool-discovery.md` for search strategies.
|
||||
|
||||
## Example Scripts
|
||||
|
||||
Two executable example scripts demonstrate common use cases:
|
||||
|
||||
**`scripts/example_tool_search.py`** - Demonstrates all three discovery methods:
|
||||
- Keyword-based search
|
||||
- LLM-based search
|
||||
- Domain-specific searches
|
||||
- Getting detailed tool information
|
||||
|
||||
**`scripts/example_workflow.py`** - Complete workflow examples:
|
||||
- Drug discovery pipeline (disease → targets → structures → screening → candidates)
|
||||
- Genomics analysis (expression data → differential analysis → pathways)
|
||||
|
||||
Run examples to understand typical usage patterns and workflow composition.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Tool Discovery:**
|
||||
- Start with broad searches, then refine based on results
|
||||
- Use `Tool_Finder_Keyword` for fast searches with known terms
|
||||
- Use `Tool_Finder_LLM` for complex semantic queries
|
||||
- Set appropriate `limit` parameter (default: 10)
|
||||
|
||||
2. **Tool Execution:**
|
||||
- Always verify tool parameters before execution
|
||||
- Implement error handling for production workflows
|
||||
- Validate input data formats (SMILES, UniProt IDs, gene symbols)
|
||||
- Check result types and structures
|
||||
|
||||
3. **Workflow Composition:**
|
||||
- Test each step individually before composing full workflows
|
||||
- Implement checkpointing for long workflows
|
||||
- Consider rate limits for remote APIs
|
||||
- Use parallel execution when tools are independent
|
||||
|
||||
4. **Integration:**
|
||||
- Initialize ToolUniverse once and reuse the instance
|
||||
- Call `load_tools()` once at startup
|
||||
- Cache frequently used tool information
|
||||
- Enable logging for debugging
|
||||
|
||||
## Key Terminology
|
||||
|
||||
- **Tool**: A scientific resource (model, dataset, API, package) accessible through ToolUniverse
|
||||
- **Tool Discovery**: Finding relevant tools using search methods (Finder, LLM, Keyword)
|
||||
- **Tool Execution**: Running a tool with specific arguments via `tu.run()`
|
||||
- **Tool Composition**: Chaining multiple tools for multi-step workflows
|
||||
- **MCP**: Model Context Protocol for integration with Claude Desktop/Code
|
||||
- **AI-Tool Interaction Protocol**: Standardized interface for LLM-tool communication
|
||||
|
||||
## Resources
|
||||
|
||||
- **Official Website**: https://aiscientist.tools
|
||||
- **GitHub**: https://github.com/mims-harvard/ToolUniverse
|
||||
- **Documentation**: https://zitniklab.hms.harvard.edu/ToolUniverse/
|
||||
- **Installation**: `uv pip install tooluniverse`
|
||||
- **MCP Server**: `tooluniverse-smcp`
|
||||
@@ -1,298 +0,0 @@
|
||||
# ToolUniverse Python API Reference
|
||||
|
||||
## Core Classes
|
||||
|
||||
### ToolUniverse
|
||||
|
||||
Main class for interacting with the ToolUniverse ecosystem.
|
||||
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
tu = ToolUniverse()
|
||||
```
|
||||
|
||||
#### Methods
|
||||
|
||||
##### `load_tools()`
|
||||
Load all available tools into the ToolUniverse instance.
|
||||
|
||||
```python
|
||||
tu.load_tools()
|
||||
```
|
||||
|
||||
**Returns:** None
|
||||
|
||||
**Side effects:** Loads 600+ tools into memory for discovery and execution.
|
||||
|
||||
---
|
||||
|
||||
##### `run(tool_config)`
|
||||
Execute a tool with specified arguments.
|
||||
|
||||
**Parameters:**
|
||||
- `tool_config` (dict): Configuration dictionary with keys:
|
||||
- `name` (str): Tool name to execute
|
||||
- `arguments` (dict): Tool-specific arguments
|
||||
|
||||
**Returns:** Tool-specific output (dict, list, str, or other types)
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
result = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {
|
||||
"efoId": "EFO_0000537"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
##### `list_tools(limit=None)`
|
||||
List all available tools or a subset.
|
||||
|
||||
**Parameters:**
|
||||
- `limit` (int, optional): Maximum number of tools to return. If None, returns all tools.
|
||||
|
||||
**Returns:** List of tool dictionaries
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# List all tools
|
||||
all_tools = tu.list_tools()
|
||||
|
||||
# List first 20 tools
|
||||
tools = tu.list_tools(limit=20)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
##### `get_tool_info(tool_name)`
|
||||
Get detailed information about a specific tool.
|
||||
|
||||
**Parameters:**
|
||||
- `tool_name` (str): Name of the tool
|
||||
|
||||
**Returns:** Dictionary containing tool metadata, parameters, and documentation
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
info = tu.get_tool_info("AlphaFold_get_structure")
|
||||
print(info['description'])
|
||||
print(info['parameters'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Built-in Discovery Tools
|
||||
|
||||
These are special tools that help find other tools in the ecosystem.
|
||||
|
||||
### Tool_Finder
|
||||
|
||||
Embedding-based semantic search for tools. Requires GPU.
|
||||
|
||||
```python
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder",
|
||||
"arguments": {
|
||||
"description": "protein structure prediction",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `description` (str): Natural language description of desired functionality
|
||||
- `limit` (int): Maximum number of tools to return
|
||||
|
||||
**Returns:** List of relevant tools with similarity scores
|
||||
|
||||
---
|
||||
|
||||
### Tool_Finder_LLM
|
||||
|
||||
LLM-based semantic search for tools. No GPU required.
|
||||
|
||||
```python
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_LLM",
|
||||
"arguments": {
|
||||
"description": "Find tools for RNA sequencing analysis",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `description` (str): Natural language query
|
||||
- `limit` (int): Maximum number of tools to return
|
||||
|
||||
**Returns:** List of relevant tools
|
||||
|
||||
---
|
||||
|
||||
### Tool_Finder_Keyword
|
||||
|
||||
Fast keyword-based search through tool names and descriptions.
|
||||
|
||||
```python
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "pathway enrichment",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `description` (str): Keywords to search for
|
||||
- `limit` (int): Maximum number of tools to return
|
||||
|
||||
**Returns:** List of matching tools
|
||||
|
||||
---
|
||||
|
||||
## Tool Output Hooks
|
||||
|
||||
Post-processing hooks for tool results.
|
||||
|
||||
### Summarization Hook
|
||||
```python
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
},
|
||||
hooks={
|
||||
"summarize": {
|
||||
"format": "brief" # or "detailed"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### File Saving Hook
|
||||
```python
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
},
|
||||
hooks={
|
||||
"save_to_file": {
|
||||
"filename": "output.json",
|
||||
"format": "json" # or "csv", "txt"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Context Protocol (MCP)
|
||||
|
||||
### Starting MCP Server
|
||||
|
||||
Command-line interface:
|
||||
```bash
|
||||
tooluniverse-smcp
|
||||
```
|
||||
|
||||
This launches an MCP server that exposes all ToolUniverse tools through the Model Context Protocol.
|
||||
|
||||
**Configuration:**
|
||||
- Default port: Automatically assigned
|
||||
- Protocol: MCP standard
|
||||
- Authentication: None required for local use
|
||||
|
||||
---
|
||||
|
||||
## Integration Modules
|
||||
|
||||
### OpenRouter Integration
|
||||
|
||||
Access 100+ LLMs through OpenRouter API:
|
||||
|
||||
```python
|
||||
from tooluniverse import OpenRouterClient
|
||||
|
||||
client = OpenRouterClient(api_key="your_key")
|
||||
response = client.chat("Analyze this protein sequence", model="anthropic/claude-3-5-sonnet")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AI-Tool Interaction Protocol
|
||||
|
||||
ToolUniverse uses a standardized protocol for LLM-tool communication:
|
||||
|
||||
**Request Format:**
|
||||
```json
|
||||
{
|
||||
"name": "tool_name",
|
||||
"arguments": {
|
||||
"param1": "value1",
|
||||
"param2": "value2"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": { ... },
|
||||
"metadata": {
|
||||
"execution_time": 1.23,
|
||||
"tool_version": "1.0.0"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
})
|
||||
except ToolNotFoundError as e:
|
||||
print(f"Tool not found: {e}")
|
||||
except InvalidArgumentError as e:
|
||||
print(f"Invalid arguments: {e}")
|
||||
except ToolExecutionError as e:
|
||||
print(f"Execution failed: {e}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Type Hints
|
||||
|
||||
```python
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
def run_tool(
|
||||
tu: ToolUniverse,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute a tool with type-safe arguments."""
|
||||
return tu.run({
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Initialize Once**: Create a single ToolUniverse instance and reuse it
|
||||
2. **Load Tools Early**: Call `load_tools()` once at startup
|
||||
3. **Cache Tool Info**: Store frequently used tool information
|
||||
4. **Error Handling**: Always wrap tool execution in try-except blocks
|
||||
5. **Type Validation**: Validate argument types before execution
|
||||
6. **Resource Management**: Consider rate limits for remote APIs
|
||||
7. **Logging**: Enable logging for production environments
|
||||
@@ -1,272 +0,0 @@
|
||||
# ToolUniverse Tool Domains and Categories
|
||||
|
||||
## Overview
|
||||
|
||||
ToolUniverse integrates 600+ scientific tools across multiple research domains. This document categorizes tools by scientific discipline and use case.
|
||||
|
||||
## Major Scientific Domains
|
||||
|
||||
### Bioinformatics
|
||||
|
||||
**Sequence Analysis:**
|
||||
- Sequence alignment and comparison
|
||||
- Multiple sequence alignment (MSA)
|
||||
- BLAST and homology searches
|
||||
- Motif finding and pattern matching
|
||||
|
||||
**Genomics:**
|
||||
- Gene expression analysis
|
||||
- RNA-seq data processing
|
||||
- Variant calling and annotation
|
||||
- Genome assembly and annotation
|
||||
- Copy number variation analysis
|
||||
|
||||
**Functional Analysis:**
|
||||
- Gene Ontology (GO) enrichment
|
||||
- Pathway analysis (KEGG, Reactome)
|
||||
- Gene set enrichment analysis (GSEA)
|
||||
- Protein domain analysis
|
||||
|
||||
**Example Tools:**
|
||||
- GEO data download and analysis
|
||||
- DESeq2 differential expression
|
||||
- KEGG pathway enrichment
|
||||
- UniProt sequence retrieval
|
||||
- VEP variant annotation
|
||||
|
||||
### Cheminformatics
|
||||
|
||||
**Molecular Descriptors:**
|
||||
- Chemical property calculation
|
||||
- Molecular fingerprints
|
||||
- SMILES/InChI conversion
|
||||
- 3D conformer generation
|
||||
|
||||
**Drug Discovery:**
|
||||
- Virtual screening
|
||||
- Molecular docking
|
||||
- ADMET prediction
|
||||
- Drug-likeness assessment (Lipinski's Rule of Five)
|
||||
- Toxicity prediction
|
||||
|
||||
**Chemical Databases:**
|
||||
- PubChem compound search
|
||||
- ChEMBL bioactivity data
|
||||
- ZINC compound libraries
|
||||
- DrugBank drug information
|
||||
|
||||
**Example Tools:**
|
||||
- RDKit molecular descriptors
|
||||
- AutoDock molecular docking
|
||||
- ZINC library screening
|
||||
- ChEMBL target-compound associations
|
||||
|
||||
### Structural Biology
|
||||
|
||||
**Protein Structure:**
|
||||
- AlphaFold structure prediction
|
||||
- PDB structure retrieval
|
||||
- Structure alignment and comparison
|
||||
- Binding site prediction
|
||||
- Protein-protein interaction prediction
|
||||
|
||||
**Structure Analysis:**
|
||||
- Secondary structure prediction
|
||||
- Solvent accessibility calculation
|
||||
- Structure quality assessment
|
||||
- Ramachandran plot analysis
|
||||
|
||||
**Example Tools:**
|
||||
- AlphaFold structure prediction
|
||||
- PDB structure download
|
||||
- Fpocket binding site detection
|
||||
- DSSP secondary structure assignment
|
||||
|
||||
### Proteomics
|
||||
|
||||
**Protein Analysis:**
|
||||
- Mass spectrometry data analysis
|
||||
- Protein identification
|
||||
- Post-translational modification analysis
|
||||
- Protein quantification
|
||||
|
||||
**Protein Databases:**
|
||||
- UniProt protein information
|
||||
- STRING protein interactions
|
||||
- IntAct interaction databases
|
||||
|
||||
**Example Tools:**
|
||||
- UniProt data retrieval
|
||||
- STRING interaction networks
|
||||
- Mass spec peak analysis
|
||||
|
||||
### Machine Learning
|
||||
|
||||
**Model Types:**
|
||||
- Classification models
|
||||
- Regression models
|
||||
- Clustering algorithms
|
||||
- Neural networks
|
||||
- Deep learning models
|
||||
|
||||
**Applications:**
|
||||
- Predictive modeling
|
||||
- Feature selection
|
||||
- Dimensionality reduction
|
||||
- Pattern recognition
|
||||
- Biomarker discovery
|
||||
|
||||
**Example Tools:**
|
||||
- Scikit-learn models
|
||||
- TensorFlow/PyTorch models
|
||||
- XGBoost predictors
|
||||
- Random forest classifiers
|
||||
|
||||
### Medical/Clinical
|
||||
|
||||
**Disease Databases:**
|
||||
- OpenTargets disease-target associations
|
||||
- OMIM genetic disorders
|
||||
- ClinVar pathogenic variants
|
||||
- DisGeNET disease-gene associations
|
||||
|
||||
**Clinical Data:**
|
||||
- Electronic health records analysis
|
||||
- Clinical trial data
|
||||
- Diagnostic tools
|
||||
- Treatment recommendations
|
||||
|
||||
**Example Tools:**
|
||||
- OpenTargets disease queries
|
||||
- ClinVar variant classification
|
||||
- OMIM disease lookup
|
||||
- FDA drug approval data
|
||||
|
||||
### Neuroscience
|
||||
|
||||
**Brain Imaging:**
|
||||
- fMRI data analysis
|
||||
- Brain atlas mapping
|
||||
- Connectivity analysis
|
||||
- Neuroimaging pipelines
|
||||
|
||||
**Neural Data:**
|
||||
- Electrophysiology analysis
|
||||
- Spike train analysis
|
||||
- Neural network simulation
|
||||
|
||||
### Image Processing
|
||||
|
||||
**Biomedical Imaging:**
|
||||
- Microscopy image analysis
|
||||
- Cell segmentation
|
||||
- Object detection
|
||||
- Image enhancement
|
||||
- Feature extraction
|
||||
|
||||
**Image Analysis:**
|
||||
- ImageJ/Fiji tools
|
||||
- CellProfiler pipelines
|
||||
- Deep learning segmentation
|
||||
|
||||
### Systems Biology
|
||||
|
||||
**Network Analysis:**
|
||||
- Biological network construction
|
||||
- Network topology analysis
|
||||
- Module identification
|
||||
- Hub gene identification
|
||||
|
||||
**Modeling:**
|
||||
- Systems biology models
|
||||
- Metabolic network modeling
|
||||
- Signaling pathway simulation
|
||||
|
||||
## Tool Categories by Use Case
|
||||
|
||||
### Literature and Knowledge
|
||||
|
||||
**Literature Search:**
|
||||
- PubMed article search
|
||||
- Article summarization
|
||||
- Citation analysis
|
||||
- Knowledge extraction
|
||||
|
||||
**Knowledge Bases:**
|
||||
- Ontology queries (GO, DO, HPO)
|
||||
- Database cross-referencing
|
||||
- Entity recognition
|
||||
|
||||
### Data Access
|
||||
|
||||
**Public Repositories:**
|
||||
- GEO (Gene Expression Omnibus)
|
||||
- SRA (Sequence Read Archive)
|
||||
- PDB (Protein Data Bank)
|
||||
- ChEMBL (Bioactivity database)
|
||||
|
||||
**API Access:**
|
||||
- RESTful API clients
|
||||
- Database query tools
|
||||
- Batch data retrieval
|
||||
|
||||
### Visualization
|
||||
|
||||
**Plot Generation:**
|
||||
- Heatmaps
|
||||
- Volcano plots
|
||||
- Manhattan plots
|
||||
- Network graphs
|
||||
- Molecular structures
|
||||
|
||||
### Utilities
|
||||
|
||||
**Data Processing:**
|
||||
- Format conversion
|
||||
- Data normalization
|
||||
- Statistical analysis
|
||||
- Quality control
|
||||
|
||||
**Workflow Management:**
|
||||
- Pipeline construction
|
||||
- Task orchestration
|
||||
- Result aggregation
|
||||
|
||||
## Finding Tools by Domain
|
||||
|
||||
Use domain-specific keywords with Tool_Finder:
|
||||
|
||||
```python
|
||||
# Bioinformatics
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {"description": "RNA-seq genomics", "limit": 10}
|
||||
})
|
||||
|
||||
# Cheminformatics
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {"description": "molecular docking SMILES", "limit": 10}
|
||||
})
|
||||
|
||||
# Structural biology
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {"description": "protein structure PDB", "limit": 10}
|
||||
})
|
||||
|
||||
# Clinical
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {"description": "disease clinical variants", "limit": 10}
|
||||
})
|
||||
```
|
||||
|
||||
## Cross-Domain Applications
|
||||
|
||||
Many scientific problems require tools from multiple domains:
|
||||
|
||||
- **Precision Medicine**: Genomics + Clinical + Proteomics
|
||||
- **Drug Discovery**: Cheminformatics + Structural Biology + Machine Learning
|
||||
- **Cancer Research**: Genomics + Pathways + Literature
|
||||
- **Neurodegenerative Diseases**: Genomics + Proteomics + Imaging
|
||||
@@ -1,89 +0,0 @@
|
||||
# ToolUniverse Installation and Setup
|
||||
|
||||
## Installation
|
||||
|
||||
### Using uv (Recommended)
|
||||
```bash
|
||||
uv pip install tooluniverse
|
||||
```
|
||||
|
||||
### Using pip
|
||||
```bash
|
||||
pip install tooluniverse
|
||||
```
|
||||
|
||||
## Basic Setup
|
||||
|
||||
### Python SDK
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
# Initialize ToolUniverse
|
||||
tu = ToolUniverse()
|
||||
|
||||
# Load all available tools (600+ scientific tools)
|
||||
tu.load_tools()
|
||||
```
|
||||
|
||||
## Model Context Protocol (MCP) Setup
|
||||
|
||||
ToolUniverse provides native MCP support for integration with Claude Desktop, Claude Code, and other MCP-compatible systems.
|
||||
|
||||
### Starting MCP Server
|
||||
```bash
|
||||
tooluniverse-smcp
|
||||
```
|
||||
|
||||
This launches an MCP server that exposes ToolUniverse's 600+ tools through the Model Context Protocol.
|
||||
|
||||
### Claude Desktop Integration
|
||||
|
||||
Add to Claude Desktop configuration (~/.config/Claude/claude_desktop_config.json):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"tooluniverse": {
|
||||
"command": "tooluniverse-smcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Claude Code Integration
|
||||
|
||||
ToolUniverse MCP server works natively with Claude Code through the MCP protocol.
|
||||
|
||||
## Integration with Other Platforms
|
||||
|
||||
### OpenRouter Integration
|
||||
ToolUniverse integrates with OpenRouter for access to 100+ LLMs through a single API:
|
||||
- GPT-5, Claude, Gemini
|
||||
- Qwen, Deepseek
|
||||
- Open-source models
|
||||
|
||||
### Supported LLM Platforms
|
||||
- Claude Desktop and Claude Code
|
||||
- Gemini CLI
|
||||
- Qwen Code
|
||||
- ChatGPT API
|
||||
- GPT Codex CLI
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.8+
|
||||
- For Tool_Finder (embedding-based search): GPU recommended
|
||||
- For Tool_Finder_LLM: No GPU required (uses LLM-based search)
|
||||
|
||||
## Verification
|
||||
|
||||
Test installation:
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
# List first 5 tools to verify setup
|
||||
tools = tu.list_tools(limit=5)
|
||||
print(f"Loaded {len(tools)} tools successfully")
|
||||
```
|
||||
@@ -1,249 +0,0 @@
|
||||
# Tool Composition and Workflows in ToolUniverse
|
||||
|
||||
## Overview
|
||||
|
||||
ToolUniverse enables chaining multiple tools together to create complex scientific workflows. Tools can be composed sequentially or in parallel to solve multi-step research problems.
|
||||
|
||||
## Sequential Tool Composition
|
||||
|
||||
Execute tools in sequence where each tool's output feeds into the next tool.
|
||||
|
||||
### Basic Pattern
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
# Step 1: Get disease-associated targets
|
||||
targets = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": "EFO_0000537"} # Hypertension
|
||||
})
|
||||
|
||||
# Step 2: For each target, get protein structure
|
||||
structures = []
|
||||
for target in targets[:5]: # First 5 targets
|
||||
structure = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {"uniprot_id": target['uniprot_id']}
|
||||
})
|
||||
structures.append(structure)
|
||||
|
||||
# Step 3: Analyze structures
|
||||
for structure in structures:
|
||||
analysis = tu.run({
|
||||
"name": "ProteinAnalysis_calculate_properties",
|
||||
"arguments": {"structure": structure}
|
||||
})
|
||||
```
|
||||
|
||||
## Complex Workflow Examples
|
||||
|
||||
### Drug Discovery Workflow
|
||||
|
||||
Complete workflow from disease to drug candidates:
|
||||
|
||||
```python
|
||||
# 1. Find disease-associated targets
|
||||
print("Finding disease targets...")
|
||||
targets = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": "EFO_0000616"} # Breast cancer
|
||||
})
|
||||
|
||||
# 2. Get target protein sequences
|
||||
print("Retrieving protein sequences...")
|
||||
sequences = []
|
||||
for target in targets[:10]:
|
||||
seq = tu.run({
|
||||
"name": "UniProt_get_sequence",
|
||||
"arguments": {"uniprot_id": target['uniprot_id']}
|
||||
})
|
||||
sequences.append(seq)
|
||||
|
||||
# 3. Predict protein structures
|
||||
print("Predicting structures...")
|
||||
structures = []
|
||||
for seq in sequences:
|
||||
structure = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {"sequence": seq}
|
||||
})
|
||||
structures.append(structure)
|
||||
|
||||
# 4. Find binding sites
|
||||
print("Identifying binding sites...")
|
||||
binding_sites = []
|
||||
for structure in structures:
|
||||
sites = tu.run({
|
||||
"name": "Fpocket_find_binding_sites",
|
||||
"arguments": {"structure": structure}
|
||||
})
|
||||
binding_sites.append(sites)
|
||||
|
||||
# 5. Screen compound libraries
|
||||
print("Screening compounds...")
|
||||
hits = []
|
||||
for site in binding_sites:
|
||||
compounds = tu.run({
|
||||
"name": "ZINC_virtual_screening",
|
||||
"arguments": {
|
||||
"binding_site": site,
|
||||
"library": "lead-like",
|
||||
"top_n": 100
|
||||
}
|
||||
})
|
||||
hits.extend(compounds)
|
||||
|
||||
# 6. Calculate drug-likeness
|
||||
print("Evaluating drug-likeness...")
|
||||
drug_candidates = []
|
||||
for compound in hits:
|
||||
properties = tu.run({
|
||||
"name": "RDKit_calculate_drug_properties",
|
||||
"arguments": {"smiles": compound['smiles']}
|
||||
})
|
||||
if properties['lipinski_pass']:
|
||||
drug_candidates.append(compound)
|
||||
|
||||
print(f"Found {len(drug_candidates)} drug candidates")
|
||||
```
|
||||
|
||||
### Genomics Analysis Workflow
|
||||
|
||||
```python
|
||||
# 1. Download gene expression data
|
||||
expression_data = tu.run({
|
||||
"name": "GEO_download_dataset",
|
||||
"arguments": {"geo_id": "GSE12345"}
|
||||
})
|
||||
|
||||
# 2. Perform differential expression analysis
|
||||
de_genes = tu.run({
|
||||
"name": "DESeq2_differential_expression",
|
||||
"arguments": {
|
||||
"data": expression_data,
|
||||
"condition1": "control",
|
||||
"condition2": "treated"
|
||||
}
|
||||
})
|
||||
|
||||
# 3. Pathway enrichment analysis
|
||||
pathways = tu.run({
|
||||
"name": "KEGG_pathway_enrichment",
|
||||
"arguments": {
|
||||
"gene_list": de_genes['significant_genes'],
|
||||
"organism": "hsa"
|
||||
}
|
||||
})
|
||||
|
||||
# 4. Find relevant literature
|
||||
papers = tu.run({
|
||||
"name": "PubMed_search",
|
||||
"arguments": {
|
||||
"query": f"{pathways[0]['pathway_name']} AND cancer",
|
||||
"max_results": 20
|
||||
}
|
||||
})
|
||||
|
||||
# 5. Summarize findings
|
||||
summary = tu.run({
|
||||
"name": "LLM_summarize",
|
||||
"arguments": {
|
||||
"text": papers,
|
||||
"focus": "therapeutic implications"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Clinical Genomics Workflow
|
||||
|
||||
```python
|
||||
# 1. Load patient variants
|
||||
variants = tu.run({
|
||||
"name": "VCF_parse",
|
||||
"arguments": {"vcf_file": "patient_001.vcf"}
|
||||
})
|
||||
|
||||
# 2. Annotate variants
|
||||
annotated = tu.run({
|
||||
"name": "VEP_annotate_variants",
|
||||
"arguments": {"variants": variants}
|
||||
})
|
||||
|
||||
# 3. Filter pathogenic variants
|
||||
pathogenic = tu.run({
|
||||
"name": "ClinVar_filter_pathogenic",
|
||||
"arguments": {"variants": annotated}
|
||||
})
|
||||
|
||||
# 4. Find disease associations
|
||||
diseases = tu.run({
|
||||
"name": "OMIM_disease_lookup",
|
||||
"arguments": {"genes": pathogenic['affected_genes']}
|
||||
})
|
||||
|
||||
# 5. Generate clinical report
|
||||
report = tu.run({
|
||||
"name": "Report_generator",
|
||||
"arguments": {
|
||||
"variants": pathogenic,
|
||||
"diseases": diseases,
|
||||
"format": "clinical"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Parallel Tool Execution
|
||||
|
||||
Execute multiple tools simultaneously when they don't depend on each other:
|
||||
|
||||
```python
|
||||
import concurrent.futures
|
||||
|
||||
def run_tool(tu, tool_config):
|
||||
return tu.run(tool_config)
|
||||
|
||||
# Define parallel tasks
|
||||
tasks = [
|
||||
{"name": "PubMed_search", "arguments": {"query": "cancer", "max_results": 10}},
|
||||
{"name": "OpenTargets_get_diseases", "arguments": {"therapeutic_area": "oncology"}},
|
||||
{"name": "ChEMBL_search_compounds", "arguments": {"target": "EGFR"}}
|
||||
]
|
||||
|
||||
# Execute in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(run_tool, tu, task) for task in tasks]
|
||||
results = [future.result() for future in concurrent.futures.as_completed(futures)]
|
||||
```
|
||||
|
||||
## Output Processing Hooks
|
||||
|
||||
ToolUniverse supports post-processing hooks for:
|
||||
- Summarization
|
||||
- File saving
|
||||
- Data transformation
|
||||
- Visualization
|
||||
|
||||
```python
|
||||
# Example: Save results to file
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
},
|
||||
hooks={
|
||||
"save_to_file": {"filename": "results.json"},
|
||||
"summarize": {"format": "brief"}
|
||||
})
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Error Handling**: Implement try-except blocks for each tool in workflow
|
||||
2. **Data Validation**: Verify output from each step before passing to next tool
|
||||
3. **Checkpointing**: Save intermediate results for long workflows
|
||||
4. **Logging**: Track progress through complex workflows
|
||||
5. **Resource Management**: Consider rate limits and computational resources
|
||||
6. **Modularity**: Break complex workflows into reusable functions
|
||||
7. **Testing**: Test each step individually before composing full workflow
|
||||
@@ -1,126 +0,0 @@
|
||||
# Tool Discovery in ToolUniverse
|
||||
|
||||
## Overview
|
||||
|
||||
ToolUniverse provides multiple methods to discover and search through 600+ scientific tools using natural language, keywords, or embeddings.
|
||||
|
||||
## Discovery Methods
|
||||
|
||||
### 1. Tool_Finder (Embedding-Based Search)
|
||||
|
||||
Uses semantic embeddings to find relevant tools. **Requires GPU** for optimal performance.
|
||||
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
# Search by natural language description
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder",
|
||||
"arguments": {
|
||||
"description": "protein structure prediction",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
|
||||
print(tools)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Natural language queries
|
||||
- Semantic similarity search
|
||||
- When GPU is available
|
||||
|
||||
### 2. Tool_Finder_LLM (LLM-Based Search)
|
||||
|
||||
Alternative to embedding-based search that uses LLM reasoning. **No GPU required**.
|
||||
|
||||
```python
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_LLM",
|
||||
"arguments": {
|
||||
"description": "Find tools for analyzing gene expression data",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- When GPU is not available
|
||||
- Complex queries requiring reasoning
|
||||
- Semantic understanding needed
|
||||
|
||||
### 3. Tool_Finder_Keyword (Keyword Search)
|
||||
|
||||
Fast keyword-based search through tool names and descriptions.
|
||||
|
||||
```python
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "disease target associations",
|
||||
"limit": 10
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Fast searches
|
||||
- Known keywords
|
||||
- Exact term matching
|
||||
|
||||
## Listing Available Tools
|
||||
|
||||
### List All Tools
|
||||
```python
|
||||
all_tools = tu.list_tools()
|
||||
print(f"Total tools available: {len(all_tools)}")
|
||||
```
|
||||
|
||||
### List Tools with Limit
|
||||
```python
|
||||
tools = tu.list_tools(limit=20)
|
||||
for tool in tools:
|
||||
print(f"{tool['name']}: {tool['description']}")
|
||||
```
|
||||
|
||||
## Tool Information
|
||||
|
||||
### Get Tool Details
|
||||
```python
|
||||
# After finding a tool, inspect its details
|
||||
tool_info = tu.get_tool_info("OpenTargets_get_associated_targets_by_disease_efoId")
|
||||
print(tool_info)
|
||||
```
|
||||
|
||||
## Search Strategies
|
||||
|
||||
### By Domain
|
||||
Use domain-specific keywords:
|
||||
- Bioinformatics: "sequence alignment", "genomics", "RNA-seq"
|
||||
- Cheminformatics: "molecular dynamics", "drug design", "SMILES"
|
||||
- Machine Learning: "classification", "prediction", "neural network"
|
||||
- Structural Biology: "protein structure", "PDB", "crystallography"
|
||||
|
||||
### By Functionality
|
||||
Search by what you want to accomplish:
|
||||
- "Find disease-gene associations"
|
||||
- "Predict protein interactions"
|
||||
- "Analyze clinical trial data"
|
||||
- "Generate molecular descriptors"
|
||||
|
||||
### By Data Source
|
||||
Search for specific databases or APIs:
|
||||
- "OpenTargets", "PubChem", "UniProt"
|
||||
- "AlphaFold", "ChEMBL", "PDB"
|
||||
- "KEGG", "Reactome", "STRING"
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start Broad**: Begin with general terms, then refine
|
||||
2. **Use Multiple Methods**: Try different discovery methods if results aren't satisfactory
|
||||
3. **Set Appropriate Limits**: Use `limit` parameter to control result size (default: 10)
|
||||
4. **Check Tool Descriptions**: Review returned tool descriptions to verify relevance
|
||||
5. **Iterate**: Refine search terms based on initial results
|
||||
@@ -1,177 +0,0 @@
|
||||
# Tool Execution in ToolUniverse
|
||||
|
||||
## Overview
|
||||
|
||||
Execute individual tools through ToolUniverse's standardized interface using the `run()` method.
|
||||
|
||||
## Basic Tool Execution
|
||||
|
||||
### Standard Pattern
|
||||
```python
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
# Execute a tool
|
||||
result = tu.run({
|
||||
"name": "tool_name_here",
|
||||
"arguments": {
|
||||
"param1": "value1",
|
||||
"param2": "value2"
|
||||
}
|
||||
})
|
||||
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Example 1: Disease-Target Associations (OpenTargets)
|
||||
```python
|
||||
# Find targets associated with hypertension
|
||||
result = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {
|
||||
"efoId": "EFO_0000537" # Hypertension
|
||||
}
|
||||
})
|
||||
|
||||
print(f"Found {len(result)} targets associated with hypertension")
|
||||
```
|
||||
|
||||
### Example 2: Protein Structure Prediction
|
||||
```python
|
||||
# Get AlphaFold structure prediction
|
||||
result = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {
|
||||
"uniprot_id": "P12345"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Example 3: Chemical Property Calculation
|
||||
```python
|
||||
# Calculate molecular descriptors
|
||||
result = tu.run({
|
||||
"name": "RDKit_calculate_descriptors",
|
||||
"arguments": {
|
||||
"smiles": "CCO" # Ethanol
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Example 4: Gene Expression Analysis
|
||||
```python
|
||||
# Analyze differential gene expression
|
||||
result = tu.run({
|
||||
"name": "GeneExpression_differential_analysis",
|
||||
"arguments": {
|
||||
"dataset_id": "GSE12345",
|
||||
"condition1": "control",
|
||||
"condition2": "treatment"
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Tool Execution Workflow
|
||||
|
||||
### 1. Discover the Tool
|
||||
```python
|
||||
# Find relevant tools
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "pathway enrichment",
|
||||
"limit": 5
|
||||
}
|
||||
})
|
||||
|
||||
# Review available tools
|
||||
for tool in tools:
|
||||
print(f"Name: {tool['name']}")
|
||||
print(f"Description: {tool['description']}")
|
||||
print(f"Parameters: {tool['parameters']}")
|
||||
print("---")
|
||||
```
|
||||
|
||||
### 2. Check Tool Parameters
|
||||
```python
|
||||
# Get detailed tool information
|
||||
tool_info = tu.get_tool_info("KEGG_pathway_enrichment")
|
||||
print(tool_info['parameters'])
|
||||
```
|
||||
|
||||
### 3. Execute with Proper Arguments
|
||||
```python
|
||||
# Execute the tool
|
||||
result = tu.run({
|
||||
"name": "KEGG_pathway_enrichment",
|
||||
"arguments": {
|
||||
"gene_list": ["TP53", "BRCA1", "EGFR"],
|
||||
"organism": "hsa" # Homo sapiens
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Handling Tool Results
|
||||
|
||||
### Check Result Type
|
||||
```python
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
})
|
||||
|
||||
# Results can be various types
|
||||
if isinstance(result, dict):
|
||||
print("Dictionary result")
|
||||
elif isinstance(result, list):
|
||||
print(f"List with {len(result)} items")
|
||||
elif isinstance(result, str):
|
||||
print("String result")
|
||||
```
|
||||
|
||||
### Process Results
|
||||
```python
|
||||
# Example: Processing multiple results
|
||||
results = tu.run({
|
||||
"name": "PubMed_search",
|
||||
"arguments": {
|
||||
"query": "cancer immunotherapy",
|
||||
"max_results": 10
|
||||
}
|
||||
})
|
||||
|
||||
for idx, paper in enumerate(results, 1):
|
||||
print(f"{idx}. {paper['title']}")
|
||||
print(f" PMID: {paper['pmid']}")
|
||||
print(f" Authors: {', '.join(paper['authors'][:3])}")
|
||||
print()
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
result = tu.run({
|
||||
"name": "some_tool",
|
||||
"arguments": {"param": "value"}
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Tool execution failed: {e}")
|
||||
# Check if tool exists
|
||||
# Verify parameter names and types
|
||||
# Review tool documentation
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Verify Tool Parameters**: Always check required parameters before execution
|
||||
2. **Start Simple**: Test with simple cases before complex workflows
|
||||
3. **Handle Results Appropriately**: Check result type and structure
|
||||
4. **Error Recovery**: Implement try-except blocks for production code
|
||||
5. **Documentation**: Review tool descriptions for parameter requirements and output formats
|
||||
6. **Rate Limiting**: Be aware of API rate limits for remote tools
|
||||
7. **Data Validation**: Validate input data format (e.g., SMILES, UniProt IDs, gene symbols)
|
||||
@@ -1,91 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example script demonstrating tool discovery in ToolUniverse.
|
||||
|
||||
This script shows how to search for tools using different methods:
|
||||
- Embedding-based search (Tool_Finder)
|
||||
- LLM-based search (Tool_Finder_LLM)
|
||||
- Keyword-based search (Tool_Finder_Keyword)
|
||||
"""
|
||||
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize ToolUniverse
|
||||
print("Initializing ToolUniverse...")
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
print(f"Loaded {len(tu.list_tools())} tools\n")
|
||||
|
||||
# Example 1: Keyword-based search (fastest)
|
||||
print("=" * 60)
|
||||
print("Example 1: Keyword Search for Disease-Target Tools")
|
||||
print("=" * 60)
|
||||
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "disease target associations",
|
||||
"limit": 5
|
||||
}
|
||||
})
|
||||
|
||||
print(f"Found {len(tools)} tools:")
|
||||
for idx, tool in enumerate(tools, 1):
|
||||
print(f"\n{idx}. {tool['name']}")
|
||||
print(f" Description: {tool['description']}")
|
||||
|
||||
# Example 2: LLM-based search (no GPU required)
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: LLM Search for Protein Structure Tools")
|
||||
print("=" * 60)
|
||||
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_LLM",
|
||||
"arguments": {
|
||||
"description": "Find tools for predicting protein structures from sequences",
|
||||
"limit": 5
|
||||
}
|
||||
})
|
||||
|
||||
print(f"Found {len(tools)} tools:")
|
||||
for idx, tool in enumerate(tools, 1):
|
||||
print(f"\n{idx}. {tool['name']}")
|
||||
print(f" Description: {tool['description']}")
|
||||
|
||||
# Example 3: Search by specific domain
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Search for Cheminformatics Tools")
|
||||
print("=" * 60)
|
||||
|
||||
tools = tu.run({
|
||||
"name": "Tool_Finder_Keyword",
|
||||
"arguments": {
|
||||
"description": "molecular docking SMILES compound",
|
||||
"limit": 5
|
||||
}
|
||||
})
|
||||
|
||||
print(f"Found {len(tools)} tools:")
|
||||
for idx, tool in enumerate(tools, 1):
|
||||
print(f"\n{idx}. {tool['name']}")
|
||||
print(f" Description: {tool['description']}")
|
||||
|
||||
# Example 4: Get detailed tool information
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 4: Get Tool Details")
|
||||
print("=" * 60)
|
||||
|
||||
if tools:
|
||||
tool_name = tools[0]['name']
|
||||
print(f"Getting details for: {tool_name}")
|
||||
|
||||
tool_info = tu.get_tool_info(tool_name)
|
||||
print(f"\nTool: {tool_info['name']}")
|
||||
print(f"Description: {tool_info['description']}")
|
||||
print(f"Parameters: {tool_info.get('parameters', 'No parameters listed')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,219 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example workflow demonstrating tool composition in ToolUniverse.
|
||||
|
||||
This script shows a complete drug discovery workflow:
|
||||
1. Find disease-associated targets
|
||||
2. Retrieve protein sequences
|
||||
3. Get structure predictions
|
||||
4. Screen compound libraries
|
||||
5. Calculate drug-likeness properties
|
||||
"""
|
||||
|
||||
from tooluniverse import ToolUniverse
|
||||
|
||||
|
||||
def drug_discovery_workflow(disease_efo_id: str, max_targets: int = 3):
|
||||
"""
|
||||
Execute a drug discovery workflow for a given disease.
|
||||
|
||||
Args:
|
||||
disease_efo_id: EFO ID for the disease (e.g., "EFO_0000537" for hypertension)
|
||||
max_targets: Maximum number of targets to process
|
||||
"""
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
print("=" * 70)
|
||||
print("DRUG DISCOVERY WORKFLOW")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Find disease-associated targets
|
||||
print(f"\nStep 1: Finding targets for disease {disease_efo_id}...")
|
||||
targets = tu.run({
|
||||
"name": "OpenTargets_get_associated_targets_by_disease_efoId",
|
||||
"arguments": {"efoId": disease_efo_id}
|
||||
})
|
||||
print(f"✓ Found {len(targets)} disease-associated targets")
|
||||
|
||||
# Process top targets
|
||||
top_targets = targets[:max_targets]
|
||||
print(f" Processing top {len(top_targets)} targets:")
|
||||
for idx, target in enumerate(top_targets, 1):
|
||||
print(f" {idx}. {target.get('target_name', 'Unknown')} ({target.get('uniprot_id', 'N/A')})")
|
||||
|
||||
# Step 2: Get protein sequences
|
||||
print(f"\nStep 2: Retrieving protein sequences...")
|
||||
sequences = []
|
||||
for target in top_targets:
|
||||
try:
|
||||
seq = tu.run({
|
||||
"name": "UniProt_get_sequence",
|
||||
"arguments": {"uniprot_id": target['uniprot_id']}
|
||||
})
|
||||
sequences.append({
|
||||
"target": target,
|
||||
"sequence": seq
|
||||
})
|
||||
print(f" ✓ Retrieved sequence for {target.get('target_name', 'Unknown')}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to get sequence: {e}")
|
||||
|
||||
# Step 3: Predict protein structures
|
||||
print(f"\nStep 3: Predicting protein structures...")
|
||||
structures = []
|
||||
for seq_data in sequences:
|
||||
try:
|
||||
structure = tu.run({
|
||||
"name": "AlphaFold_get_structure",
|
||||
"arguments": {"uniprot_id": seq_data['target']['uniprot_id']}
|
||||
})
|
||||
structures.append({
|
||||
"target": seq_data['target'],
|
||||
"structure": structure
|
||||
})
|
||||
print(f" ✓ Predicted structure for {seq_data['target'].get('target_name', 'Unknown')}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to predict structure: {e}")
|
||||
|
||||
# Step 4: Find binding sites
|
||||
print(f"\nStep 4: Identifying binding sites...")
|
||||
binding_sites = []
|
||||
for struct_data in structures:
|
||||
try:
|
||||
sites = tu.run({
|
||||
"name": "Fpocket_find_binding_sites",
|
||||
"arguments": {"structure": struct_data['structure']}
|
||||
})
|
||||
binding_sites.append({
|
||||
"target": struct_data['target'],
|
||||
"sites": sites
|
||||
})
|
||||
print(f" ✓ Found {len(sites)} binding sites for {struct_data['target'].get('target_name', 'Unknown')}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to find binding sites: {e}")
|
||||
|
||||
# Step 5: Virtual screening (simplified)
|
||||
print(f"\nStep 5: Screening compound libraries...")
|
||||
all_hits = []
|
||||
for site_data in binding_sites:
|
||||
for site in site_data['sites'][:1]: # Top site only
|
||||
try:
|
||||
compounds = tu.run({
|
||||
"name": "ZINC_virtual_screening",
|
||||
"arguments": {
|
||||
"binding_site": site,
|
||||
"library": "lead-like",
|
||||
"top_n": 10
|
||||
}
|
||||
})
|
||||
all_hits.extend(compounds)
|
||||
print(f" ✓ Found {len(compounds)} hit compounds for {site_data['target'].get('target_name', 'Unknown')}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Screening failed: {e}")
|
||||
|
||||
# Step 6: Calculate drug-likeness
|
||||
print(f"\nStep 6: Evaluating drug-likeness...")
|
||||
drug_candidates = []
|
||||
for compound in all_hits:
|
||||
try:
|
||||
properties = tu.run({
|
||||
"name": "RDKit_calculate_drug_properties",
|
||||
"arguments": {"smiles": compound['smiles']}
|
||||
})
|
||||
|
||||
if properties.get('lipinski_pass', False):
|
||||
drug_candidates.append({
|
||||
"compound": compound,
|
||||
"properties": properties
|
||||
})
|
||||
except Exception as e:
|
||||
print(f" ✗ Property calculation failed: {e}")
|
||||
|
||||
print(f"\n ✓ Identified {len(drug_candidates)} drug candidates passing Lipinski's Rule of Five")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("WORKFLOW SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"Disease targets processed: {len(top_targets)}")
|
||||
print(f"Protein structures predicted: {len(structures)}")
|
||||
print(f"Binding sites identified: {sum(len(s['sites']) for s in binding_sites)}")
|
||||
print(f"Compounds screened: {len(all_hits)}")
|
||||
print(f"Drug candidates identified: {len(drug_candidates)}")
|
||||
print("=" * 70)
|
||||
|
||||
return drug_candidates
|
||||
|
||||
|
||||
def genomics_workflow(geo_id: str):
|
||||
"""
|
||||
Execute a genomics analysis workflow.
|
||||
|
||||
Args:
|
||||
geo_id: GEO dataset ID (e.g., "GSE12345")
|
||||
"""
|
||||
tu = ToolUniverse()
|
||||
tu.load_tools()
|
||||
|
||||
print("=" * 70)
|
||||
print("GENOMICS ANALYSIS WORKFLOW")
|
||||
print("=" * 70)
|
||||
|
||||
# Step 1: Download gene expression data
|
||||
print(f"\nStep 1: Downloading dataset {geo_id}...")
|
||||
try:
|
||||
expression_data = tu.run({
|
||||
"name": "GEO_download_dataset",
|
||||
"arguments": {"geo_id": geo_id}
|
||||
})
|
||||
print(f" ✓ Downloaded expression data")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
return
|
||||
|
||||
# Step 2: Differential expression analysis
|
||||
print(f"\nStep 2: Performing differential expression analysis...")
|
||||
try:
|
||||
de_genes = tu.run({
|
||||
"name": "DESeq2_differential_expression",
|
||||
"arguments": {
|
||||
"data": expression_data,
|
||||
"condition1": "control",
|
||||
"condition2": "treated"
|
||||
}
|
||||
})
|
||||
print(f" ✓ Found {len(de_genes.get('significant_genes', []))} differentially expressed genes")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
return
|
||||
|
||||
# Step 3: Pathway enrichment
|
||||
print(f"\nStep 3: Running pathway enrichment analysis...")
|
||||
try:
|
||||
pathways = tu.run({
|
||||
"name": "KEGG_pathway_enrichment",
|
||||
"arguments": {
|
||||
"gene_list": de_genes['significant_genes'],
|
||||
"organism": "hsa"
|
||||
}
|
||||
})
|
||||
print(f" ✓ Found {len(pathways)} enriched pathways")
|
||||
if pathways:
|
||||
print(f" Top pathway: {pathways[0].get('pathway_name', 'Unknown')}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Drug discovery workflow for hypertension
|
||||
print("EXAMPLE 1: Drug Discovery for Hypertension")
|
||||
candidates = drug_discovery_workflow("EFO_0000537", max_targets=2)
|
||||
|
||||
print("\n\n")
|
||||
|
||||
# Example 2: Genomics workflow
|
||||
print("EXAMPLE 2: Genomics Analysis")
|
||||
genomics_workflow("GSE12345")
|
||||
@@ -1,349 +0,0 @@
|
||||
---
|
||||
name: transformers
|
||||
description: Work with state-of-the-art machine learning models for NLP, computer vision, audio, and multimodal tasks using HuggingFace Transformers. This skill should be used when fine-tuning pre-trained models, performing inference with pipelines, generating text, training sequence models, or working with BERT, GPT, T5, ViT, and other transformer architectures. Covers model loading, tokenization, training with Trainer API, text generation strategies, and task-specific patterns for classification, NER, QA, summarization, translation, and image tasks. (plugin:scientific-packages@claude-scientific-skills)
|
||||
---
|
||||
|
||||
# Transformers
|
||||
|
||||
## Overview
|
||||
|
||||
The Transformers library provides state-of-the-art machine learning models for NLP, computer vision, audio, and multimodal tasks. Apply this skill for quick inference through pipelines, comprehensive training via the Trainer API, and flexible text generation with various decoding strategies.
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Quick Inference with Pipelines
|
||||
|
||||
For rapid inference without complex setup, use the `pipeline()` API. Pipelines abstract away tokenization, model invocation, and post-processing.
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Text classification
|
||||
classifier = pipeline("text-classification")
|
||||
result = classifier("This product is amazing!")
|
||||
|
||||
# Named entity recognition
|
||||
ner = pipeline("token-classification")
|
||||
entities = ner("Sarah works at Microsoft in Seattle")
|
||||
|
||||
# Question answering
|
||||
qa = pipeline("question-answering")
|
||||
answer = qa(question="What is the capital?", context="Paris is the capital of France.")
|
||||
|
||||
# Text generation
|
||||
generator = pipeline("text-generation", model="gpt2")
|
||||
text = generator("Once upon a time", max_length=50)
|
||||
|
||||
# Image classification
|
||||
image_classifier = pipeline("image-classification")
|
||||
predictions = image_classifier("image.jpg")
|
||||
```
|
||||
|
||||
**When to use pipelines:**
|
||||
- Quick prototyping and testing
|
||||
- Simple inference tasks without custom logic
|
||||
- Demonstrations and examples
|
||||
- Production inference for standard tasks
|
||||
|
||||
**Available pipeline tasks:**
|
||||
- **NLP**: text-classification, token-classification, question-answering, summarization, translation, text-generation, fill-mask, zero-shot-classification
|
||||
- **Vision**: image-classification, object-detection, image-segmentation, depth-estimation, zero-shot-image-classification
|
||||
- **Audio**: automatic-speech-recognition, audio-classification, text-to-audio
|
||||
- **Multimodal**: image-to-text, visual-question-answering, image-text-to-text
|
||||
|
||||
For comprehensive pipeline documentation, see `references/pipelines.md`.
|
||||
|
||||
### 2. Model Training and Fine-Tuning
|
||||
|
||||
Use the Trainer API for comprehensive model training with support for distributed training, mixed precision, and advanced optimization.
|
||||
|
||||
**Basic training workflow:**
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# 1. Load and tokenize data
|
||||
dataset = load_dataset("imdb")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||
|
||||
# 2. Load model
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=2
|
||||
)
|
||||
|
||||
# 3. Configure training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
# 4. Create trainer and train
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**Key training features:**
|
||||
- Mixed precision training (fp16/bf16)
|
||||
- Distributed training (multi-GPU, multi-node)
|
||||
- Gradient accumulation
|
||||
- Learning rate scheduling with warmup
|
||||
- Checkpoint management
|
||||
- Hyperparameter search
|
||||
- Push to Hugging Face Hub
|
||||
|
||||
For detailed training documentation, see `references/training.md`.
|
||||
|
||||
### 3. Text Generation
|
||||
|
||||
Generate text using various decoding strategies including greedy decoding, beam search, sampling, and more.
|
||||
|
||||
**Generation strategies:**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt")
|
||||
|
||||
# Greedy decoding (deterministic)
|
||||
outputs = model.generate(**inputs, max_new_tokens=50)
|
||||
|
||||
# Beam search (explores multiple hypotheses)
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
num_beams=5,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# Sampling (creative, diverse)
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50
|
||||
)
|
||||
```
|
||||
|
||||
**Generation parameters:**
|
||||
- `temperature`: Controls randomness (0.1-2.0)
|
||||
- `top_k`: Sample from top-k tokens
|
||||
- `top_p`: Nucleus sampling threshold
|
||||
- `num_beams`: Number of beams for beam search
|
||||
- `repetition_penalty`: Discourage repetition
|
||||
- `no_repeat_ngram_size`: Prevent repeating n-grams
|
||||
|
||||
For comprehensive generation documentation, see `references/generation_strategies.md`.
|
||||
|
||||
### 4. Task-Specific Patterns
|
||||
|
||||
Common task patterns with appropriate model classes:
|
||||
|
||||
**Text Classification:**
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=3,
|
||||
id2label={0: "negative", 1: "neutral", 2: "positive"}
|
||||
)
|
||||
```
|
||||
|
||||
**Named Entity Recognition (Token Classification):**
|
||||
```python
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=9 # Number of entity types
|
||||
)
|
||||
```
|
||||
|
||||
**Question Answering:**
|
||||
```python
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
|
||||
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
|
||||
```
|
||||
|
||||
**Summarization and Translation (Seq2Seq):**
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
```
|
||||
|
||||
**Image Classification:**
|
||||
```python
|
||||
from transformers import AutoModelForImageClassification
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"google/vit-base-patch16-224",
|
||||
num_labels=num_classes
|
||||
)
|
||||
```
|
||||
|
||||
For detailed task-specific workflows including data preprocessing, training, and evaluation, see `references/task_patterns.md`.
|
||||
|
||||
## Auto Classes
|
||||
|
||||
Use Auto classes for automatic architecture selection based on model checkpoints:
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer, # Tokenization
|
||||
AutoModel, # Base model (hidden states)
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForCausalLM, # GPT-style
|
||||
AutoModelForMaskedLM, # BERT-style
|
||||
AutoModelForSeq2SeqLM, # T5, BART
|
||||
AutoProcessor, # For multimodal models
|
||||
AutoImageProcessor, # For vision models
|
||||
)
|
||||
|
||||
# Load any model by name
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
|
||||
```
|
||||
|
||||
For comprehensive API documentation, see `references/api_reference.md`.
|
||||
|
||||
## Model Loading and Optimization
|
||||
|
||||
**Device placement:**
|
||||
```python
|
||||
model = AutoModel.from_pretrained("bert-base-uncased", device_map="auto")
|
||||
```
|
||||
|
||||
**Mixed precision:**
|
||||
```python
|
||||
model = AutoModel.from_pretrained(
|
||||
"model-name",
|
||||
torch_dtype=torch.float16 # or torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
**Quantization:**
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Quick Inference Workflow
|
||||
1. Choose appropriate pipeline for task
|
||||
2. Load pipeline with optional model specification
|
||||
3. Pass inputs and get results
|
||||
4. For batch processing, pass list of inputs
|
||||
|
||||
**See:** `scripts/quick_inference.py` for comprehensive pipeline examples
|
||||
|
||||
### Training Workflow
|
||||
1. Load and preprocess dataset using 🤗 Datasets
|
||||
2. Tokenize data with appropriate tokenizer
|
||||
3. Load pre-trained model for specific task
|
||||
4. Configure TrainingArguments
|
||||
5. Create Trainer with model, data, and compute_metrics
|
||||
6. Train with `trainer.train()`
|
||||
7. Evaluate with `trainer.evaluate()`
|
||||
8. Save model and optionally push to Hub
|
||||
|
||||
**See:** `scripts/fine_tune_classifier.py` for complete training example
|
||||
|
||||
### Text Generation Workflow
|
||||
1. Load causal or seq2seq language model
|
||||
2. Load tokenizer and tokenize prompt
|
||||
3. Choose generation strategy (greedy, beam search, sampling)
|
||||
4. Configure generation parameters
|
||||
5. Generate with `model.generate()`
|
||||
6. Decode output tokens to text
|
||||
|
||||
**See:** `scripts/generate_text.py` for generation strategy examples
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Auto classes** for flexibility across different model architectures
|
||||
2. **Batch processing** for efficiency - process multiple inputs at once
|
||||
3. **Device management** - use `device_map="auto"` for automatic placement
|
||||
4. **Memory optimization** - enable fp16/bf16 or quantization for large models
|
||||
5. **Checkpoint management** - save checkpoints regularly and load best model
|
||||
6. **Pipeline for quick tasks** - use pipelines for standard inference tasks
|
||||
7. **Custom metrics** - define compute_metrics for task-specific evaluation
|
||||
8. **Gradient accumulation** - use for large effective batch sizes on limited memory
|
||||
9. **Learning rate warmup** - typically 5-10% of total training steps
|
||||
10. **Hub integration** - push trained models to Hub for sharing and versioning
|
||||
|
||||
## Resources
|
||||
|
||||
### scripts/
|
||||
Executable Python scripts demonstrating common Transformers workflows:
|
||||
|
||||
- `quick_inference.py` - Pipeline examples for NLP, vision, audio, and multimodal tasks
|
||||
- `fine_tune_classifier.py` - Complete fine-tuning workflow with Trainer API
|
||||
- `generate_text.py` - Text generation with various decoding strategies
|
||||
|
||||
Run scripts directly to see examples in action:
|
||||
```bash
|
||||
python scripts/quick_inference.py
|
||||
python scripts/fine_tune_classifier.py
|
||||
python scripts/generate_text.py
|
||||
```
|
||||
|
||||
### references/
|
||||
Comprehensive reference documentation loaded into context as needed:
|
||||
|
||||
- `api_reference.md` - Core classes and APIs (Auto classes, Trainer, GenerationConfig, etc.)
|
||||
- `pipelines.md` - All available pipelines organized by modality with examples
|
||||
- `training.md` - Training patterns, TrainingArguments, distributed training, callbacks
|
||||
- `generation_strategies.md` - Text generation methods, decoding strategies, parameters
|
||||
- `task_patterns.md` - Complete workflows for common tasks (classification, NER, QA, summarization, etc.)
|
||||
|
||||
When working on specific tasks or features, load the relevant reference file for detailed guidance.
|
||||
|
||||
## Additional Information
|
||||
|
||||
- **Official Documentation**: https://huggingface.co/docs/transformers/index
|
||||
- **Model Hub**: https://huggingface.co/models (1M+ pre-trained models)
|
||||
- **Datasets Hub**: https://huggingface.co/datasets
|
||||
- **Installation**: `pip install transformers datasets evaluate accelerate`
|
||||
- **GPU Support**: Requires PyTorch or TensorFlow with CUDA
|
||||
- **Framework Support**: PyTorch (primary), TensorFlow, JAX/Flax
|
||||
@@ -1,485 +0,0 @@
|
||||
# Transformers API Reference
|
||||
|
||||
This reference covers the core classes and APIs in the Transformers library.
|
||||
|
||||
## Core Auto Classes
|
||||
|
||||
Auto classes provide a convenient way to automatically select the appropriate architecture based on model name or checkpoint.
|
||||
|
||||
### AutoTokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Tokenize single text
|
||||
encoded = tokenizer("Hello, how are you?")
|
||||
# Returns: {'input_ids': [...], 'attention_mask': [...]}
|
||||
|
||||
# Tokenize with options
|
||||
encoded = tokenizer(
|
||||
"Hello, how are you?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt" # "pt" for PyTorch, "tf" for TensorFlow
|
||||
)
|
||||
|
||||
# Tokenize pairs (for classification, QA, etc.)
|
||||
encoded = tokenizer(
|
||||
"Question or sentence A",
|
||||
"Context or sentence B",
|
||||
padding=True,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
# Batch tokenization
|
||||
texts = ["Text 1", "Text 2", "Text 3"]
|
||||
encoded = tokenizer(texts, padding=True, truncation=True)
|
||||
|
||||
# Decode tokens back to text
|
||||
text = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
|
||||
# Batch decode
|
||||
texts = tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `padding`: "max_length", "longest", or True (pad to max in batch)
|
||||
- `truncation`: True or strategy ("longest_first", "only_first", "only_second")
|
||||
- `max_length`: Maximum sequence length
|
||||
- `return_tensors`: "pt" (PyTorch), "tf" (TensorFlow), "np" (NumPy)
|
||||
- `return_attention_mask`: Return attention masks (default True)
|
||||
- `return_token_type_ids`: Return token type IDs for pairs (default True)
|
||||
- `add_special_tokens`: Add special tokens like [CLS], [SEP] (default True)
|
||||
|
||||
**Special Properties:**
|
||||
- `tokenizer.vocab_size`: Size of vocabulary
|
||||
- `tokenizer.pad_token_id`: ID of padding token
|
||||
- `tokenizer.eos_token_id`: ID of end-of-sequence token
|
||||
- `tokenizer.bos_token_id`: ID of beginning-of-sequence token
|
||||
- `tokenizer.unk_token_id`: ID of unknown token
|
||||
|
||||
### AutoModel
|
||||
|
||||
Base model class that outputs hidden states.
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Access hidden states
|
||||
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
|
||||
pooler_output = outputs.pooler_output # [batch_size, hidden_size]
|
||||
|
||||
# Get all hidden states
|
||||
model = AutoModel.from_pretrained("bert-base-uncased", output_hidden_states=True)
|
||||
outputs = model(**inputs)
|
||||
all_hidden_states = outputs.hidden_states # Tuple of tensors
|
||||
```
|
||||
|
||||
### Task-Specific Auto Classes
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForVision2Seq,
|
||||
)
|
||||
|
||||
# Sequence classification (sentiment, topic, etc.)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=3,
|
||||
id2label={0: "negative", 1: "neutral", 2: "positive"},
|
||||
label2id={"negative": 0, "neutral": 1, "positive": 2}
|
||||
)
|
||||
|
||||
# Token classification (NER, POS tagging)
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=9 # Number of entity types
|
||||
)
|
||||
|
||||
# Question answering
|
||||
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Causal language modeling (GPT-style)
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
# Masked language modeling (BERT-style)
|
||||
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Sequence-to-sequence (T5, BART)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
# Image classification
|
||||
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
# Object detection
|
||||
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
||||
|
||||
# Vision-to-text (image captioning, VQA)
|
||||
model = AutoModelForVision2Seq.from_pretrained("microsoft/git-base")
|
||||
```
|
||||
|
||||
### AutoProcessor
|
||||
|
||||
For multimodal models that need both text and image processing.
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
# For vision-language models
|
||||
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
||||
|
||||
# Process image and text
|
||||
from PIL import Image
|
||||
image = Image.open("image.jpg")
|
||||
inputs = processor(images=image, text="caption", return_tensors="pt")
|
||||
|
||||
# For audio models
|
||||
processor = AutoProcessor.from_pretrained("openai/whisper-base")
|
||||
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
|
||||
```
|
||||
|
||||
### AutoImageProcessor
|
||||
|
||||
For vision-only models.
|
||||
|
||||
```python
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
# Process single image
|
||||
from PIL import Image
|
||||
image = Image.open("image.jpg")
|
||||
inputs = processor(image, return_tensors="pt")
|
||||
|
||||
# Batch processing
|
||||
images = [Image.open(f"image{i}.jpg") for i in range(10)]
|
||||
inputs = processor(images, return_tensors="pt")
|
||||
```
|
||||
|
||||
## Model Loading Options
|
||||
|
||||
### from_pretrained Parameters
|
||||
|
||||
```python
|
||||
model = AutoModel.from_pretrained(
|
||||
"model-name",
|
||||
# Device and precision
|
||||
device_map="auto", # Automatic device placement
|
||||
torch_dtype=torch.float16, # Use fp16
|
||||
low_cpu_mem_usage=True, # Reduce CPU memory during loading
|
||||
|
||||
# Quantization
|
||||
load_in_8bit=True, # 8-bit quantization
|
||||
load_in_4bit=True, # 4-bit quantization
|
||||
|
||||
# Model configuration
|
||||
num_labels=3, # For classification
|
||||
id2label={...}, # Label mapping
|
||||
label2id={...},
|
||||
|
||||
# Outputs
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
|
||||
# Trust remote code
|
||||
trust_remote_code=True, # For custom models
|
||||
|
||||
# Caching
|
||||
cache_dir="./cache",
|
||||
force_download=False,
|
||||
resume_download=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Quantization with BitsAndBytes
|
||||
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
|
||||
|
||||
# 4-bit quantization
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Training Components
|
||||
|
||||
### TrainingArguments
|
||||
|
||||
See `training.md` for comprehensive coverage. Key parameters:
|
||||
|
||||
```python
|
||||
from transformers import TrainingArguments
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
weight_decay=0.01,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
fp16=True,
|
||||
logging_steps=100,
|
||||
save_total_limit=2,
|
||||
)
|
||||
```
|
||||
|
||||
### Trainer
|
||||
|
||||
```python
|
||||
from transformers import Trainer
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=data_collator,
|
||||
callbacks=[callback1, callback2],
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Resume from checkpoint
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
|
||||
# Evaluate
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
# Predict
|
||||
predictions = trainer.predict(test_dataset)
|
||||
|
||||
# Hyperparameter search
|
||||
best_trial = trainer.hyperparameter_search(
|
||||
direction="maximize",
|
||||
backend="optuna",
|
||||
n_trials=10,
|
||||
)
|
||||
|
||||
# Save model
|
||||
trainer.save_model("./final_model")
|
||||
|
||||
# Push to Hub
|
||||
trainer.push_to_hub(commit_message="Training complete")
|
||||
```
|
||||
|
||||
### Data Collators
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
DataCollatorWithPadding,
|
||||
DataCollatorForTokenClassification,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorForLanguageModeling,
|
||||
DefaultDataCollator,
|
||||
)
|
||||
|
||||
# For classification/regression with dynamic padding
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# For token classification (NER)
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
||||
|
||||
# For seq2seq tasks
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
|
||||
|
||||
# For language modeling
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True, # True for masked LM, False for causal LM
|
||||
mlm_probability=0.15
|
||||
)
|
||||
|
||||
# Default (no special handling)
|
||||
data_collator = DefaultDataCollator()
|
||||
```
|
||||
|
||||
## Generation Components
|
||||
|
||||
### GenerationConfig
|
||||
|
||||
See `generation_strategies.md` for comprehensive coverage.
|
||||
|
||||
```python
|
||||
from transformers import GenerationConfig
|
||||
|
||||
config = GenerationConfig(
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
num_beams=5,
|
||||
repetition_penalty=1.2,
|
||||
no_repeat_ngram_size=3,
|
||||
)
|
||||
|
||||
# Use with model
|
||||
outputs = model.generate(**inputs, generation_config=config)
|
||||
```
|
||||
|
||||
### generate() Method
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
num_return_sequences=3,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
```
|
||||
|
||||
## Pipeline API
|
||||
|
||||
See `pipelines.md` for comprehensive coverage.
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Basic usage
|
||||
pipe = pipeline("task-name", model="model-name", device=0)
|
||||
results = pipe(inputs)
|
||||
|
||||
# With custom model
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("model-name")
|
||||
tokenizer = AutoTokenizer.from_pretrained("model-name")
|
||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
## Configuration Classes
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
# Load configuration
|
||||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Access configuration
|
||||
print(config.hidden_size)
|
||||
print(config.num_attention_heads)
|
||||
print(config.num_hidden_layers)
|
||||
|
||||
# Modify configuration
|
||||
config.num_labels = 5
|
||||
config.output_hidden_states = True
|
||||
|
||||
# Create model with config
|
||||
model = AutoModel.from_config(config)
|
||||
|
||||
# Save configuration
|
||||
config.save_pretrained("./config")
|
||||
```
|
||||
|
||||
## Utilities
|
||||
|
||||
### Hub Utilities
|
||||
|
||||
```python
|
||||
from huggingface_hub import login, snapshot_download
|
||||
|
||||
# Login
|
||||
login(token="hf_...")
|
||||
|
||||
# Download model
|
||||
snapshot_download(repo_id="model-name", cache_dir="./cache")
|
||||
|
||||
# Push to Hub
|
||||
model.push_to_hub("username/model-name", commit_message="Initial commit")
|
||||
tokenizer.push_to_hub("username/model-name")
|
||||
```
|
||||
|
||||
### Evaluation Metrics
|
||||
|
||||
```python
|
||||
import evaluate
|
||||
|
||||
# Load metric
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
# Compute metric
|
||||
results = metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
# Common metrics
|
||||
accuracy = evaluate.load("accuracy")
|
||||
precision = evaluate.load("precision")
|
||||
recall = evaluate.load("recall")
|
||||
f1 = evaluate.load("f1")
|
||||
bleu = evaluate.load("bleu")
|
||||
rouge = evaluate.load("rouge")
|
||||
```
|
||||
|
||||
## Model Outputs
|
||||
|
||||
All models return dataclass objects with named attributes:
|
||||
|
||||
```python
|
||||
# Sequence classification output
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits # [batch_size, num_labels]
|
||||
loss = outputs.loss # If labels provided
|
||||
|
||||
# Causal LM output
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits # [batch_size, seq_length, vocab_size]
|
||||
past_key_values = outputs.past_key_values # KV cache
|
||||
|
||||
# Seq2Seq output
|
||||
outputs = model(**inputs, labels=labels)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||
|
||||
# Access as dict
|
||||
outputs_dict = outputs.to_tuple() # or dict(outputs)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Auto classes**: AutoModel, AutoTokenizer for flexibility
|
||||
2. **Device management**: Use `device_map="auto"` for multi-GPU
|
||||
3. **Memory optimization**: Use `torch_dtype=torch.float16` and quantization
|
||||
4. **Caching**: Set `cache_dir` to avoid re-downloading
|
||||
5. **Batch processing**: Process multiple inputs at once for efficiency
|
||||
6. **Trust remote code**: Only set `trust_remote_code=True` for trusted sources
|
||||
@@ -1,373 +0,0 @@
|
||||
# Text Generation Strategies
|
||||
|
||||
Transformers provides flexible text generation capabilities through the `generate()` method, supporting multiple decoding strategies and configuration options.
|
||||
|
||||
## Basic Generation
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, max_new_tokens=50)
|
||||
generated_text = tokenizer.decode(outputs[0])
|
||||
```
|
||||
|
||||
## Decoding Strategies
|
||||
|
||||
### 1. Greedy Decoding
|
||||
|
||||
Selects the token with highest probability at each step. Deterministic but can be repetitive.
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
do_sample=False,
|
||||
num_beams=1 # Greedy is default when num_beams=1 and do_sample=False
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Beam Search
|
||||
|
||||
Explores multiple hypotheses simultaneously, keeping top-k candidates at each step.
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
num_beams=5, # Number of beams
|
||||
early_stopping=True, # Stop when all beams reach EOS
|
||||
no_repeat_ngram_size=2, # Prevent repeating n-grams
|
||||
)
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `num_beams`: Number of beams (higher = more thorough but slower)
|
||||
- `early_stopping`: Stop when all beams finish (True/False)
|
||||
- `length_penalty`: Exponential penalty for length (>1.0 favors longer sequences)
|
||||
- `no_repeat_ngram_size`: Prevent repeating n-grams
|
||||
|
||||
### 3. Sampling (Multinomial)
|
||||
|
||||
Samples from probability distribution, introducing randomness and diversity.
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
do_sample=True,
|
||||
temperature=0.7, # Controls randomness (lower = more focused)
|
||||
top_k=50, # Consider only top-k tokens
|
||||
top_p=0.9, # Nucleus sampling (cumulative probability threshold)
|
||||
)
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `temperature`: Scales logits before softmax (0.1-2.0 typical range)
|
||||
- Lower (0.1-0.7): More focused, deterministic
|
||||
- Higher (0.8-1.5): More creative, random
|
||||
- `top_k`: Sample from top-k tokens only
|
||||
- `top_p`: Nucleus sampling - sample from smallest set with cumulative probability > p
|
||||
|
||||
### 4. Beam Search with Sampling
|
||||
|
||||
Combines beam search with sampling for diverse but coherent outputs.
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
num_beams=5,
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Contrastive Search
|
||||
|
||||
Balances coherence and diversity using contrastive objective.
|
||||
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=50,
|
||||
penalty_alpha=0.6, # Contrastive penalty
|
||||
top_k=4, # Consider top-k candidates
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Assisted Decoding
|
||||
|
||||
Uses a smaller "assistant" model to speed up generation of larger model.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2-large")
|
||||
assistant_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_model=assistant_model,
|
||||
max_new_tokens=50,
|
||||
)
|
||||
```
|
||||
|
||||
## GenerationConfig
|
||||
|
||||
Configure generation parameters with `GenerationConfig` for reusability.
|
||||
|
||||
```python
|
||||
from transformers import GenerationConfig
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.2,
|
||||
no_repeat_ngram_size=3,
|
||||
)
|
||||
|
||||
# Use with model
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
# Save and load
|
||||
generation_config.save_pretrained("./config")
|
||||
loaded_config = GenerationConfig.from_pretrained("./config")
|
||||
```
|
||||
|
||||
## Key Parameters Reference
|
||||
|
||||
### Output Length Control
|
||||
|
||||
- `max_length`: Maximum total tokens (input + output)
|
||||
- `max_new_tokens`: Maximum new tokens to generate (recommended over max_length)
|
||||
- `min_length`: Minimum total tokens
|
||||
- `min_new_tokens`: Minimum new tokens to generate
|
||||
|
||||
### Sampling Parameters
|
||||
|
||||
- `temperature`: Sampling temperature (0.1-2.0, default 1.0)
|
||||
- `top_k`: Top-k sampling (1-100, typically 50)
|
||||
- `top_p`: Nucleus sampling (0.0-1.0, typically 0.9)
|
||||
- `do_sample`: Enable sampling (True/False)
|
||||
|
||||
### Beam Search Parameters
|
||||
|
||||
- `num_beams`: Number of beams (1-20, typically 5)
|
||||
- `early_stopping`: Stop when beams finish (True/False)
|
||||
- `length_penalty`: Length penalty (>1.0 favors longer, <1.0 favors shorter)
|
||||
- `num_beam_groups`: Diverse beam search groups
|
||||
- `diversity_penalty`: Penalty for similar beams
|
||||
|
||||
### Repetition Control
|
||||
|
||||
- `repetition_penalty`: Penalty for repeating tokens (1.0-2.0, default 1.0)
|
||||
- `no_repeat_ngram_size`: Prevent repeating n-grams (2-5 typical)
|
||||
- `encoder_repetition_penalty`: Penalty for repeating encoder tokens
|
||||
|
||||
### Special Tokens
|
||||
|
||||
- `bos_token_id`: Beginning of sequence token
|
||||
- `eos_token_id`: End of sequence token (or list of tokens)
|
||||
- `pad_token_id`: Padding token
|
||||
- `forced_bos_token_id`: Force specific token at beginning
|
||||
- `forced_eos_token_id`: Force specific token at end
|
||||
|
||||
### Multiple Sequences
|
||||
|
||||
- `num_return_sequences`: Number of sequences to return
|
||||
- `num_beam_groups`: Number of diverse beam groups
|
||||
|
||||
## Advanced Generation Techniques
|
||||
|
||||
### Constrained Generation
|
||||
|
||||
Force generation to include specific tokens or follow patterns.
|
||||
|
||||
```python
|
||||
from transformers import PhrasalConstraint
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(tokenizer("New York", add_special_tokens=False).input_ids)
|
||||
]
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
constraints=constraints,
|
||||
num_beams=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Streaming Generation
|
||||
|
||||
Generate tokens one at a time for real-time display.
|
||||
|
||||
```python
|
||||
from transformers import TextIteratorStreamer
|
||||
from threading import Thread
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
||||
|
||||
generation_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
for new_text in streamer:
|
||||
print(new_text, end="", flush=True)
|
||||
|
||||
thread.join()
|
||||
```
|
||||
|
||||
### Logit Processors
|
||||
|
||||
Customize token selection with custom logit processors.
|
||||
|
||||
```python
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
class CustomLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids, scores):
|
||||
# Modify scores here
|
||||
return scores
|
||||
|
||||
logits_processor = LogitsProcessorList([CustomLogitsProcessor()])
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
```
|
||||
|
||||
### Stopping Criteria
|
||||
|
||||
Define custom stopping conditions.
|
||||
|
||||
```python
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
class CustomStoppingCriteria(StoppingCriteria):
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
# Return True to stop generation
|
||||
return False
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria()])
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### For Creative Tasks (Stories, Dialogue)
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=200,
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.2,
|
||||
no_repeat_ngram_size=3,
|
||||
)
|
||||
```
|
||||
|
||||
### For Factual Tasks (Summaries, QA)
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
num_beams=4,
|
||||
early_stopping=True,
|
||||
no_repeat_ngram_size=2,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
```
|
||||
|
||||
### For Chat/Instruction Following
|
||||
```python
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
```
|
||||
|
||||
## Vision-Language Model Generation
|
||||
|
||||
For models like LLaVA, BLIP-2, etc.:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
from PIL import Image
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
image = Image.open("image.jpg")
|
||||
inputs = processor(text="Describe this image", images=image, return_tensors="pt")
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
generated_text = processor.decode(outputs[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Use KV Cache
|
||||
```python
|
||||
# KV cache is enabled by default
|
||||
outputs = model.generate(**inputs, use_cache=True)
|
||||
```
|
||||
|
||||
### Mixed Precision
|
||||
```python
|
||||
import torch
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
```
|
||||
|
||||
### Batch Generation
|
||||
```python
|
||||
texts = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True)
|
||||
outputs = model.generate(**inputs, max_new_tokens=50)
|
||||
```
|
||||
|
||||
### Quantization
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
@@ -1,234 +0,0 @@
|
||||
# Transformers Pipelines
|
||||
|
||||
Pipelines provide a simple and optimized interface for inference across many machine learning tasks. They abstract away the complexity of tokenization, model invocation, and post-processing.
|
||||
|
||||
## Usage Pattern
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Basic usage
|
||||
classifier = pipeline("text-classification")
|
||||
result = classifier("This movie was amazing!")
|
||||
|
||||
# With specific model
|
||||
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
|
||||
result = classifier("This movie was amazing!")
|
||||
```
|
||||
|
||||
## Natural Language Processing Pipelines
|
||||
|
||||
### Text Classification
|
||||
```python
|
||||
classifier = pipeline("text-classification")
|
||||
classifier("I love this product!")
|
||||
# [{'label': 'POSITIVE', 'score': 0.9998}]
|
||||
```
|
||||
|
||||
### Zero-Shot Classification
|
||||
```python
|
||||
classifier = pipeline("zero-shot-classification")
|
||||
classifier("This is about climate change", candidate_labels=["politics", "science", "sports"])
|
||||
```
|
||||
|
||||
### Token Classification (NER)
|
||||
```python
|
||||
ner = pipeline("token-classification")
|
||||
ner("My name is Sarah and I work at Microsoft in Seattle")
|
||||
```
|
||||
|
||||
### Question Answering
|
||||
```python
|
||||
qa = pipeline("question-answering")
|
||||
qa(question="What is the capital?", context="The capital of France is Paris.")
|
||||
```
|
||||
|
||||
### Text Generation
|
||||
```python
|
||||
generator = pipeline("text-generation")
|
||||
generator("Once upon a time", max_length=50)
|
||||
```
|
||||
|
||||
### Text2Text Generation
|
||||
```python
|
||||
generator = pipeline("text2text-generation", model="t5-base")
|
||||
generator("translate English to French: Hello")
|
||||
```
|
||||
|
||||
### Summarization
|
||||
```python
|
||||
summarizer = pipeline("summarization")
|
||||
summarizer("Long article text here...", max_length=130, min_length=30)
|
||||
```
|
||||
|
||||
### Translation
|
||||
```python
|
||||
translator = pipeline("translation_en_to_fr")
|
||||
translator("Hello, how are you?")
|
||||
```
|
||||
|
||||
### Fill Mask
|
||||
```python
|
||||
unmasker = pipeline("fill-mask")
|
||||
unmasker("Paris is the [MASK] of France.")
|
||||
```
|
||||
|
||||
### Feature Extraction
|
||||
```python
|
||||
extractor = pipeline("feature-extraction")
|
||||
embeddings = extractor("This is a sentence")
|
||||
```
|
||||
|
||||
### Document Question Answering
|
||||
```python
|
||||
doc_qa = pipeline("document-question-answering")
|
||||
doc_qa(image="document.png", question="What is the invoice number?")
|
||||
```
|
||||
|
||||
### Table Question Answering
|
||||
```python
|
||||
table_qa = pipeline("table-question-answering")
|
||||
table_qa(table=data, query="How many employees?")
|
||||
```
|
||||
|
||||
## Computer Vision Pipelines
|
||||
|
||||
### Image Classification
|
||||
```python
|
||||
classifier = pipeline("image-classification")
|
||||
classifier("cat.jpg")
|
||||
```
|
||||
|
||||
### Zero-Shot Image Classification
|
||||
```python
|
||||
classifier = pipeline("zero-shot-image-classification")
|
||||
classifier("cat.jpg", candidate_labels=["cat", "dog", "bird"])
|
||||
```
|
||||
|
||||
### Object Detection
|
||||
```python
|
||||
detector = pipeline("object-detection")
|
||||
detector("street.jpg")
|
||||
```
|
||||
|
||||
### Image Segmentation
|
||||
```python
|
||||
segmenter = pipeline("image-segmentation")
|
||||
segmenter("image.jpg")
|
||||
```
|
||||
|
||||
### Image-to-Image
|
||||
```python
|
||||
img2img = pipeline("image-to-image", model="lllyasviel/sd-controlnet-canny")
|
||||
img2img("input.jpg")
|
||||
```
|
||||
|
||||
### Depth Estimation
|
||||
```python
|
||||
depth = pipeline("depth-estimation")
|
||||
depth("image.jpg")
|
||||
```
|
||||
|
||||
### Video Classification
|
||||
```python
|
||||
classifier = pipeline("video-classification")
|
||||
classifier("video.mp4")
|
||||
```
|
||||
|
||||
### Keypoint Matching
|
||||
```python
|
||||
matcher = pipeline("keypoint-matching")
|
||||
matcher(image1="img1.jpg", image2="img2.jpg")
|
||||
```
|
||||
|
||||
## Audio Pipelines
|
||||
|
||||
### Automatic Speech Recognition
|
||||
```python
|
||||
asr = pipeline("automatic-speech-recognition")
|
||||
asr("audio.wav")
|
||||
```
|
||||
|
||||
### Audio Classification
|
||||
```python
|
||||
classifier = pipeline("audio-classification")
|
||||
classifier("audio.wav")
|
||||
```
|
||||
|
||||
### Zero-Shot Audio Classification
|
||||
```python
|
||||
classifier = pipeline("zero-shot-audio-classification")
|
||||
classifier("audio.wav", candidate_labels=["speech", "music", "noise"])
|
||||
```
|
||||
|
||||
### Text-to-Audio/Text-to-Speech
|
||||
```python
|
||||
synthesizer = pipeline("text-to-audio")
|
||||
audio = synthesizer("Hello, how are you today?")
|
||||
```
|
||||
|
||||
## Multimodal Pipelines
|
||||
|
||||
### Image-to-Text (Image Captioning)
|
||||
```python
|
||||
captioner = pipeline("image-to-text")
|
||||
captioner("image.jpg")
|
||||
```
|
||||
|
||||
### Visual Question Answering
|
||||
```python
|
||||
vqa = pipeline("visual-question-answering")
|
||||
vqa(image="image.jpg", question="What color is the car?")
|
||||
```
|
||||
|
||||
### Image-Text-to-Text (VLMs)
|
||||
```python
|
||||
vlm = pipeline("image-text-to-text")
|
||||
vlm(images="image.jpg", text="Describe this image in detail")
|
||||
```
|
||||
|
||||
### Zero-Shot Object Detection
|
||||
```python
|
||||
detector = pipeline("zero-shot-object-detection")
|
||||
detector("image.jpg", candidate_labels=["car", "person", "tree"])
|
||||
```
|
||||
|
||||
## Pipeline Configuration
|
||||
|
||||
### Common Parameters
|
||||
|
||||
- `model`: Specify model identifier or path
|
||||
- `device`: Set device (0 for GPU, -1 for CPU, or "cuda:0")
|
||||
- `batch_size`: Process multiple inputs at once
|
||||
- `torch_dtype`: Set precision (torch.float16, torch.bfloat16)
|
||||
|
||||
```python
|
||||
# GPU with half precision
|
||||
pipe = pipeline("text-generation", model="gpt2", device=0, torch_dtype=torch.float16)
|
||||
|
||||
# Batch processing
|
||||
pipe(["text 1", "text 2", "text 3"], batch_size=8)
|
||||
```
|
||||
|
||||
### Task-Specific Parameters
|
||||
|
||||
Each pipeline accepts task-specific parameters in the call:
|
||||
|
||||
```python
|
||||
# Text generation
|
||||
generator("prompt", max_length=100, temperature=0.7, top_p=0.9, num_return_sequences=3)
|
||||
|
||||
# Summarization
|
||||
summarizer("text", max_length=130, min_length=30, do_sample=False)
|
||||
|
||||
# Translation
|
||||
translator("text", max_length=512, num_beams=4)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Reuse pipelines**: Create once, use multiple times for efficiency
|
||||
2. **Batch processing**: Use batches for multiple inputs to maximize throughput
|
||||
3. **GPU acceleration**: Set `device=0` for GPU when available
|
||||
4. **Model selection**: Choose task-specific models for best results
|
||||
5. **Memory management**: Use `torch_dtype=torch.float16` for large models
|
||||
@@ -1,599 +0,0 @@
|
||||
# Common Task Patterns
|
||||
|
||||
This document provides common patterns and workflows for typical tasks using Transformers.
|
||||
|
||||
## Text Classification
|
||||
|
||||
### Binary or Multi-class Classification
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
import evaluate
|
||||
import numpy as np
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("imdb")
|
||||
|
||||
# Tokenize
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||
|
||||
# Load model
|
||||
id2label = {0: "negative", 1: "positive"}
|
||||
label2id = {"negative": 0, "positive": 1}
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=2,
|
||||
id2label=id2label,
|
||||
label2id=label2id
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
logits, labels = eval_pred
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
num_train_epochs=3,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Inference
|
||||
text = "This movie was fantastic!"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits.argmax(-1)
|
||||
print(id2label[predictions.item()])
|
||||
```
|
||||
|
||||
## Named Entity Recognition (Token Classification)
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForTokenClassification,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForTokenClassification
|
||||
)
|
||||
from datasets import load_dataset
|
||||
import evaluate
|
||||
import numpy as np
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("conll2003")
|
||||
|
||||
# Tokenize (align labels with tokenized words)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def tokenize_and_align_labels(examples):
|
||||
tokenized_inputs = tokenizer(
|
||||
examples["tokens"],
|
||||
truncation=True,
|
||||
is_split_into_words=True
|
||||
)
|
||||
|
||||
labels = []
|
||||
for i, label in enumerate(examples["ner_tags"]):
|
||||
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
||||
label_ids = []
|
||||
previous_word_idx = None
|
||||
for word_idx in word_ids:
|
||||
if word_idx is None:
|
||||
label_ids.append(-100)
|
||||
elif word_idx != previous_word_idx:
|
||||
label_ids.append(label[word_idx])
|
||||
else:
|
||||
label_ids.append(-100)
|
||||
previous_word_idx = word_idx
|
||||
labels.append(label_ids)
|
||||
|
||||
tokenized_inputs["labels"] = labels
|
||||
return tokenized_inputs
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
|
||||
|
||||
# Model
|
||||
label_list = dataset["train"].features["ner_tags"].feature.names
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=len(label_list)
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||
|
||||
# Metrics
|
||||
metric = evaluate.load("seqeval")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
true_labels = [[label_list[l] for l in label if l != -100] for label in labels]
|
||||
true_predictions = [
|
||||
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
|
||||
return metric.compute(predictions=true_predictions, references=true_labels)
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=16,
|
||||
num_train_epochs=3,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Question Answering
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForQuestionAnswering,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DefaultDataCollator
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("squad")
|
||||
|
||||
# Tokenize
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def preprocess_function(examples):
|
||||
questions = [q.strip() for q in examples["question"]]
|
||||
inputs = tokenizer(
|
||||
questions,
|
||||
examples["context"],
|
||||
max_length=384,
|
||||
truncation="only_second",
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
offset_mapping = inputs.pop("offset_mapping")
|
||||
answers = examples["answers"]
|
||||
start_positions = []
|
||||
end_positions = []
|
||||
|
||||
for i, offset in enumerate(offset_mapping):
|
||||
answer = answers[i]
|
||||
start_char = answer["answer_start"][0]
|
||||
end_char = start_char + len(answer["text"][0])
|
||||
|
||||
# Find start and end token positions
|
||||
sequence_ids = inputs.sequence_ids(i)
|
||||
context_start = sequence_ids.index(1)
|
||||
context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1)
|
||||
|
||||
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
|
||||
start_positions.append(0)
|
||||
end_positions.append(0)
|
||||
else:
|
||||
idx = context_start
|
||||
while idx <= context_end and offset[idx][0] <= start_char:
|
||||
idx += 1
|
||||
start_positions.append(idx - 1)
|
||||
|
||||
idx = context_end
|
||||
while idx >= context_start and offset[idx][1] >= end_char:
|
||||
idx -= 1
|
||||
end_positions.append(idx + 1)
|
||||
|
||||
inputs["start_positions"] = start_positions
|
||||
inputs["end_positions"] = end_positions
|
||||
return inputs
|
||||
|
||||
tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
# Model
|
||||
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=16,
|
||||
num_train_epochs=3,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["validation"],
|
||||
data_collator=DefaultDataCollator(),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Inference
|
||||
question = "What is the capital of France?"
|
||||
context = "Paris is the capital and most populous city of France."
|
||||
inputs = tokenizer(question, context, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
|
||||
start_pos = outputs.start_logits.argmax()
|
||||
end_pos = outputs.end_logits.argmax()
|
||||
answer_tokens = inputs.input_ids[0][start_pos:end_pos+1]
|
||||
answer = tokenizer.decode(answer_tokens)
|
||||
```
|
||||
|
||||
## Text Summarization
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSeq2SeqLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForSeq2Seq
|
||||
)
|
||||
from datasets import load_dataset
|
||||
import evaluate
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("cnn_dailymail", "3.0.0")
|
||||
|
||||
# Tokenize
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
||||
|
||||
def preprocess_function(examples):
|
||||
inputs = ["summarize: " + doc for doc in examples["article"]]
|
||||
model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
|
||||
|
||||
labels = tokenizer(
|
||||
text_target=examples["highlights"],
|
||||
max_length=128,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
tokenized_datasets = dataset.map(preprocess_function, batched=True)
|
||||
|
||||
# Model
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
||||
|
||||
# Metrics
|
||||
rouge = evaluate.load("rouge")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
result = rouge.compute(
|
||||
predictions=decoded_preds,
|
||||
references=decoded_labels,
|
||||
use_stemmer=True
|
||||
)
|
||||
|
||||
return {k: round(v, 4) for k, v in result.items()}
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=8,
|
||||
per_device_eval_batch_size=8,
|
||||
num_train_epochs=3,
|
||||
predict_with_generate=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Inference
|
||||
text = "Long article text..."
|
||||
inputs = tokenizer("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
|
||||
outputs = model.generate(**inputs, max_length=128, num_beams=4)
|
||||
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
## Translation
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSeq2SeqLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForSeq2Seq
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wmt16", "de-en")
|
||||
|
||||
# Tokenize
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
||||
|
||||
def preprocess_function(examples):
|
||||
inputs = [f"translate German to English: {de}" for de in examples["de"]]
|
||||
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
|
||||
|
||||
labels = tokenizer(
|
||||
text_target=examples["en"],
|
||||
max_length=128,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
tokenized_datasets = dataset.map(preprocess_function, batched=True)
|
||||
|
||||
# Model and training (similar to summarization)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
|
||||
# Inference
|
||||
text = "Guten Tag, wie geht es Ihnen?"
|
||||
inputs = tokenizer(f"translate German to English: {text}", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, max_length=128)
|
||||
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
## Causal Language Modeling (Training from Scratch or Fine-tuning)
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
|
||||
|
||||
# Tokenize
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], truncation=True, max_length=512)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
# Group texts into chunks
|
||||
block_size = 128
|
||||
|
||||
def group_texts(examples):
|
||||
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i:i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
|
||||
|
||||
# Model
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=8,
|
||||
num_train_epochs=3,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=lm_datasets["train"],
|
||||
eval_dataset=lm_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Image Classification
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageClassification,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
||||
import numpy as np
|
||||
import evaluate
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("food101", split="train[:5000]")
|
||||
|
||||
# Prepare image transforms
|
||||
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
size = image_processor.size["height"]
|
||||
|
||||
transforms = Compose([
|
||||
Resize((size, size)),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
def preprocess_function(examples):
|
||||
examples["pixel_values"] = [transforms(img.convert("RGB")) for img in examples["image"]]
|
||||
return examples
|
||||
|
||||
dataset = dataset.with_transform(preprocess_function)
|
||||
|
||||
# Model
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"google/vit-base-patch16-224",
|
||||
num_labels=len(dataset["train"].features["label"].names),
|
||||
ignore_mismatched_sizes=True
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions = np.argmax(eval_pred.predictions, axis=1)
|
||||
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=5e-5,
|
||||
per_device_train_batch_size=16,
|
||||
num_train_epochs=3,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["validation"],
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Vision-Language Tasks (Image Captioning)
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoModelForVision2Seq,
|
||||
TrainingArguments,
|
||||
Trainer
|
||||
)
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("ybelkada/football-dataset")
|
||||
|
||||
# Processor
|
||||
processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
||||
|
||||
def preprocess_function(examples):
|
||||
images = [Image.open(img).convert("RGB") for img in examples["image"]]
|
||||
texts = examples["caption"]
|
||||
|
||||
inputs = processor(images=images, text=texts, padding="max_length", truncation=True)
|
||||
inputs["labels"] = inputs["input_ids"]
|
||||
return inputs
|
||||
|
||||
dataset = dataset.map(preprocess_function, batched=True)
|
||||
|
||||
# Model
|
||||
model = AutoModelForVision2Seq.from_pretrained("microsoft/git-base")
|
||||
|
||||
# Train
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="epoch",
|
||||
learning_rate=5e-5,
|
||||
per_device_train_batch_size=8,
|
||||
num_train_epochs=3,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Inference
|
||||
image = Image.open("image.jpg")
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model.generate(**inputs)
|
||||
caption = processor.decode(outputs[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
## Best Practices Summary
|
||||
|
||||
1. **Use appropriate Auto* classes**: AutoTokenizer, AutoModel, etc. for model loading
|
||||
2. **Proper preprocessing**: Tokenize, align labels, handle special cases
|
||||
3. **Data collators**: Use appropriate collators for dynamic padding
|
||||
4. **Metrics**: Load and compute relevant metrics for evaluation
|
||||
5. **Training arguments**: Configure properly for task and hardware
|
||||
6. **Inference**: Use pipeline() for quick inference, or manual tokenization for custom needs
|
||||
@@ -1,328 +0,0 @@
|
||||
# Training with Transformers
|
||||
|
||||
Transformers provides comprehensive training capabilities through the `Trainer` API, supporting distributed training, mixed precision, and advanced optimization techniques.
|
||||
|
||||
## Basic Training Workflow
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments
|
||||
)
|
||||
from datasets import load_dataset
|
||||
|
||||
# 1. Load and preprocess data
|
||||
dataset = load_dataset("imdb")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
||||
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||
|
||||
# 2. Load model
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=2
|
||||
)
|
||||
|
||||
# 3. Define training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
# 4. Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
)
|
||||
|
||||
# 5. Train
|
||||
trainer.train()
|
||||
|
||||
# 6. Evaluate
|
||||
trainer.evaluate()
|
||||
|
||||
# 7. Save model
|
||||
trainer.save_model("./final_model")
|
||||
```
|
||||
|
||||
## TrainingArguments Configuration
|
||||
|
||||
### Essential Parameters
|
||||
|
||||
**Output and Logging:**
|
||||
- `output_dir`: Directory for checkpoints and outputs (required)
|
||||
- `logging_dir`: TensorBoard log directory (default: `{output_dir}/runs`)
|
||||
- `logging_steps`: Log every N steps (default: 500)
|
||||
- `logging_strategy`: "steps" or "epoch"
|
||||
|
||||
**Training Duration:**
|
||||
- `num_train_epochs`: Number of epochs (default: 3.0)
|
||||
- `max_steps`: Max training steps (overrides num_train_epochs if set)
|
||||
|
||||
**Batch Size and Gradient Accumulation:**
|
||||
- `per_device_train_batch_size`: Batch size per device (default: 8)
|
||||
- `per_device_eval_batch_size`: Eval batch size per device (default: 8)
|
||||
- `gradient_accumulation_steps`: Accumulate gradients over N steps (default: 1)
|
||||
- Effective batch size = `per_device_train_batch_size * gradient_accumulation_steps * num_gpus`
|
||||
|
||||
**Learning Rate:**
|
||||
- `learning_rate`: Peak learning rate (default: 5e-5)
|
||||
- `lr_scheduler_type`: Scheduler type ("linear", "cosine", "constant", etc.)
|
||||
- `warmup_steps`: Warmup steps (default: 0)
|
||||
- `warmup_ratio`: Warmup as fraction of total steps
|
||||
|
||||
**Evaluation:**
|
||||
- `eval_strategy`: "no", "steps", or "epoch" (default: "no")
|
||||
- `eval_steps`: Evaluate every N steps (if eval_strategy="steps")
|
||||
- `eval_delay`: Delay evaluation until N steps
|
||||
|
||||
**Checkpointing:**
|
||||
- `save_strategy`: "no", "steps", or "epoch" (default: "steps")
|
||||
- `save_steps`: Save checkpoint every N steps (default: 500)
|
||||
- `save_total_limit`: Keep only N most recent checkpoints
|
||||
- `load_best_model_at_end`: Load best checkpoint at end (default: False)
|
||||
- `metric_for_best_model`: Metric to determine best model
|
||||
|
||||
**Optimization:**
|
||||
- `optim`: Optimizer ("adamw_torch", "adamw_hf", "sgd", etc.)
|
||||
- `weight_decay`: Weight decay coefficient (default: 0.0)
|
||||
- `adam_beta1`, `adam_beta2`: Adam optimizer betas
|
||||
- `adam_epsilon`: Epsilon for Adam (default: 1e-8)
|
||||
- `max_grad_norm`: Max gradient norm for clipping (default: 1.0)
|
||||
|
||||
### Mixed Precision Training
|
||||
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
fp16=True, # Use fp16 on NVIDIA GPUs
|
||||
fp16_opt_level="O1", # O0, O1, O2, O3 (Apex levels)
|
||||
# or
|
||||
bf16=True, # Use bf16 on Ampere+ GPUs (better than fp16)
|
||||
)
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
**DataParallel (single-node multi-GPU):**
|
||||
```python
|
||||
# Automatic with multiple GPUs
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=16, # Per GPU
|
||||
)
|
||||
# Run: python script.py
|
||||
```
|
||||
|
||||
**DistributedDataParallel (multi-node or multi-GPU):**
|
||||
```bash
|
||||
# Single node, multiple GPUs
|
||||
python -m torch.distributed.launch --nproc_per_node=4 script.py
|
||||
|
||||
# Or use accelerate
|
||||
accelerate config
|
||||
accelerate launch script.py
|
||||
```
|
||||
|
||||
**DeepSpeed Integration:**
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
deepspeed="ds_config.json", # DeepSpeed config file
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
**Gradient Checkpointing (reduce memory):**
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
```
|
||||
|
||||
**Compilation with torch.compile:**
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
torch_compile=True,
|
||||
torch_compile_backend="inductor", # or "cudagraphs"
|
||||
)
|
||||
```
|
||||
|
||||
**Push to Hub:**
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
push_to_hub=True,
|
||||
hub_model_id="username/model-name",
|
||||
hub_strategy="every_save", # or "end"
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Training Components
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```python
|
||||
import evaluate
|
||||
import numpy as np
|
||||
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
logits, labels = eval_pred
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Loss Function
|
||||
|
||||
```python
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# Custom loss calculation
|
||||
loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
### Data Collator
|
||||
|
||||
```python
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
```
|
||||
|
||||
### Callbacks
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class CustomCallback(TrainerCallback):
|
||||
def on_epoch_end(self, args, state, control, **kwargs):
|
||||
print(f"Epoch {state.epoch} completed!")
|
||||
return control
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
callbacks=[CustomCallback],
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameter Search
|
||||
|
||||
```python
|
||||
def model_init():
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=2
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model_init=model_init,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Optuna-based search
|
||||
best_trial = trainer.hyperparameter_search(
|
||||
direction="maximize",
|
||||
backend="optuna",
|
||||
n_trials=10,
|
||||
hp_space=lambda trial: {
|
||||
"learning_rate": trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True),
|
||||
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8, 16, 32]),
|
||||
"num_train_epochs": trial.suggest_int("num_train_epochs", 2, 5),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Training Best Practices
|
||||
|
||||
1. **Start with small learning rates**: 2e-5 to 5e-5 for fine-tuning
|
||||
2. **Use warmup**: 5-10% of total steps for learning rate warmup
|
||||
3. **Monitor training**: Use eval_strategy="epoch" or "steps" to track progress
|
||||
4. **Save checkpoints**: Set save_strategy and save_total_limit
|
||||
5. **Use mixed precision**: Enable fp16 or bf16 for faster training
|
||||
6. **Gradient accumulation**: For large effective batch sizes on limited memory
|
||||
7. **Load best model**: Set load_best_model_at_end=True to avoid overfitting
|
||||
8. **Push to Hub**: Enable push_to_hub for easy model sharing and versioning
|
||||
|
||||
## Common Training Patterns
|
||||
|
||||
### Classification
|
||||
```python
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=num_classes,
|
||||
id2label=id2label,
|
||||
label2id=label2id
|
||||
)
|
||||
```
|
||||
|
||||
### Question Answering
|
||||
```python
|
||||
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
|
||||
```
|
||||
|
||||
### Token Classification (NER)
|
||||
```python
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
num_labels=num_tags,
|
||||
id2label=id2label,
|
||||
label2id=label2id
|
||||
)
|
||||
```
|
||||
|
||||
### Sequence-to-Sequence
|
||||
```python
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
```
|
||||
|
||||
### Causal Language Modeling
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
```
|
||||
|
||||
### Masked Language Modeling
|
||||
```python
|
||||
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
```
|
||||
@@ -1,241 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fine-tune a transformer model for text classification.
|
||||
|
||||
This script demonstrates the complete workflow for fine-tuning a pre-trained
|
||||
model on a classification task using the Trainer API.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorWithPadding,
|
||||
)
|
||||
import evaluate
|
||||
|
||||
|
||||
def load_and_prepare_data(dataset_name="imdb", model_name="distilbert-base-uncased", max_samples=None):
|
||||
"""
|
||||
Load dataset and tokenize.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to load
|
||||
model_name: Name of the model/tokenizer to use
|
||||
max_samples: Limit number of samples (for quick testing)
|
||||
|
||||
Returns:
|
||||
tokenized_datasets, tokenizer
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_name}")
|
||||
dataset = load_dataset(dataset_name)
|
||||
|
||||
# Optionally limit samples for quick testing
|
||||
if max_samples:
|
||||
dataset["train"] = dataset["train"].select(range(max_samples))
|
||||
dataset["test"] = dataset["test"].select(range(min(max_samples, len(dataset["test"]))))
|
||||
|
||||
print(f"Loading tokenizer: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
print("Tokenizing dataset...")
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||
|
||||
return tokenized_datasets, tokenizer
|
||||
|
||||
|
||||
def create_model(model_name, num_labels, id2label, label2id):
|
||||
"""
|
||||
Create classification model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the pre-trained model
|
||||
num_labels: Number of classification labels
|
||||
id2label: Dictionary mapping label IDs to names
|
||||
label2id: Dictionary mapping label names to IDs
|
||||
|
||||
Returns:
|
||||
model
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name,
|
||||
num_labels=num_labels,
|
||||
id2label=id2label,
|
||||
label2id=label2id
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def define_compute_metrics(metric_name="accuracy"):
|
||||
"""
|
||||
Define function to compute metrics during evaluation.
|
||||
|
||||
Args:
|
||||
metric_name: Name of the metric to use
|
||||
|
||||
Returns:
|
||||
compute_metrics function
|
||||
"""
|
||||
metric = evaluate.load(metric_name)
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
logits, labels = eval_pred
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
return compute_metrics
|
||||
|
||||
|
||||
def train_model(model, tokenizer, train_dataset, eval_dataset, output_dir="./results"):
|
||||
"""
|
||||
Train the model.
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
tokenizer: The tokenizer
|
||||
train_dataset: Training dataset
|
||||
eval_dataset: Evaluation dataset
|
||||
output_dir: Directory for checkpoints and logs
|
||||
|
||||
Returns:
|
||||
trained model, trainer
|
||||
"""
|
||||
# Define training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
weight_decay=0.01,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
logging_dir=f"{output_dir}/logs",
|
||||
logging_steps=100,
|
||||
save_total_limit=2,
|
||||
fp16=False, # Set to True if using GPU with fp16 support
|
||||
)
|
||||
|
||||
# Create data collator
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=define_compute_metrics("accuracy"),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("\nStarting training...")
|
||||
trainer.train()
|
||||
|
||||
# Evaluate
|
||||
print("\nEvaluating model...")
|
||||
eval_results = trainer.evaluate()
|
||||
print(f"Evaluation results: {eval_results}")
|
||||
|
||||
return model, trainer
|
||||
|
||||
|
||||
def test_inference(model, tokenizer, id2label):
|
||||
"""
|
||||
Test the trained model with sample texts.
|
||||
|
||||
Args:
|
||||
model: Trained model
|
||||
tokenizer: Tokenizer
|
||||
id2label: Dictionary mapping label IDs to names
|
||||
"""
|
||||
print("\n=== Testing Inference ===")
|
||||
|
||||
test_texts = [
|
||||
"This movie was absolutely fantastic! I loved every minute of it.",
|
||||
"Terrible film. Waste of time and money.",
|
||||
"It was okay, nothing special but not bad either."
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits.argmax(-1)
|
||||
predicted_label = id2label[predictions.item()]
|
||||
confidence = outputs.logits.softmax(-1).max().item()
|
||||
|
||||
print(f"\nText: {text}")
|
||||
print(f"Prediction: {predicted_label} (confidence: {confidence:.3f})")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training pipeline."""
|
||||
# Configuration
|
||||
DATASET_NAME = "imdb"
|
||||
MODEL_NAME = "distilbert-base-uncased"
|
||||
OUTPUT_DIR = "./results"
|
||||
MAX_SAMPLES = None # Set to a small number (e.g., 1000) for quick testing
|
||||
|
||||
# Label mapping
|
||||
id2label = {0: "negative", 1: "positive"}
|
||||
label2id = {"negative": 0, "positive": 1}
|
||||
num_labels = len(id2label)
|
||||
|
||||
print("=" * 60)
|
||||
print("Fine-Tuning Text Classification Model")
|
||||
print("=" * 60)
|
||||
|
||||
# Load and prepare data
|
||||
tokenized_datasets, tokenizer = load_and_prepare_data(
|
||||
dataset_name=DATASET_NAME,
|
||||
model_name=MODEL_NAME,
|
||||
max_samples=MAX_SAMPLES
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = create_model(
|
||||
model_name=MODEL_NAME,
|
||||
num_labels=num_labels,
|
||||
id2label=id2label,
|
||||
label2id=label2id
|
||||
)
|
||||
|
||||
# Train model
|
||||
model, trainer = train_model(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
output_dir=OUTPUT_DIR
|
||||
)
|
||||
|
||||
# Save final model
|
||||
print(f"\nSaving model to {OUTPUT_DIR}/final_model")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final_model")
|
||||
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final_model")
|
||||
|
||||
# Test inference
|
||||
test_inference(model, tokenizer, id2label)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Training completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Text generation with different decoding strategies.
|
||||
|
||||
This script demonstrates various text generation approaches using
|
||||
different sampling and decoding strategies.
|
||||
"""
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_name="gpt2"):
|
||||
"""
|
||||
Load model and tokenizer.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
Returns:
|
||||
model, tokenizer
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Set pad token if not already set
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate_with_greedy(model, tokenizer, prompt, max_new_tokens=50):
|
||||
"""Greedy decoding - always picks highest probability token."""
|
||||
print("\n=== Greedy Decoding ===")
|
||||
print(f"Prompt: {prompt}")
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"Generated: {generated_text}\n")
|
||||
|
||||
|
||||
def generate_with_beam_search(model, tokenizer, prompt, max_new_tokens=50, num_beams=5):
|
||||
"""Beam search - explores multiple hypotheses."""
|
||||
print("\n=== Beam Search ===")
|
||||
print(f"Prompt: {prompt}")
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
early_stopping=True,
|
||||
no_repeat_ngram_size=2,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"Generated: {generated_text}\n")
|
||||
|
||||
|
||||
def generate_with_sampling(model, tokenizer, prompt, max_new_tokens=50,
|
||||
temperature=0.7, top_k=50, top_p=0.9):
|
||||
"""Sampling with temperature, top-k, and nucleus (top-p) sampling."""
|
||||
print("\n=== Sampling (Temperature + Top-K + Top-P) ===")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Parameters: temperature={temperature}, top_k={top_k}, top_p={top_p}")
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"Generated: {generated_text}\n")
|
||||
|
||||
|
||||
def generate_multiple_sequences(model, tokenizer, prompt, max_new_tokens=50,
|
||||
num_return_sequences=3):
|
||||
"""Generate multiple diverse sequences."""
|
||||
print("\n=== Multiple Sequences (with Sampling) ===")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Generating {num_return_sequences} sequences...")
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
num_return_sequences=num_return_sequences,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
generated_text = tokenizer.decode(output, skip_special_tokens=True)
|
||||
print(f"\nSequence {i+1}: {generated_text}")
|
||||
print()
|
||||
|
||||
|
||||
def generate_with_config(model, tokenizer, prompt):
|
||||
"""Use GenerationConfig for reusable configuration."""
|
||||
print("\n=== Using GenerationConfig ===")
|
||||
print(f"Prompt: {prompt}")
|
||||
|
||||
# Create a generation config
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=50,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.2,
|
||||
no_repeat_ngram_size=3,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"Generated: {generated_text}\n")
|
||||
|
||||
|
||||
def compare_temperatures(model, tokenizer, prompt, max_new_tokens=50):
|
||||
"""Compare different temperature settings."""
|
||||
print("\n=== Temperature Comparison ===")
|
||||
print(f"Prompt: {prompt}\n")
|
||||
|
||||
temperatures = [0.3, 0.7, 1.0, 1.5]
|
||||
|
||||
for temp in temperatures:
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temp,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(f"Temperature {temp}: {generated_text}\n")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all generation examples."""
|
||||
print("=" * 70)
|
||||
print("Text Generation Examples")
|
||||
print("=" * 70)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load_model_and_tokenizer("gpt2")
|
||||
|
||||
# Example prompts
|
||||
story_prompt = "Once upon a time in a distant galaxy"
|
||||
factual_prompt = "The three branches of the US government are"
|
||||
|
||||
# Demonstrate different strategies
|
||||
generate_with_greedy(model, tokenizer, story_prompt)
|
||||
generate_with_beam_search(model, tokenizer, factual_prompt)
|
||||
generate_with_sampling(model, tokenizer, story_prompt)
|
||||
generate_multiple_sequences(model, tokenizer, story_prompt, num_return_sequences=3)
|
||||
generate_with_config(model, tokenizer, story_prompt)
|
||||
compare_temperatures(model, tokenizer, story_prompt)
|
||||
|
||||
print("=" * 70)
|
||||
print("All generation examples completed!")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick inference using Transformers pipelines.
|
||||
|
||||
This script demonstrates how to quickly use pre-trained models for inference
|
||||
across various tasks using the pipeline API.
|
||||
"""
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
def text_classification_example():
|
||||
"""Sentiment analysis example."""
|
||||
print("=== Text Classification ===")
|
||||
classifier = pipeline("text-classification")
|
||||
result = classifier("I love using Transformers! It makes NLP so easy.")
|
||||
print(f"Result: {result}\n")
|
||||
|
||||
|
||||
def named_entity_recognition_example():
|
||||
"""Named Entity Recognition example."""
|
||||
print("=== Named Entity Recognition ===")
|
||||
ner = pipeline("token-classification", aggregation_strategy="simple")
|
||||
text = "My name is Sarah and I work at Microsoft in Seattle"
|
||||
entities = ner(text)
|
||||
for entity in entities:
|
||||
print(f"{entity['word']}: {entity['entity_group']} (score: {entity['score']:.3f})")
|
||||
print()
|
||||
|
||||
|
||||
def question_answering_example():
|
||||
"""Question Answering example."""
|
||||
print("=== Question Answering ===")
|
||||
qa = pipeline("question-answering")
|
||||
context = "Paris is the capital and most populous city of France. It is located in northern France."
|
||||
question = "What is the capital of France?"
|
||||
answer = qa(question=question, context=context)
|
||||
print(f"Question: {question}")
|
||||
print(f"Answer: {answer['answer']} (score: {answer['score']:.3f})\n")
|
||||
|
||||
|
||||
def text_generation_example():
|
||||
"""Text generation example."""
|
||||
print("=== Text Generation ===")
|
||||
generator = pipeline("text-generation", model="gpt2")
|
||||
prompt = "Once upon a time in a land far away"
|
||||
generated = generator(prompt, max_length=50, num_return_sequences=1)
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Generated: {generated[0]['generated_text']}\n")
|
||||
|
||||
|
||||
def summarization_example():
|
||||
"""Text summarization example."""
|
||||
print("=== Summarization ===")
|
||||
summarizer = pipeline("summarization")
|
||||
article = """
|
||||
The Transformers library provides thousands of pretrained models to perform tasks
|
||||
on texts such as classification, information extraction, question answering,
|
||||
summarization, translation, text generation, etc in over 100 languages. Its aim
|
||||
is to make cutting-edge NLP easier to use for everyone. The library provides APIs
|
||||
to quickly download and use pretrained models on a given text, fine-tune them on
|
||||
your own datasets then share them with the community on the model hub.
|
||||
"""
|
||||
summary = summarizer(article, max_length=50, min_length=25, do_sample=False)
|
||||
print(f"Summary: {summary[0]['summary_text']}\n")
|
||||
|
||||
|
||||
def translation_example():
|
||||
"""Translation example."""
|
||||
print("=== Translation ===")
|
||||
translator = pipeline("translation_en_to_fr")
|
||||
text = "Hello, how are you today?"
|
||||
translation = translator(text)
|
||||
print(f"English: {text}")
|
||||
print(f"French: {translation[0]['translation_text']}\n")
|
||||
|
||||
|
||||
def zero_shot_classification_example():
|
||||
"""Zero-shot classification example."""
|
||||
print("=== Zero-Shot Classification ===")
|
||||
classifier = pipeline("zero-shot-classification")
|
||||
text = "This is a breaking news story about a major earthquake."
|
||||
candidate_labels = ["politics", "sports", "science", "breaking news"]
|
||||
result = classifier(text, candidate_labels)
|
||||
print(f"Text: {text}")
|
||||
print("Predictions:")
|
||||
for label, score in zip(result['labels'], result['scores']):
|
||||
print(f" {label}: {score:.3f}")
|
||||
print()
|
||||
|
||||
|
||||
def image_classification_example():
|
||||
"""Image classification example (requires PIL)."""
|
||||
print("=== Image Classification ===")
|
||||
try:
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
classifier = pipeline("image-classification")
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
predictions = classifier(image)
|
||||
print("Top predictions:")
|
||||
for pred in predictions[:3]:
|
||||
print(f" {pred['label']}: {pred['score']:.3f}")
|
||||
print()
|
||||
except ImportError:
|
||||
print("PIL not installed. Skipping image classification example.\n")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all examples."""
|
||||
print("Transformers Quick Inference Examples")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
# Text tasks
|
||||
text_classification_example()
|
||||
named_entity_recognition_example()
|
||||
question_answering_example()
|
||||
text_generation_example()
|
||||
summarization_example()
|
||||
translation_example()
|
||||
zero_shot_classification_example()
|
||||
|
||||
# Vision task (optional)
|
||||
image_classification_example()
|
||||
|
||||
print("=" * 50)
|
||||
print("All examples completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
scientific-skills/adaptyv/SKILL.md
Normal file
117
scientific-skills/adaptyv/SKILL.md
Normal file
@@ -0,0 +1,117 @@
|
||||
---
|
||||
name: adaptyv
|
||||
description: Cloud laboratory platform for automated protein testing and validation. Use when designing proteins and needing experimental validation including binding assays, expression testing, thermostability measurements, enzyme activity assays, or protein sequence optimization. Also use for submitting experiments via API, tracking experiment status, downloading results, optimizing protein sequences for better expression using computational tools (NetSolP, SoluProt, SolubleMPNN, ESM), or managing protein design workflows with wet-lab validation.
|
||||
license: Unknown
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# Adaptyv
|
||||
|
||||
Adaptyv is a cloud laboratory platform that provides automated protein testing and validation services. Submit protein sequences via API or web interface and receive experimental results in approximately 21 days.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Authentication Setup
|
||||
|
||||
Adaptyv requires API authentication. Set up your credentials:
|
||||
|
||||
1. Contact support@adaptyvbio.com to request API access (platform is in alpha/beta)
|
||||
2. Receive your API access token
|
||||
3. Set environment variable:
|
||||
|
||||
```bash
|
||||
export ADAPTYV_API_KEY="your_api_key_here"
|
||||
```
|
||||
|
||||
Or create a `.env` file:
|
||||
|
||||
```
|
||||
ADAPTYV_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
Install the required package using uv:
|
||||
|
||||
```bash
|
||||
uv pip install requests python-dotenv
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
|
||||
Submit protein sequences for testing:
|
||||
|
||||
```python
|
||||
import os
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
api_key = os.getenv("ADAPTYV_API_KEY")
|
||||
base_url = "https://kq5jp7qj7wdqklhsxmovkzn4l40obksv.lambda-url.eu-central-1.on.aws"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Submit experiment
|
||||
response = requests.post(
|
||||
f"{base_url}/experiments",
|
||||
headers=headers,
|
||||
json={
|
||||
"sequences": ">protein1\nMKVLWALLGLLGAA...",
|
||||
"experiment_type": "binding",
|
||||
"webhook_url": "https://your-webhook.com/callback"
|
||||
}
|
||||
)
|
||||
|
||||
experiment_id = response.json()["experiment_id"]
|
||||
```
|
||||
|
||||
## Available Experiment Types
|
||||
|
||||
Adaptyv supports multiple assay types:
|
||||
|
||||
- **Binding assays** - Test protein-target interactions using biolayer interferometry
|
||||
- **Expression testing** - Measure protein expression levels
|
||||
- **Thermostability** - Characterize protein thermal stability
|
||||
- **Enzyme activity** - Assess enzymatic function
|
||||
|
||||
See `reference/experiments.md` for detailed information on each experiment type and workflows.
|
||||
|
||||
## Protein Sequence Optimization
|
||||
|
||||
Before submitting sequences, optimize them for better expression and stability:
|
||||
|
||||
**Common issues to address:**
|
||||
- Unpaired cysteines that create unwanted disulfides
|
||||
- Excessive hydrophobic regions causing aggregation
|
||||
- Poor solubility predictions
|
||||
|
||||
**Recommended tools:**
|
||||
- NetSolP / SoluProt - Initial solubility filtering
|
||||
- SolubleMPNN - Sequence redesign for improved solubility
|
||||
- ESM - Sequence likelihood scoring
|
||||
- ipTM - Interface stability assessment
|
||||
- pSAE - Hydrophobic exposure quantification
|
||||
|
||||
See `reference/protein_optimization.md` for detailed optimization workflows and tool usage.
|
||||
|
||||
## API Reference
|
||||
|
||||
For complete API documentation including all endpoints, request/response formats, and authentication details, see `reference/api_reference.md`.
|
||||
|
||||
## Examples
|
||||
|
||||
For concrete code examples covering common use cases (experiment submission, status tracking, result retrieval, batch processing), see `reference/examples.md`.
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Platform is currently in alpha/beta phase with features subject to change
|
||||
- Not all platform features are available via API yet
|
||||
- Results typically delivered in ~21 days
|
||||
- Contact support@adaptyvbio.com for access requests or questions
|
||||
- Suitable for high-throughput AI-driven protein design workflows
|
||||
308
scientific-skills/adaptyv/reference/api_reference.md
Normal file
308
scientific-skills/adaptyv/reference/api_reference.md
Normal file
@@ -0,0 +1,308 @@
|
||||
# Adaptyv API Reference
|
||||
|
||||
## Base URL
|
||||
|
||||
```
|
||||
https://kq5jp7qj7wdqklhsxmovkzn4l40obksv.lambda-url.eu-central-1.on.aws
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
All API requests require bearer token authentication in the request header:
|
||||
|
||||
```
|
||||
Authorization: Bearer YOUR_API_KEY
|
||||
```
|
||||
|
||||
To obtain API access:
|
||||
1. Contact support@adaptyvbio.com
|
||||
2. Request API access during alpha/beta period
|
||||
3. Receive your personal access token
|
||||
|
||||
Store your API key securely:
|
||||
- Use environment variables: `ADAPTYV_API_KEY`
|
||||
- Never commit API keys to version control
|
||||
- Use `.env` files with `.gitignore` for local development
|
||||
|
||||
## Endpoints
|
||||
|
||||
### Experiments
|
||||
|
||||
#### Create Experiment
|
||||
|
||||
Submit protein sequences for experimental testing.
|
||||
|
||||
**Endpoint:** `POST /experiments`
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"sequences": ">protein1\nMKVLWALLGLLGAA...\n>protein2\nMATGVLWALLG...",
|
||||
"experiment_type": "binding|expression|thermostability|enzyme_activity",
|
||||
"target_id": "optional_target_identifier",
|
||||
"webhook_url": "https://your-webhook.com/callback",
|
||||
"metadata": {
|
||||
"project": "optional_project_name",
|
||||
"notes": "optional_notes"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Sequence Format:**
|
||||
- FASTA format with headers
|
||||
- Multiple sequences supported
|
||||
- Standard amino acid codes
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"experiment_id": "exp_abc123xyz",
|
||||
"status": "submitted",
|
||||
"created_at": "2025-11-24T10:00:00Z",
|
||||
"estimated_completion": "2025-12-15T10:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### Get Experiment Status
|
||||
|
||||
Check the current status of an experiment.
|
||||
|
||||
**Endpoint:** `GET /experiments/{experiment_id}`
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"experiment_id": "exp_abc123xyz",
|
||||
"status": "submitted|processing|completed|failed",
|
||||
"created_at": "2025-11-24T10:00:00Z",
|
||||
"updated_at": "2025-11-25T14:30:00Z",
|
||||
"progress": {
|
||||
"stage": "sequencing|expression|assay|analysis",
|
||||
"percentage": 45
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Status Values:**
|
||||
- `submitted` - Experiment received and queued
|
||||
- `processing` - Active testing in progress
|
||||
- `completed` - Results available for download
|
||||
- `failed` - Experiment encountered an error
|
||||
|
||||
#### List Experiments
|
||||
|
||||
Retrieve all experiments for your organization.
|
||||
|
||||
**Endpoint:** `GET /experiments`
|
||||
|
||||
**Query Parameters:**
|
||||
- `status` - Filter by status (optional)
|
||||
- `limit` - Number of results per page (default: 50)
|
||||
- `offset` - Pagination offset (default: 0)
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"experiments": [
|
||||
{
|
||||
"experiment_id": "exp_abc123xyz",
|
||||
"status": "completed",
|
||||
"experiment_type": "binding",
|
||||
"created_at": "2025-11-24T10:00:00Z"
|
||||
}
|
||||
],
|
||||
"total": 150,
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
```
|
||||
|
||||
### Results
|
||||
|
||||
#### Get Experiment Results
|
||||
|
||||
Download results from a completed experiment.
|
||||
|
||||
**Endpoint:** `GET /experiments/{experiment_id}/results`
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"experiment_id": "exp_abc123xyz",
|
||||
"results": [
|
||||
{
|
||||
"sequence_id": "protein1",
|
||||
"measurements": {
|
||||
"kd": 1.2e-9,
|
||||
"kon": 1.5e5,
|
||||
"koff": 1.8e-4
|
||||
},
|
||||
"quality_metrics": {
|
||||
"confidence": "high",
|
||||
"r_squared": 0.98
|
||||
}
|
||||
}
|
||||
],
|
||||
"download_urls": {
|
||||
"raw_data": "https://...",
|
||||
"analysis_package": "https://...",
|
||||
"report": "https://..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Targets
|
||||
|
||||
#### Search Target Catalog
|
||||
|
||||
Search the ACROBiosystems antigen catalog.
|
||||
|
||||
**Endpoint:** `GET /targets`
|
||||
|
||||
**Query Parameters:**
|
||||
- `search` - Search term (protein name, UniProt ID, etc.)
|
||||
- `species` - Filter by species
|
||||
- `category` - Filter by category
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"targets": [
|
||||
{
|
||||
"target_id": "tgt_12345",
|
||||
"name": "Human PD-L1",
|
||||
"species": "Homo sapiens",
|
||||
"uniprot_id": "Q9NZQ7",
|
||||
"availability": "in_stock|custom_order",
|
||||
"price_usd": 450
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Request Custom Target
|
||||
|
||||
Request an antigen not in the standard catalog.
|
||||
|
||||
**Endpoint:** `POST /targets/request`
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"target_name": "Custom target name",
|
||||
"uniprot_id": "optional_uniprot_id",
|
||||
"species": "species_name",
|
||||
"notes": "Additional requirements"
|
||||
}
|
||||
```
|
||||
|
||||
### Organization
|
||||
|
||||
#### Get Credits Balance
|
||||
|
||||
Check your organization's credit balance and usage.
|
||||
|
||||
**Endpoint:** `GET /organization/credits`
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"balance": 10000,
|
||||
"currency": "USD",
|
||||
"usage_this_month": 2500,
|
||||
"experiments_remaining": 22
|
||||
}
|
||||
```
|
||||
|
||||
## Webhooks
|
||||
|
||||
Configure webhook URLs to receive notifications when experiments complete.
|
||||
|
||||
**Webhook Payload:**
|
||||
```json
|
||||
{
|
||||
"event": "experiment.completed",
|
||||
"experiment_id": "exp_abc123xyz",
|
||||
"status": "completed",
|
||||
"timestamp": "2025-12-15T10:00:00Z",
|
||||
"results_url": "/experiments/exp_abc123xyz/results"
|
||||
}
|
||||
```
|
||||
|
||||
**Webhook Events:**
|
||||
- `experiment.submitted` - Experiment received
|
||||
- `experiment.started` - Processing began
|
||||
- `experiment.completed` - Results available
|
||||
- `experiment.failed` - Error occurred
|
||||
|
||||
**Security:**
|
||||
- Verify webhook signatures (details provided during onboarding)
|
||||
- Use HTTPS endpoints only
|
||||
- Respond with 200 OK to acknowledge receipt
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Error Response Format:**
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": "invalid_sequence",
|
||||
"message": "Sequence contains invalid amino acid codes",
|
||||
"details": {
|
||||
"sequence_id": "protein1",
|
||||
"position": 45,
|
||||
"character": "X"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Common Error Codes:**
|
||||
- `authentication_failed` - Invalid or missing API key
|
||||
- `invalid_sequence` - Malformed FASTA or invalid amino acids
|
||||
- `insufficient_credits` - Not enough credits for experiment
|
||||
- `target_not_found` - Specified target ID doesn't exist
|
||||
- `rate_limit_exceeded` - Too many requests
|
||||
- `experiment_not_found` - Invalid experiment ID
|
||||
- `internal_error` - Server-side error
|
||||
|
||||
## Rate Limits
|
||||
|
||||
- 100 requests per minute per API key
|
||||
- 1000 experiments per day per organization
|
||||
- Batch submissions encouraged for large-scale testing
|
||||
|
||||
When rate limited, response includes:
|
||||
```
|
||||
HTTP 429 Too Many Requests
|
||||
Retry-After: 60
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use webhooks** for long-running experiments instead of polling
|
||||
2. **Batch sequences** when submitting multiple variants
|
||||
3. **Cache results** to avoid redundant API calls
|
||||
4. **Implement retry logic** with exponential backoff
|
||||
5. **Monitor credits** to avoid experiment failures
|
||||
6. **Validate sequences** locally before submission
|
||||
7. **Use descriptive metadata** for better experiment tracking
|
||||
|
||||
## API Versioning
|
||||
|
||||
The API is currently in alpha/beta. Breaking changes may occur but will be:
|
||||
- Announced via email to registered users
|
||||
- Documented in the changelog
|
||||
- Supported with migration guides
|
||||
|
||||
Current version is reflected in response headers:
|
||||
```
|
||||
X-API-Version: alpha-2025-11
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For API issues or questions:
|
||||
- Email: support@adaptyvbio.com
|
||||
- Documentation updates: https://docs.adaptyvbio.com
|
||||
- Report bugs with experiment IDs and request details
|
||||
913
scientific-skills/adaptyv/reference/examples.md
Normal file
913
scientific-skills/adaptyv/reference/examples.md
Normal file
@@ -0,0 +1,913 @@
|
||||
# Code Examples
|
||||
|
||||
## Setup and Authentication
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import os
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
API_KEY = os.getenv("ADAPTYV_API_KEY")
|
||||
BASE_URL = "https://kq5jp7qj7wdqklhsxmovkzn4l40obksv.lambda-url.eu-central-1.on.aws"
|
||||
|
||||
# Standard headers
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def check_api_connection():
|
||||
"""Verify API connection and credentials"""
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/organization/credits", headers=HEADERS)
|
||||
response.raise_for_status()
|
||||
print("✓ API connection successful")
|
||||
print(f" Credits remaining: {response.json()['balance']}")
|
||||
return True
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"✗ API authentication failed: {e}")
|
||||
return False
|
||||
```
|
||||
|
||||
### Environment Setup
|
||||
|
||||
Create a `.env` file:
|
||||
```bash
|
||||
ADAPTYV_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
uv pip install requests python-dotenv
|
||||
```
|
||||
|
||||
## Experiment Submission
|
||||
|
||||
### Submit Single Sequence
|
||||
|
||||
```python
|
||||
def submit_single_experiment(sequence, experiment_type="binding", target_id=None):
|
||||
"""
|
||||
Submit a single protein sequence for testing
|
||||
|
||||
Args:
|
||||
sequence: Amino acid sequence string
|
||||
experiment_type: Type of experiment (binding, expression, thermostability, enzyme_activity)
|
||||
target_id: Optional target identifier for binding assays
|
||||
|
||||
Returns:
|
||||
Experiment ID and status
|
||||
"""
|
||||
|
||||
# Format as FASTA
|
||||
fasta_content = f">protein_sequence\n{sequence}\n"
|
||||
|
||||
payload = {
|
||||
"sequences": fasta_content,
|
||||
"experiment_type": experiment_type
|
||||
}
|
||||
|
||||
if target_id:
|
||||
payload["target_id"] = target_id
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/experiments",
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
print(f"✓ Experiment submitted")
|
||||
print(f" Experiment ID: {result['experiment_id']}")
|
||||
print(f" Status: {result['status']}")
|
||||
print(f" Estimated completion: {result['estimated_completion']}")
|
||||
|
||||
return result
|
||||
|
||||
# Example usage
|
||||
sequence = "MKVLWAALLGLLGAAAAFPAVTSAVKPYKAAVSAAVSKPYKAAVSAAVSKPYK"
|
||||
experiment = submit_single_experiment(sequence, experiment_type="expression")
|
||||
```
|
||||
|
||||
### Submit Multiple Sequences (Batch)
|
||||
|
||||
```python
|
||||
def submit_batch_experiment(sequences_dict, experiment_type="binding", metadata=None):
|
||||
"""
|
||||
Submit multiple protein sequences in a single batch
|
||||
|
||||
Args:
|
||||
sequences_dict: Dictionary of {name: sequence}
|
||||
experiment_type: Type of experiment
|
||||
metadata: Optional dictionary of additional information
|
||||
|
||||
Returns:
|
||||
Experiment details
|
||||
"""
|
||||
|
||||
# Format all sequences as FASTA
|
||||
fasta_content = ""
|
||||
for name, sequence in sequences_dict.items():
|
||||
fasta_content += f">{name}\n{sequence}\n"
|
||||
|
||||
payload = {
|
||||
"sequences": fasta_content,
|
||||
"experiment_type": experiment_type
|
||||
}
|
||||
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/experiments",
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
print(f"✓ Batch experiment submitted")
|
||||
print(f" Experiment ID: {result['experiment_id']}")
|
||||
print(f" Sequences: {len(sequences_dict)}")
|
||||
print(f" Status: {result['status']}")
|
||||
|
||||
return result
|
||||
|
||||
# Example usage
|
||||
sequences = {
|
||||
"variant_1": "MKVLWAALLGLLGAAA...",
|
||||
"variant_2": "MKVLSAALLGLLGAAA...",
|
||||
"variant_3": "MKVLAAALLGLLGAAA...",
|
||||
"wildtype": "MKVLWAALLGLLGAAA..."
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"project": "antibody_optimization",
|
||||
"round": 3,
|
||||
"notes": "Testing solubility-optimized variants"
|
||||
}
|
||||
|
||||
experiment = submit_batch_experiment(sequences, "expression", metadata)
|
||||
```
|
||||
|
||||
### Submit with Webhook Notification
|
||||
|
||||
```python
|
||||
def submit_with_webhook(sequences_dict, experiment_type, webhook_url):
|
||||
"""
|
||||
Submit experiment with webhook for completion notification
|
||||
|
||||
Args:
|
||||
sequences_dict: Dictionary of {name: sequence}
|
||||
experiment_type: Type of experiment
|
||||
webhook_url: URL to receive notification when complete
|
||||
"""
|
||||
|
||||
fasta_content = ""
|
||||
for name, sequence in sequences_dict.items():
|
||||
fasta_content += f">{name}\n{sequence}\n"
|
||||
|
||||
payload = {
|
||||
"sequences": fasta_content,
|
||||
"experiment_type": experiment_type,
|
||||
"webhook_url": webhook_url
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/experiments",
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
print(f"✓ Experiment submitted with webhook")
|
||||
print(f" Experiment ID: {result['experiment_id']}")
|
||||
print(f" Webhook: {webhook_url}")
|
||||
|
||||
return result
|
||||
|
||||
# Example
|
||||
webhook_url = "https://your-server.com/adaptyv-webhook"
|
||||
experiment = submit_with_webhook(sequences, "binding", webhook_url)
|
||||
```
|
||||
|
||||
## Tracking Experiments
|
||||
|
||||
### Check Experiment Status
|
||||
|
||||
```python
|
||||
def check_experiment_status(experiment_id):
|
||||
"""
|
||||
Get current status of an experiment
|
||||
|
||||
Args:
|
||||
experiment_id: Experiment identifier
|
||||
|
||||
Returns:
|
||||
Status information
|
||||
"""
|
||||
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/experiments/{experiment_id}",
|
||||
headers=HEADERS
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
status = response.json()
|
||||
|
||||
print(f"Experiment: {experiment_id}")
|
||||
print(f" Status: {status['status']}")
|
||||
print(f" Created: {status['created_at']}")
|
||||
print(f" Updated: {status['updated_at']}")
|
||||
|
||||
if 'progress' in status:
|
||||
print(f" Progress: {status['progress']['percentage']}%")
|
||||
print(f" Current stage: {status['progress']['stage']}")
|
||||
|
||||
return status
|
||||
|
||||
# Example
|
||||
status = check_experiment_status("exp_abc123xyz")
|
||||
```
|
||||
|
||||
### List All Experiments
|
||||
|
||||
```python
|
||||
def list_experiments(status_filter=None, limit=50):
|
||||
"""
|
||||
List experiments with optional status filtering
|
||||
|
||||
Args:
|
||||
status_filter: Filter by status (submitted, processing, completed, failed)
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of experiments
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if status_filter:
|
||||
params["status"] = status_filter
|
||||
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/experiments",
|
||||
headers=HEADERS,
|
||||
params=params
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
print(f"Found {result['total']} experiments")
|
||||
for exp in result['experiments']:
|
||||
print(f" {exp['experiment_id']}: {exp['status']} ({exp['experiment_type']})")
|
||||
|
||||
return result['experiments']
|
||||
|
||||
# Example - list all completed experiments
|
||||
completed_experiments = list_experiments(status_filter="completed")
|
||||
```
|
||||
|
||||
### Poll Until Complete
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
def wait_for_completion(experiment_id, check_interval=3600):
|
||||
"""
|
||||
Poll experiment status until completion
|
||||
|
||||
Args:
|
||||
experiment_id: Experiment identifier
|
||||
check_interval: Seconds between status checks (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Final status
|
||||
"""
|
||||
|
||||
print(f"Monitoring experiment {experiment_id}...")
|
||||
|
||||
while True:
|
||||
status = check_experiment_status(experiment_id)
|
||||
|
||||
if status['status'] == 'completed':
|
||||
print("✓ Experiment completed!")
|
||||
return status
|
||||
elif status['status'] == 'failed':
|
||||
print("✗ Experiment failed")
|
||||
return status
|
||||
|
||||
print(f" Status: {status['status']} - checking again in {check_interval}s")
|
||||
time.sleep(check_interval)
|
||||
|
||||
# Example (not recommended - use webhooks instead!)
|
||||
# status = wait_for_completion("exp_abc123xyz", check_interval=3600)
|
||||
```
|
||||
|
||||
## Retrieving Results
|
||||
|
||||
### Download Experiment Results
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def download_results(experiment_id, output_dir="results"):
|
||||
"""
|
||||
Download and parse experiment results
|
||||
|
||||
Args:
|
||||
experiment_id: Experiment identifier
|
||||
output_dir: Directory to save results
|
||||
|
||||
Returns:
|
||||
Parsed results data
|
||||
"""
|
||||
|
||||
# Get results
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/experiments/{experiment_id}/results",
|
||||
headers=HEADERS
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
# Save results JSON
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = f"{output_dir}/{experiment_id}_results.json"
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"✓ Results downloaded: {output_file}")
|
||||
print(f" Sequences tested: {len(results['results'])}")
|
||||
|
||||
# Download raw data if available
|
||||
if 'download_urls' in results:
|
||||
for data_type, url in results['download_urls'].items():
|
||||
print(f" {data_type} available at: {url}")
|
||||
|
||||
return results
|
||||
|
||||
# Example
|
||||
results = download_results("exp_abc123xyz")
|
||||
```
|
||||
|
||||
### Parse Binding Results
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
def parse_binding_results(results):
|
||||
"""
|
||||
Parse binding assay results into DataFrame
|
||||
|
||||
Args:
|
||||
results: Results dictionary from API
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with organized results
|
||||
"""
|
||||
|
||||
data = []
|
||||
for result in results['results']:
|
||||
row = {
|
||||
'sequence_id': result['sequence_id'],
|
||||
'kd': result['measurements']['kd'],
|
||||
'kd_error': result['measurements']['kd_error'],
|
||||
'kon': result['measurements']['kon'],
|
||||
'koff': result['measurements']['koff'],
|
||||
'confidence': result['quality_metrics']['confidence'],
|
||||
'r_squared': result['quality_metrics']['r_squared']
|
||||
}
|
||||
data.append(row)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Sort by affinity (lower KD = stronger binding)
|
||||
df = df.sort_values('kd')
|
||||
|
||||
print("Top 5 binders:")
|
||||
print(df.head())
|
||||
|
||||
return df
|
||||
|
||||
# Example
|
||||
experiment_id = "exp_abc123xyz"
|
||||
results = download_results(experiment_id)
|
||||
binding_df = parse_binding_results(results)
|
||||
|
||||
# Export to CSV
|
||||
binding_df.to_csv(f"{experiment_id}_binding_results.csv", index=False)
|
||||
```
|
||||
|
||||
### Parse Expression Results
|
||||
|
||||
```python
|
||||
def parse_expression_results(results):
|
||||
"""
|
||||
Parse expression testing results into DataFrame
|
||||
|
||||
Args:
|
||||
results: Results dictionary from API
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with organized results
|
||||
"""
|
||||
|
||||
data = []
|
||||
for result in results['results']:
|
||||
row = {
|
||||
'sequence_id': result['sequence_id'],
|
||||
'yield_mg_per_l': result['measurements']['total_yield_mg_per_l'],
|
||||
'soluble_fraction': result['measurements']['soluble_fraction_percent'],
|
||||
'purity': result['measurements']['purity_percent'],
|
||||
'percentile': result['ranking']['percentile']
|
||||
}
|
||||
data.append(row)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Sort by yield
|
||||
df = df.sort_values('yield_mg_per_l', ascending=False)
|
||||
|
||||
print(f"Mean yield: {df['yield_mg_per_l'].mean():.2f} mg/L")
|
||||
print(f"Top performer: {df.iloc[0]['sequence_id']} ({df.iloc[0]['yield_mg_per_l']:.2f} mg/L)")
|
||||
|
||||
return df
|
||||
|
||||
# Example
|
||||
results = download_results("exp_expression123")
|
||||
expression_df = parse_expression_results(results)
|
||||
```
|
||||
|
||||
## Target Catalog
|
||||
|
||||
### Search for Targets
|
||||
|
||||
```python
|
||||
def search_targets(query, species=None, category=None):
|
||||
"""
|
||||
Search the antigen catalog
|
||||
|
||||
Args:
|
||||
query: Search term (protein name, UniProt ID, etc.)
|
||||
species: Optional species filter
|
||||
category: Optional category filter
|
||||
|
||||
Returns:
|
||||
List of matching targets
|
||||
"""
|
||||
|
||||
params = {"search": query}
|
||||
if species:
|
||||
params["species"] = species
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/targets",
|
||||
headers=HEADERS,
|
||||
params=params
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
targets = response.json()['targets']
|
||||
|
||||
print(f"Found {len(targets)} targets matching '{query}':")
|
||||
for target in targets:
|
||||
print(f" {target['target_id']}: {target['name']}")
|
||||
print(f" Species: {target['species']}")
|
||||
print(f" Availability: {target['availability']}")
|
||||
print(f" Price: ${target['price_usd']}")
|
||||
|
||||
return targets
|
||||
|
||||
# Example
|
||||
targets = search_targets("PD-L1", species="Homo sapiens")
|
||||
```
|
||||
|
||||
### Request Custom Target
|
||||
|
||||
```python
|
||||
def request_custom_target(target_name, uniprot_id=None, species=None, notes=None):
|
||||
"""
|
||||
Request a custom antigen not in the standard catalog
|
||||
|
||||
Args:
|
||||
target_name: Name of the target protein
|
||||
uniprot_id: Optional UniProt identifier
|
||||
species: Species name
|
||||
notes: Additional requirements or notes
|
||||
|
||||
Returns:
|
||||
Request confirmation
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"target_name": target_name,
|
||||
"species": species
|
||||
}
|
||||
|
||||
if uniprot_id:
|
||||
payload["uniprot_id"] = uniprot_id
|
||||
if notes:
|
||||
payload["notes"] = notes
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/targets/request",
|
||||
headers=HEADERS,
|
||||
json=payload
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
print(f"✓ Custom target request submitted")
|
||||
print(f" Request ID: {result['request_id']}")
|
||||
print(f" Status: {result['status']}")
|
||||
|
||||
return result
|
||||
|
||||
# Example
|
||||
request = request_custom_target(
|
||||
target_name="Novel receptor XYZ",
|
||||
uniprot_id="P12345",
|
||||
species="Mus musculus",
|
||||
notes="Need high purity for structural studies"
|
||||
)
|
||||
```
|
||||
|
||||
## Complete Workflows
|
||||
|
||||
### End-to-End Binding Assay
|
||||
|
||||
```python
|
||||
def complete_binding_workflow(sequences_dict, target_id, project_name):
|
||||
"""
|
||||
Complete workflow: submit sequences, track, and retrieve binding results
|
||||
|
||||
Args:
|
||||
sequences_dict: Dictionary of {name: sequence}
|
||||
target_id: Target identifier from catalog
|
||||
project_name: Project name for metadata
|
||||
|
||||
Returns:
|
||||
DataFrame with binding results
|
||||
"""
|
||||
|
||||
print("=== Starting Binding Assay Workflow ===")
|
||||
|
||||
# Step 1: Submit experiment
|
||||
print("\n1. Submitting experiment...")
|
||||
metadata = {
|
||||
"project": project_name,
|
||||
"target": target_id
|
||||
}
|
||||
|
||||
experiment = submit_batch_experiment(
|
||||
sequences_dict,
|
||||
experiment_type="binding",
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
experiment_id = experiment['experiment_id']
|
||||
|
||||
# Step 2: Save experiment info
|
||||
print("\n2. Saving experiment details...")
|
||||
with open(f"{experiment_id}_info.json", 'w') as f:
|
||||
json.dump(experiment, f, indent=2)
|
||||
|
||||
print(f"✓ Experiment {experiment_id} submitted")
|
||||
print(" Results will be available in ~21 days")
|
||||
print(" Use webhook or poll status for updates")
|
||||
|
||||
# Note: In practice, wait for completion before this step
|
||||
# print("\n3. Waiting for completion...")
|
||||
# status = wait_for_completion(experiment_id)
|
||||
|
||||
# print("\n4. Downloading results...")
|
||||
# results = download_results(experiment_id)
|
||||
|
||||
# print("\n5. Parsing results...")
|
||||
# df = parse_binding_results(results)
|
||||
|
||||
# return df
|
||||
|
||||
return experiment_id
|
||||
|
||||
# Example
|
||||
antibody_variants = {
|
||||
"variant_1": "EVQLVESGGGLVQPGG...",
|
||||
"variant_2": "EVQLVESGGGLVQPGS...",
|
||||
"variant_3": "EVQLVESGGGLVQPGA...",
|
||||
"wildtype": "EVQLVESGGGLVQPGG..."
|
||||
}
|
||||
|
||||
experiment_id = complete_binding_workflow(
|
||||
antibody_variants,
|
||||
target_id="tgt_pdl1_human",
|
||||
project_name="antibody_affinity_maturation"
|
||||
)
|
||||
```
|
||||
|
||||
### Optimization + Testing Pipeline
|
||||
|
||||
```python
|
||||
# Combine computational optimization with experimental testing
|
||||
|
||||
def optimization_and_testing_pipeline(initial_sequences, experiment_type="expression"):
|
||||
"""
|
||||
Complete pipeline: optimize sequences computationally, then submit for testing
|
||||
|
||||
Args:
|
||||
initial_sequences: Dictionary of {name: sequence}
|
||||
experiment_type: Type of experiment
|
||||
|
||||
Returns:
|
||||
Experiment ID for tracking
|
||||
"""
|
||||
|
||||
print("=== Optimization and Testing Pipeline ===")
|
||||
|
||||
# Step 1: Computational optimization
|
||||
print("\n1. Computational optimization...")
|
||||
from protein_optimization import complete_optimization_pipeline
|
||||
|
||||
optimized = complete_optimization_pipeline(initial_sequences)
|
||||
|
||||
print(f"✓ Optimization complete")
|
||||
print(f" Started with: {len(initial_sequences)} sequences")
|
||||
print(f" Optimized to: {len(optimized)} sequences")
|
||||
|
||||
# Step 2: Select top candidates
|
||||
print("\n2. Selecting top candidates for testing...")
|
||||
top_candidates = optimized[:50] # Top 50
|
||||
|
||||
sequences_to_test = {
|
||||
seq_data['name']: seq_data['sequence']
|
||||
for seq_data in top_candidates
|
||||
}
|
||||
|
||||
# Step 3: Submit for experimental validation
|
||||
print("\n3. Submitting to Adaptyv...")
|
||||
metadata = {
|
||||
"optimization_method": "computational_pipeline",
|
||||
"initial_library_size": len(initial_sequences),
|
||||
"computational_scores": [s['combined'] for s in top_candidates]
|
||||
}
|
||||
|
||||
experiment = submit_batch_experiment(
|
||||
sequences_to_test,
|
||||
experiment_type=experiment_type,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
print(f"✓ Pipeline complete")
|
||||
print(f" Experiment ID: {experiment['experiment_id']}")
|
||||
|
||||
return experiment['experiment_id']
|
||||
|
||||
# Example
|
||||
initial_library = {
|
||||
f"variant_{i}": generate_random_sequence()
|
||||
for i in range(1000)
|
||||
}
|
||||
|
||||
experiment_id = optimization_and_testing_pipeline(
|
||||
initial_library,
|
||||
experiment_type="expression"
|
||||
)
|
||||
```
|
||||
|
||||
### Batch Result Analysis
|
||||
|
||||
```python
|
||||
def analyze_multiple_experiments(experiment_ids):
|
||||
"""
|
||||
Download and analyze results from multiple experiments
|
||||
|
||||
Args:
|
||||
experiment_ids: List of experiment identifiers
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all results
|
||||
"""
|
||||
|
||||
all_results = []
|
||||
|
||||
for exp_id in experiment_ids:
|
||||
print(f"Processing {exp_id}...")
|
||||
|
||||
# Download results
|
||||
results = download_results(exp_id, output_dir=f"results/{exp_id}")
|
||||
|
||||
# Parse based on experiment type
|
||||
exp_type = results.get('experiment_type', 'unknown')
|
||||
|
||||
if exp_type == 'binding':
|
||||
df = parse_binding_results(results)
|
||||
df['experiment_id'] = exp_id
|
||||
all_results.append(df)
|
||||
|
||||
elif exp_type == 'expression':
|
||||
df = parse_expression_results(results)
|
||||
df['experiment_id'] = exp_id
|
||||
all_results.append(df)
|
||||
|
||||
# Combine all results
|
||||
combined_df = pd.concat(all_results, ignore_index=True)
|
||||
|
||||
print(f"\n✓ Analysis complete")
|
||||
print(f" Total experiments: {len(experiment_ids)}")
|
||||
print(f" Total sequences: {len(combined_df)}")
|
||||
|
||||
return combined_df
|
||||
|
||||
# Example
|
||||
experiment_ids = [
|
||||
"exp_round1_abc",
|
||||
"exp_round2_def",
|
||||
"exp_round3_ghi"
|
||||
]
|
||||
|
||||
all_data = analyze_multiple_experiments(experiment_ids)
|
||||
all_data.to_csv("combined_results.csv", index=False)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Robust API Wrapper
|
||||
|
||||
```python
|
||||
import time
|
||||
from requests.exceptions import RequestException, HTTPError
|
||||
|
||||
def api_request_with_retry(method, url, max_retries=3, backoff_factor=2, **kwargs):
|
||||
"""
|
||||
Make API request with retry logic and error handling
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: Request URL
|
||||
max_retries: Maximum number of retry attempts
|
||||
backoff_factor: Exponential backoff multiplier
|
||||
**kwargs: Additional arguments for requests
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
|
||||
Raises:
|
||||
RequestException: If all retries fail
|
||||
"""
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 429: # Rate limit
|
||||
wait_time = backoff_factor ** attempt
|
||||
print(f"Rate limited. Waiting {wait_time}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
elif e.response.status_code >= 500: # Server error
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = backoff_factor ** attempt
|
||||
print(f"Server error. Retrying in {wait_time}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
else: # Client error (4xx) - don't retry
|
||||
error_data = e.response.json() if e.response.content else {}
|
||||
print(f"API Error: {error_data.get('error', {}).get('message', str(e))}")
|
||||
raise
|
||||
|
||||
except RequestException as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = backoff_factor ** attempt
|
||||
print(f"Request failed. Retrying in {wait_time}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
raise RequestException(f"Failed after {max_retries} attempts")
|
||||
|
||||
# Example usage
|
||||
response = api_request_with_retry(
|
||||
"POST",
|
||||
f"{BASE_URL}/experiments",
|
||||
headers=HEADERS,
|
||||
json={"sequences": fasta_content, "experiment_type": "binding"}
|
||||
)
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### Validate FASTA Format
|
||||
|
||||
```python
|
||||
def validate_fasta(fasta_string):
|
||||
"""
|
||||
Validate FASTA format and sequences
|
||||
|
||||
Args:
|
||||
fasta_string: FASTA-formatted string
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
|
||||
lines = fasta_string.strip().split('\n')
|
||||
|
||||
if not lines:
|
||||
return False, "Empty FASTA content"
|
||||
|
||||
if not lines[0].startswith('>'):
|
||||
return False, "FASTA must start with header line (>)"
|
||||
|
||||
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
|
||||
current_header = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith('>'):
|
||||
if not line[1:].strip():
|
||||
return False, f"Line {i+1}: Empty header"
|
||||
current_header = line[1:].strip()
|
||||
|
||||
else:
|
||||
if current_header is None:
|
||||
return False, f"Line {i+1}: Sequence before header"
|
||||
|
||||
sequence = line.strip().upper()
|
||||
invalid = set(sequence) - valid_amino_acids
|
||||
|
||||
if invalid:
|
||||
return False, f"Line {i+1}: Invalid amino acids: {invalid}"
|
||||
|
||||
return True, None
|
||||
|
||||
# Example
|
||||
fasta = ">protein1\nMKVLWAALLG\n>protein2\nMATGVLWALG"
|
||||
is_valid, error = validate_fasta(fasta)
|
||||
|
||||
if is_valid:
|
||||
print("✓ FASTA format valid")
|
||||
else:
|
||||
print(f"✗ FASTA validation failed: {error}")
|
||||
```
|
||||
|
||||
### Format Sequences to FASTA
|
||||
|
||||
```python
|
||||
def sequences_to_fasta(sequences_dict):
|
||||
"""
|
||||
Convert dictionary of sequences to FASTA format
|
||||
|
||||
Args:
|
||||
sequences_dict: Dictionary of {name: sequence}
|
||||
|
||||
Returns:
|
||||
FASTA-formatted string
|
||||
"""
|
||||
|
||||
fasta_content = ""
|
||||
for name, sequence in sequences_dict.items():
|
||||
# Clean sequence (remove whitespace, ensure uppercase)
|
||||
clean_seq = ''.join(sequence.split()).upper()
|
||||
|
||||
# Validate
|
||||
is_valid, error = validate_fasta(f">{name}\n{clean_seq}")
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid sequence '{name}': {error}")
|
||||
|
||||
fasta_content += f">{name}\n{clean_seq}\n"
|
||||
|
||||
return fasta_content
|
||||
|
||||
# Example
|
||||
sequences = {
|
||||
"var1": "MKVLWAALLG",
|
||||
"var2": "MATGVLWALG"
|
||||
}
|
||||
|
||||
fasta = sequences_to_fasta(sequences)
|
||||
print(fasta)
|
||||
```
|
||||
360
scientific-skills/adaptyv/reference/experiments.md
Normal file
360
scientific-skills/adaptyv/reference/experiments.md
Normal file
@@ -0,0 +1,360 @@
|
||||
# Experiment Types and Workflows
|
||||
|
||||
## Overview
|
||||
|
||||
Adaptyv provides multiple experimental assay types for comprehensive protein characterization. Each experiment type has specific applications, workflows, and data outputs.
|
||||
|
||||
## Binding Assays
|
||||
|
||||
### Description
|
||||
|
||||
Measure protein-target interactions using biolayer interferometry (BLI), a label-free technique that monitors biomolecular binding in real-time.
|
||||
|
||||
### Use Cases
|
||||
|
||||
- Antibody-antigen binding characterization
|
||||
- Receptor-ligand interaction analysis
|
||||
- Protein-protein interaction studies
|
||||
- Affinity maturation screening
|
||||
- Epitope binning experiments
|
||||
|
||||
### Technology: Biolayer Interferometry (BLI)
|
||||
|
||||
BLI measures the interference pattern of reflected light from two surfaces:
|
||||
- **Reference layer** - Biosensor tip surface
|
||||
- **Biological layer** - Accumulated bound molecules
|
||||
|
||||
As molecules bind, the optical thickness increases, causing a wavelength shift proportional to binding.
|
||||
|
||||
**Advantages:**
|
||||
- Label-free detection
|
||||
- Real-time kinetics
|
||||
- High-throughput compatible
|
||||
- Works in crude samples
|
||||
- Minimal sample consumption
|
||||
|
||||
### Measured Parameters
|
||||
|
||||
**Kinetic constants:**
|
||||
- **KD** - Equilibrium dissociation constant (binding affinity)
|
||||
- **kon** - Association rate constant (binding speed)
|
||||
- **koff** - Dissociation rate constant (unbinding speed)
|
||||
|
||||
**Typical ranges:**
|
||||
- Strong binders: KD < 1 nM
|
||||
- Moderate binders: KD = 1-100 nM
|
||||
- Weak binders: KD > 100 nM
|
||||
|
||||
### Workflow
|
||||
|
||||
1. **Sequence submission** - Provide protein sequences in FASTA format
|
||||
2. **Expression** - Proteins expressed in appropriate host system
|
||||
3. **Purification** - Automated purification protocols
|
||||
4. **BLI assay** - Real-time binding measurements against specified targets
|
||||
5. **Analysis** - Kinetic curve fitting and quality assessment
|
||||
6. **Results delivery** - Binding parameters with confidence metrics
|
||||
|
||||
### Sample Requirements
|
||||
|
||||
- Protein sequence (standard amino acid codes)
|
||||
- Target specification (from catalog or custom request)
|
||||
- Buffer conditions (standard or custom)
|
||||
- Expected concentration range (optional, improves assay design)
|
||||
|
||||
### Results Format
|
||||
|
||||
```json
|
||||
{
|
||||
"sequence_id": "antibody_variant_1",
|
||||
"target": "Human PD-L1",
|
||||
"measurements": {
|
||||
"kd": 2.5e-9,
|
||||
"kd_error": 0.3e-9,
|
||||
"kon": 1.8e5,
|
||||
"kon_error": 0.2e5,
|
||||
"koff": 4.5e-4,
|
||||
"koff_error": 0.5e-4
|
||||
},
|
||||
"quality_metrics": {
|
||||
"confidence": "high|medium|low",
|
||||
"r_squared": 0.97,
|
||||
"chi_squared": 0.02,
|
||||
"flags": []
|
||||
},
|
||||
"raw_data_url": "https://..."
|
||||
}
|
||||
```
|
||||
|
||||
## Expression Testing
|
||||
|
||||
### Description
|
||||
|
||||
Quantify protein expression levels in various host systems to assess producibility and optimize sequences for manufacturing.
|
||||
|
||||
### Use Cases
|
||||
|
||||
- Screening variants for high expression
|
||||
- Optimizing codon usage
|
||||
- Identifying expression bottlenecks
|
||||
- Selecting candidates for scale-up
|
||||
- Comparing expression systems
|
||||
|
||||
### Host Systems
|
||||
|
||||
Available expression platforms:
|
||||
- **E. coli** - Rapid, cost-effective, prokaryotic system
|
||||
- **Mammalian cells** - Native post-translational modifications
|
||||
- **Yeast** - Eukaryotic system with simpler growth requirements
|
||||
- **Insect cells** - Alternative eukaryotic platform
|
||||
|
||||
### Measured Parameters
|
||||
|
||||
- **Total protein yield** (mg/L culture)
|
||||
- **Soluble fraction** (percentage)
|
||||
- **Purity** (after initial purification)
|
||||
- **Expression time course** (optional)
|
||||
|
||||
### Workflow
|
||||
|
||||
1. **Sequence submission** - Provide protein sequences
|
||||
2. **Construct generation** - Cloning into expression vectors
|
||||
3. **Expression** - Culture in specified host system
|
||||
4. **Quantification** - Protein measurement via multiple methods
|
||||
5. **Analysis** - Expression level comparison and ranking
|
||||
6. **Results delivery** - Yield data and recommendations
|
||||
|
||||
### Results Format
|
||||
|
||||
```json
|
||||
{
|
||||
"sequence_id": "variant_1",
|
||||
"host_system": "E. coli",
|
||||
"measurements": {
|
||||
"total_yield_mg_per_l": 25.5,
|
||||
"soluble_fraction_percent": 78,
|
||||
"purity_percent": 92
|
||||
},
|
||||
"ranking": {
|
||||
"percentile": 85,
|
||||
"notes": "High expression, good solubility"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Thermostability Testing
|
||||
|
||||
### Description
|
||||
|
||||
Measure protein thermal stability to assess structural integrity, predict shelf-life, and identify stabilizing mutations.
|
||||
|
||||
### Use Cases
|
||||
|
||||
- Selecting thermally stable variants
|
||||
- Formulation development
|
||||
- Shelf-life prediction
|
||||
- Stability-driven protein engineering
|
||||
- Quality control screening
|
||||
|
||||
### Measurement Techniques
|
||||
|
||||
**Differential Scanning Fluorimetry (DSF):**
|
||||
- Monitors protein unfolding via fluorescent dye binding
|
||||
- Determines melting temperature (Tm)
|
||||
- High-throughput capable
|
||||
|
||||
**Circular Dichroism (CD):**
|
||||
- Secondary structure analysis
|
||||
- Thermal unfolding curves
|
||||
- Reversibility assessment
|
||||
|
||||
### Measured Parameters
|
||||
|
||||
- **Tm** - Melting temperature (midpoint of unfolding)
|
||||
- **ΔH** - Enthalpy of unfolding
|
||||
- **Aggregation temperature** (Tagg)
|
||||
- **Reversibility** - Refolding after heating
|
||||
|
||||
### Workflow
|
||||
|
||||
1. **Sequence submission** - Provide protein sequences
|
||||
2. **Expression and purification** - Standard protocols
|
||||
3. **Thermostability assay** - Temperature gradient analysis
|
||||
4. **Data analysis** - Curve fitting and parameter extraction
|
||||
5. **Results delivery** - Stability metrics with ranking
|
||||
|
||||
### Results Format
|
||||
|
||||
```json
|
||||
{
|
||||
"sequence_id": "variant_1",
|
||||
"measurements": {
|
||||
"tm_celsius": 68.5,
|
||||
"tm_error": 0.5,
|
||||
"tagg_celsius": 72.0,
|
||||
"reversibility_percent": 85
|
||||
},
|
||||
"quality_metrics": {
|
||||
"curve_quality": "excellent",
|
||||
"cooperativity": "two-state"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Enzyme Activity Assays
|
||||
|
||||
### Description
|
||||
|
||||
Measure enzymatic function including substrate turnover, catalytic efficiency, and inhibitor sensitivity.
|
||||
|
||||
### Use Cases
|
||||
|
||||
- Screening enzyme variants for improved activity
|
||||
- Substrate specificity profiling
|
||||
- Inhibitor testing
|
||||
- pH and temperature optimization
|
||||
- Mechanistic studies
|
||||
|
||||
### Assay Types
|
||||
|
||||
**Continuous assays:**
|
||||
- Chromogenic substrates
|
||||
- Fluorogenic substrates
|
||||
- Real-time monitoring
|
||||
|
||||
**Endpoint assays:**
|
||||
- HPLC quantification
|
||||
- Mass spectrometry
|
||||
- Colorimetric detection
|
||||
|
||||
### Measured Parameters
|
||||
|
||||
**Kinetic parameters:**
|
||||
- **kcat** - Turnover number (catalytic rate constant)
|
||||
- **KM** - Michaelis constant (substrate affinity)
|
||||
- **kcat/KM** - Catalytic efficiency
|
||||
- **IC50** - Inhibitor concentration for 50% inhibition
|
||||
|
||||
**Activity metrics:**
|
||||
- Specific activity (units/mg protein)
|
||||
- Relative activity vs. reference
|
||||
- Substrate specificity profile
|
||||
|
||||
### Workflow
|
||||
|
||||
1. **Sequence submission** - Provide enzyme sequences
|
||||
2. **Expression and purification** - Optimized for activity retention
|
||||
3. **Activity assay** - Substrate turnover measurements
|
||||
4. **Kinetic analysis** - Michaelis-Menten fitting
|
||||
5. **Results delivery** - Kinetic parameters and rankings
|
||||
|
||||
### Results Format
|
||||
|
||||
```json
|
||||
{
|
||||
"sequence_id": "enzyme_variant_1",
|
||||
"substrate": "substrate_name",
|
||||
"measurements": {
|
||||
"kcat_per_second": 125,
|
||||
"km_micromolar": 45,
|
||||
"kcat_km": 2.8,
|
||||
"specific_activity": 180
|
||||
},
|
||||
"quality_metrics": {
|
||||
"confidence": "high",
|
||||
"r_squared": 0.99
|
||||
},
|
||||
"ranking": {
|
||||
"relative_activity": 1.8,
|
||||
"improvement_vs_wildtype": "80%"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Experiment Design Best Practices
|
||||
|
||||
### Sequence Submission
|
||||
|
||||
1. **Use clear identifiers** - Name sequences descriptively
|
||||
2. **Include controls** - Submit wild-type or reference sequences
|
||||
3. **Batch similar variants** - Group related sequences in single submission
|
||||
4. **Validate sequences** - Check for errors before submission
|
||||
|
||||
### Sample Size
|
||||
|
||||
- **Pilot studies** - 5-10 sequences to test feasibility
|
||||
- **Library screening** - 50-500 sequences for variant exploration
|
||||
- **Focused optimization** - 10-50 sequences for fine-tuning
|
||||
- **Large-scale campaigns** - 500+ sequences for ML-driven design
|
||||
|
||||
### Quality Control
|
||||
|
||||
Adaptyv includes automated QC steps:
|
||||
- Expression verification before assay
|
||||
- Replicate measurements for reliability
|
||||
- Positive/negative controls in each batch
|
||||
- Statistical validation of results
|
||||
|
||||
### Timeline Expectations
|
||||
|
||||
**Standard turnaround:** ~21 days from submission to results
|
||||
|
||||
**Timeline breakdown:**
|
||||
- Construct generation: 3-5 days
|
||||
- Expression: 5-7 days
|
||||
- Purification: 2-3 days
|
||||
- Assay execution: 3-5 days
|
||||
- Analysis and QC: 2-3 days
|
||||
|
||||
**Factors affecting timeline:**
|
||||
- Custom targets (add 1-2 weeks)
|
||||
- Novel assay development (add 2-4 weeks)
|
||||
- Large batch sizes (may add 1 week)
|
||||
|
||||
### Cost Optimization
|
||||
|
||||
1. **Batch submissions** - Lower per-sequence cost
|
||||
2. **Standard targets** - Catalog antigens are faster/cheaper
|
||||
3. **Standard conditions** - Custom buffers add cost
|
||||
4. **Computational pre-filtering** - Submit only promising candidates
|
||||
|
||||
## Combining Experiment Types
|
||||
|
||||
For comprehensive protein characterization, combine multiple assays:
|
||||
|
||||
**Therapeutic antibody development:**
|
||||
1. Binding assay → Identify high-affinity binders
|
||||
2. Expression testing → Select manufacturable candidates
|
||||
3. Thermostability → Ensure formulation stability
|
||||
|
||||
**Enzyme engineering:**
|
||||
1. Activity assay → Screen for improved catalysis
|
||||
2. Expression testing → Ensure producibility
|
||||
3. Thermostability → Validate industrial robustness
|
||||
|
||||
**Sequential vs. Parallel:**
|
||||
- **Sequential** - Use results from early assays to filter candidates
|
||||
- **Parallel** - Run all assays simultaneously for faster results
|
||||
|
||||
## Data Integration
|
||||
|
||||
Results integrate with computational workflows:
|
||||
|
||||
1. **Download raw data** via API
|
||||
2. **Parse results** into standardized format
|
||||
3. **Feed into ML models** for next-round design
|
||||
4. **Track experiments** with metadata tags
|
||||
5. **Visualize trends** across design iterations
|
||||
|
||||
## Support and Troubleshooting
|
||||
|
||||
**Common issues:**
|
||||
- Low expression → Consider sequence optimization (see protein_optimization.md)
|
||||
- Poor binding → Verify target specification and expected range
|
||||
- Variable results → Check sequence quality and controls
|
||||
- Incomplete data → Contact support with experiment ID
|
||||
|
||||
**Getting help:**
|
||||
- Email: support@adaptyvbio.com
|
||||
- Include experiment ID and specific question
|
||||
- Provide context (design goals, expected results)
|
||||
- Response time: <24 hours for active experiments
|
||||
637
scientific-skills/adaptyv/reference/protein_optimization.md
Normal file
637
scientific-skills/adaptyv/reference/protein_optimization.md
Normal file
@@ -0,0 +1,637 @@
|
||||
# Protein Sequence Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
Before submitting protein sequences for experimental testing, use computational tools to optimize sequences for improved expression, solubility, and stability. This pre-screening reduces experimental costs and increases success rates.
|
||||
|
||||
## Common Protein Expression Problems
|
||||
|
||||
### 1. Unpaired Cysteines
|
||||
|
||||
**Problem:**
|
||||
- Unpaired cysteines form unwanted disulfide bonds
|
||||
- Leads to aggregation and misfolding
|
||||
- Reduces expression yield and stability
|
||||
|
||||
**Solution:**
|
||||
- Remove unpaired cysteines unless functionally necessary
|
||||
- Pair cysteines appropriately for structural disulfides
|
||||
- Replace with serine or alanine in non-critical positions
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
# Check for cysteine pairs
|
||||
from Bio.Seq import Seq
|
||||
|
||||
def check_cysteines(sequence):
|
||||
cys_count = sequence.count('C')
|
||||
if cys_count % 2 != 0:
|
||||
print(f"Warning: Odd number of cysteines ({cys_count})")
|
||||
return cys_count
|
||||
```
|
||||
|
||||
### 2. Excessive Hydrophobicity
|
||||
|
||||
**Problem:**
|
||||
- Long hydrophobic patches promote aggregation
|
||||
- Exposed hydrophobic residues drive protein clumping
|
||||
- Poor solubility in aqueous buffers
|
||||
|
||||
**Solution:**
|
||||
- Maintain balanced hydropathy profiles
|
||||
- Use short, flexible linkers between domains
|
||||
- Reduce surface-exposed hydrophobic residues
|
||||
|
||||
**Metrics:**
|
||||
- Kyte-Doolittle hydropathy plots
|
||||
- GRAVY score (Grand Average of Hydropathy)
|
||||
- pSAE (percent Solvent-Accessible hydrophobic residues)
|
||||
|
||||
### 3. Low Solubility
|
||||
|
||||
**Problem:**
|
||||
- Proteins precipitate during expression or purification
|
||||
- Inclusion body formation
|
||||
- Difficult downstream processing
|
||||
|
||||
**Solution:**
|
||||
- Use solubility prediction tools for pre-screening
|
||||
- Apply sequence optimization algorithms
|
||||
- Add solubilizing tags if needed
|
||||
|
||||
## Computational Tools for Optimization
|
||||
|
||||
### NetSolP - Initial Solubility Screening
|
||||
|
||||
**Purpose:** Fast solubility prediction for filtering sequences.
|
||||
|
||||
**Method:** Machine learning model trained on E. coli expression data.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Install: uv pip install requests
|
||||
import requests
|
||||
|
||||
def predict_solubility_netsolp(sequence):
|
||||
"""Predict protein solubility using NetSolP web service"""
|
||||
url = "https://services.healthtech.dtu.dk/services/NetSolP-1.0/api/predict"
|
||||
|
||||
data = {
|
||||
"sequence": sequence,
|
||||
"format": "fasta"
|
||||
}
|
||||
|
||||
response = requests.post(url, data=data)
|
||||
return response.json()
|
||||
|
||||
# Example
|
||||
sequence = "MKVLWAALLGLLGAAA..."
|
||||
result = predict_solubility_netsolp(sequence)
|
||||
print(f"Solubility score: {result['score']}")
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- Score > 0.5: Likely soluble
|
||||
- Score < 0.5: Likely insoluble
|
||||
- Use for initial filtering before more expensive predictions
|
||||
|
||||
**When to use:**
|
||||
- First-pass filtering of large libraries
|
||||
- Quick validation of designed sequences
|
||||
- Prioritizing sequences for experimental testing
|
||||
|
||||
### SoluProt - Comprehensive Solubility Prediction
|
||||
|
||||
**Purpose:** Advanced solubility prediction with higher accuracy.
|
||||
|
||||
**Method:** Deep learning model incorporating sequence and structural features.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Install: uv pip install soluprot
|
||||
from soluprot import predict_solubility
|
||||
|
||||
def screen_variants_soluprot(sequences):
|
||||
"""Screen multiple sequences for solubility"""
|
||||
results = []
|
||||
for name, seq in sequences.items():
|
||||
score = predict_solubility(seq)
|
||||
results.append({
|
||||
'name': name,
|
||||
'sequence': seq,
|
||||
'solubility_score': score,
|
||||
'predicted_soluble': score > 0.6
|
||||
})
|
||||
return results
|
||||
|
||||
# Example
|
||||
sequences = {
|
||||
'variant_1': 'MKVLW...',
|
||||
'variant_2': 'MATGV...'
|
||||
}
|
||||
|
||||
results = screen_variants_soluprot(sequences)
|
||||
soluble_variants = [r for r in results if r['predicted_soluble']]
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- Score > 0.6: High solubility confidence
|
||||
- Score 0.4-0.6: Uncertain, may need optimization
|
||||
- Score < 0.4: Likely problematic
|
||||
|
||||
**When to use:**
|
||||
- After initial NetSolP filtering
|
||||
- When higher prediction accuracy is needed
|
||||
- Before committing to expensive synthesis/testing
|
||||
|
||||
### SolubleMPNN - Sequence Redesign
|
||||
|
||||
**Purpose:** Redesign protein sequences to improve solubility while maintaining function.
|
||||
|
||||
**Method:** Graph neural network that suggests mutations to increase solubility.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Install: uv pip install soluble-mpnn
|
||||
from soluble_mpnn import optimize_sequence
|
||||
|
||||
def optimize_for_solubility(sequence, structure_pdb=None):
|
||||
"""
|
||||
Redesign sequence for improved solubility
|
||||
|
||||
Args:
|
||||
sequence: Original amino acid sequence
|
||||
structure_pdb: Optional PDB file for structure-aware design
|
||||
|
||||
Returns:
|
||||
Optimized sequence variants ranked by predicted solubility
|
||||
"""
|
||||
|
||||
variants = optimize_sequence(
|
||||
sequence=sequence,
|
||||
structure=structure_pdb,
|
||||
num_variants=10,
|
||||
temperature=0.1 # Lower = more conservative mutations
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
# Example
|
||||
original_seq = "MKVLWAALLGLLGAAA..."
|
||||
optimized_variants = optimize_for_solubility(original_seq)
|
||||
|
||||
for i, variant in enumerate(optimized_variants):
|
||||
print(f"Variant {i+1}:")
|
||||
print(f" Sequence: {variant['sequence']}")
|
||||
print(f" Solubility score: {variant['solubility_score']}")
|
||||
print(f" Mutations: {variant['mutations']}")
|
||||
```
|
||||
|
||||
**Design strategy:**
|
||||
- **Conservative** (temperature=0.1): Minimal changes, safer
|
||||
- **Moderate** (temperature=0.3): Balance between change and safety
|
||||
- **Aggressive** (temperature=0.5): More mutations, higher risk
|
||||
|
||||
**When to use:**
|
||||
- Primary tool for sequence optimization
|
||||
- Default starting point for improving problematic sequences
|
||||
- Generating diverse soluble variants
|
||||
|
||||
**Best practices:**
|
||||
- Generate 10-50 variants per sequence
|
||||
- Use structure information when available (improves accuracy)
|
||||
- Validate key functional residues are preserved
|
||||
- Test multiple temperature settings
|
||||
|
||||
### ESM (Evolutionary Scale Modeling) - Sequence Likelihood
|
||||
|
||||
**Purpose:** Assess how "natural" a protein sequence appears based on evolutionary patterns.
|
||||
|
||||
**Method:** Protein language model trained on millions of natural sequences.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Install: uv pip install fair-esm
|
||||
import torch
|
||||
from esm import pretrained
|
||||
|
||||
def score_sequence_esm(sequence):
|
||||
"""
|
||||
Calculate ESM likelihood score for sequence
|
||||
Higher scores indicate more natural/stable sequences
|
||||
"""
|
||||
|
||||
model, alphabet = pretrained.esm2_t33_650M_UR50D()
|
||||
batch_converter = alphabet.get_batch_converter()
|
||||
|
||||
data = [("protein", sequence)]
|
||||
_, _, batch_tokens = batch_converter(data)
|
||||
|
||||
with torch.no_grad():
|
||||
results = model(batch_tokens, repr_layers=[33])
|
||||
token_logprobs = results["logits"].log_softmax(dim=-1)
|
||||
|
||||
# Calculate perplexity as sequence quality metric
|
||||
sequence_score = token_logprobs.mean().item()
|
||||
|
||||
return sequence_score
|
||||
|
||||
# Example - Compare variants
|
||||
sequences = {
|
||||
'original': 'MKVLW...',
|
||||
'optimized_1': 'MKVLS...',
|
||||
'optimized_2': 'MKVLA...'
|
||||
}
|
||||
|
||||
for name, seq in sequences.items():
|
||||
score = score_sequence_esm(seq)
|
||||
print(f"{name}: ESM score = {score:.3f}")
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- Higher scores → More "natural" sequence
|
||||
- Use to avoid unlikely mutations
|
||||
- Balance with functional requirements
|
||||
|
||||
**When to use:**
|
||||
- Filtering synthetic designs
|
||||
- Comparing SolubleMPNN variants
|
||||
- Ensuring sequences aren't too artificial
|
||||
- Avoiding expression bottlenecks
|
||||
|
||||
**Integration with design:**
|
||||
```python
|
||||
def rank_variants_by_esm(variants):
|
||||
"""Rank protein variants by ESM likelihood"""
|
||||
scored = []
|
||||
for v in variants:
|
||||
esm_score = score_sequence_esm(v['sequence'])
|
||||
v['esm_score'] = esm_score
|
||||
scored.append(v)
|
||||
|
||||
# Sort by combined solubility and ESM score
|
||||
scored.sort(
|
||||
key=lambda x: x['solubility_score'] * x['esm_score'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return scored
|
||||
```
|
||||
|
||||
### ipTM - Interface Stability (AlphaFold-Multimer)
|
||||
|
||||
**Purpose:** Assess protein-protein interface stability and binding confidence.
|
||||
|
||||
**Method:** Interface predicted TM-score from AlphaFold-Multimer predictions.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Requires AlphaFold-Multimer installation
|
||||
# Or use ColabFold for easier access
|
||||
|
||||
def predict_interface_stability(protein_a_seq, protein_b_seq):
|
||||
"""
|
||||
Predict interface stability using AlphaFold-Multimer
|
||||
|
||||
Returns ipTM score: higher = more stable interface
|
||||
"""
|
||||
from colabfold import run_alphafold_multimer
|
||||
|
||||
sequences = {
|
||||
'chainA': protein_a_seq,
|
||||
'chainB': protein_b_seq
|
||||
}
|
||||
|
||||
result = run_alphafold_multimer(sequences)
|
||||
|
||||
return {
|
||||
'ipTM': result['iptm'],
|
||||
'pTM': result['ptm'],
|
||||
'pLDDT': result['plddt']
|
||||
}
|
||||
|
||||
# Example for antibody-antigen binding
|
||||
antibody_seq = "EVQLVESGGGLVQPGG..."
|
||||
antigen_seq = "MKVLWAALLGLLGAAA..."
|
||||
|
||||
stability = predict_interface_stability(antibody_seq, antigen_seq)
|
||||
print(f"Interface pTM: {stability['ipTM']:.3f}")
|
||||
|
||||
# Interpretation
|
||||
if stability['ipTM'] > 0.7:
|
||||
print("High confidence interface")
|
||||
elif stability['ipTM'] > 0.5:
|
||||
print("Moderate confidence interface")
|
||||
else:
|
||||
print("Low confidence interface - may need redesign")
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- ipTM > 0.7: Strong predicted interface
|
||||
- ipTM 0.5-0.7: Moderate interface confidence
|
||||
- ipTM < 0.5: Weak interface, consider redesign
|
||||
|
||||
**When to use:**
|
||||
- Antibody-antigen design
|
||||
- Protein-protein interaction engineering
|
||||
- Validating binding interfaces
|
||||
- Comparing interface variants
|
||||
|
||||
### pSAE - Solvent-Accessible Hydrophobic Residues
|
||||
|
||||
**Purpose:** Quantify exposed hydrophobic residues that promote aggregation.
|
||||
|
||||
**Method:** Calculates percentage of solvent-accessible surface area (SASA) occupied by hydrophobic residues.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
# Requires structure (PDB file or AlphaFold prediction)
|
||||
# Install: uv pip install biopython
|
||||
|
||||
from Bio.PDB import PDBParser, DSSP
|
||||
import numpy as np
|
||||
|
||||
def calculate_psae(pdb_file):
|
||||
"""
|
||||
Calculate percent Solvent-Accessible hydrophobic residues (pSAE)
|
||||
|
||||
Lower pSAE = better solubility
|
||||
"""
|
||||
|
||||
parser = PDBParser(QUIET=True)
|
||||
structure = parser.get_structure('protein', pdb_file)
|
||||
|
||||
# Run DSSP to get solvent accessibility
|
||||
model = structure[0]
|
||||
dssp = DSSP(model, pdb_file, acc_array='Wilke')
|
||||
|
||||
hydrophobic = ['ALA', 'VAL', 'ILE', 'LEU', 'MET', 'PHE', 'TRP', 'PRO']
|
||||
|
||||
total_sasa = 0
|
||||
hydrophobic_sasa = 0
|
||||
|
||||
for residue in dssp:
|
||||
res_name = residue[1]
|
||||
rel_accessibility = residue[3]
|
||||
|
||||
total_sasa += rel_accessibility
|
||||
if res_name in hydrophobic:
|
||||
hydrophobic_sasa += rel_accessibility
|
||||
|
||||
psae = (hydrophobic_sasa / total_sasa) * 100
|
||||
|
||||
return psae
|
||||
|
||||
# Example
|
||||
pdb_file = "protein_structure.pdb"
|
||||
psae_score = calculate_psae(pdb_file)
|
||||
print(f"pSAE: {psae_score:.2f}%")
|
||||
|
||||
# Interpretation
|
||||
if psae_score < 25:
|
||||
print("Good solubility expected")
|
||||
elif psae_score < 35:
|
||||
print("Moderate solubility")
|
||||
else:
|
||||
print("High aggregation risk")
|
||||
```
|
||||
|
||||
**Interpretation:**
|
||||
- pSAE < 25%: Low aggregation risk
|
||||
- pSAE 25-35%: Moderate risk
|
||||
- pSAE > 35%: High aggregation risk
|
||||
|
||||
**When to use:**
|
||||
- Analyzing designed structures
|
||||
- Post-AlphaFold validation
|
||||
- Identifying aggregation hotspots
|
||||
- Guiding surface mutations
|
||||
|
||||
## Recommended Optimization Workflow
|
||||
|
||||
### Step 1: Initial Screening (Fast)
|
||||
|
||||
```python
|
||||
def initial_screening(sequences):
|
||||
"""
|
||||
Quick first-pass filtering using NetSolP
|
||||
Filters out obviously problematic sequences
|
||||
"""
|
||||
passed = []
|
||||
for name, seq in sequences.items():
|
||||
netsolp_score = predict_solubility_netsolp(seq)
|
||||
if netsolp_score > 0.5:
|
||||
passed.append((name, seq))
|
||||
|
||||
return passed
|
||||
```
|
||||
|
||||
### Step 2: Detailed Assessment (Moderate)
|
||||
|
||||
```python
|
||||
def detailed_assessment(filtered_sequences):
|
||||
"""
|
||||
More thorough analysis with SoluProt and ESM
|
||||
Ranks sequences by multiple criteria
|
||||
"""
|
||||
results = []
|
||||
for name, seq in filtered_sequences:
|
||||
soluprot_score = predict_solubility(seq)
|
||||
esm_score = score_sequence_esm(seq)
|
||||
|
||||
combined_score = soluprot_score * 0.7 + esm_score * 0.3
|
||||
|
||||
results.append({
|
||||
'name': name,
|
||||
'sequence': seq,
|
||||
'soluprot': soluprot_score,
|
||||
'esm': esm_score,
|
||||
'combined': combined_score
|
||||
})
|
||||
|
||||
results.sort(key=lambda x: x['combined'], reverse=True)
|
||||
return results
|
||||
```
|
||||
|
||||
### Step 3: Sequence Optimization (If needed)
|
||||
|
||||
```python
|
||||
def optimize_problematic_sequences(sequences_needing_optimization):
|
||||
"""
|
||||
Use SolubleMPNN to redesign problematic sequences
|
||||
Returns improved variants
|
||||
"""
|
||||
optimized = []
|
||||
for name, seq in sequences_needing_optimization:
|
||||
# Generate multiple variants
|
||||
variants = optimize_sequence(
|
||||
sequence=seq,
|
||||
num_variants=10,
|
||||
temperature=0.2
|
||||
)
|
||||
|
||||
# Score variants with ESM
|
||||
for variant in variants:
|
||||
variant['esm_score'] = score_sequence_esm(variant['sequence'])
|
||||
|
||||
# Keep best variants
|
||||
variants.sort(
|
||||
key=lambda x: x['solubility_score'] * x['esm_score'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
optimized.extend(variants[:3]) # Top 3 variants per sequence
|
||||
|
||||
return optimized
|
||||
```
|
||||
|
||||
### Step 4: Structure-Based Validation (For critical sequences)
|
||||
|
||||
```python
|
||||
def structure_validation(top_candidates):
|
||||
"""
|
||||
Predict structures and calculate pSAE for top candidates
|
||||
Final validation before experimental testing
|
||||
"""
|
||||
validated = []
|
||||
for candidate in top_candidates:
|
||||
# Predict structure with AlphaFold
|
||||
structure_pdb = predict_structure_alphafold(candidate['sequence'])
|
||||
|
||||
# Calculate pSAE
|
||||
psae = calculate_psae(structure_pdb)
|
||||
|
||||
candidate['psae'] = psae
|
||||
candidate['pass_structure_check'] = psae < 30
|
||||
|
||||
validated.append(candidate)
|
||||
|
||||
return validated
|
||||
```
|
||||
|
||||
### Complete Workflow Example
|
||||
|
||||
```python
|
||||
def complete_optimization_pipeline(initial_sequences):
|
||||
"""
|
||||
End-to-end optimization pipeline
|
||||
|
||||
Input: Dictionary of {name: sequence}
|
||||
Output: Ranked list of optimized, validated sequences
|
||||
"""
|
||||
|
||||
print("Step 1: Initial screening with NetSolP...")
|
||||
filtered = initial_screening(initial_sequences)
|
||||
print(f" Passed: {len(filtered)}/{len(initial_sequences)}")
|
||||
|
||||
print("Step 2: Detailed assessment with SoluProt and ESM...")
|
||||
assessed = detailed_assessment(filtered)
|
||||
|
||||
# Split into good and needs-optimization
|
||||
good_sequences = [s for s in assessed if s['soluprot'] > 0.6]
|
||||
needs_optimization = [s for s in assessed if s['soluprot'] <= 0.6]
|
||||
|
||||
print(f" Good sequences: {len(good_sequences)}")
|
||||
print(f" Need optimization: {len(needs_optimization)}")
|
||||
|
||||
if needs_optimization:
|
||||
print("Step 3: Optimizing problematic sequences with SolubleMPNN...")
|
||||
optimized = optimize_problematic_sequences(needs_optimization)
|
||||
all_sequences = good_sequences + optimized
|
||||
else:
|
||||
all_sequences = good_sequences
|
||||
|
||||
print("Step 4: Structure-based validation for top candidates...")
|
||||
top_20 = all_sequences[:20]
|
||||
final_validated = structure_validation(top_20)
|
||||
|
||||
# Final ranking
|
||||
final_validated.sort(
|
||||
key=lambda x: (
|
||||
x['pass_structure_check'],
|
||||
x['combined'],
|
||||
-x['psae']
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return final_validated
|
||||
|
||||
# Usage
|
||||
initial_library = {
|
||||
'variant_1': 'MKVLWAALLGLLGAAA...',
|
||||
'variant_2': 'MATGVLWAALLGLLGA...',
|
||||
# ... more sequences
|
||||
}
|
||||
|
||||
optimized_library = complete_optimization_pipeline(initial_library)
|
||||
|
||||
# Submit top sequences to Adaptyv
|
||||
top_sequences_for_testing = optimized_library[:50]
|
||||
```
|
||||
|
||||
## Best Practices Summary
|
||||
|
||||
1. **Always pre-screen** before experimental testing
|
||||
2. **Use NetSolP first** for fast filtering of large libraries
|
||||
3. **Apply SolubleMPNN** as default optimization tool
|
||||
4. **Validate with ESM** to avoid unnatural sequences
|
||||
5. **Calculate pSAE** for structure-based validation
|
||||
6. **Test multiple variants** per design to account for prediction uncertainty
|
||||
7. **Keep controls** - include wild-type or known-good sequences
|
||||
8. **Iterate** - use experimental results to refine predictions
|
||||
|
||||
## Integration with Adaptyv
|
||||
|
||||
After computational optimization, submit sequences to Adaptyv:
|
||||
|
||||
```python
|
||||
# After optimization pipeline
|
||||
optimized_sequences = complete_optimization_pipeline(initial_library)
|
||||
|
||||
# Prepare FASTA format
|
||||
fasta_content = ""
|
||||
for seq_data in optimized_sequences[:50]: # Top 50
|
||||
fasta_content += f">{seq_data['name']}\n{seq_data['sequence']}\n"
|
||||
|
||||
# Submit to Adaptyv
|
||||
import requests
|
||||
response = requests.post(
|
||||
"https://kq5jp7qj7wdqklhsxmovkzn4l40obksv.lambda-url.eu-central-1.on.aws/experiments",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"sequences": fasta_content,
|
||||
"experiment_type": "expression",
|
||||
"metadata": {
|
||||
"optimization_method": "SolubleMPNN_ESM_pipeline",
|
||||
"computational_scores": [s['combined'] for s in optimized_sequences[:50]]
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Issue: All sequences score poorly on solubility predictions**
|
||||
- Check if sequences contain unusual amino acids
|
||||
- Verify FASTA format is correct
|
||||
- Consider if protein family is naturally low-solubility
|
||||
- May need experimental validation despite predictions
|
||||
|
||||
**Issue: SolubleMPNN changes functionally important residues**
|
||||
- Provide structure file to preserve spatial constraints
|
||||
- Mask critical residues from mutation
|
||||
- Lower temperature parameter for conservative changes
|
||||
- Manually revert problematic mutations
|
||||
|
||||
**Issue: ESM scores are low after optimization**
|
||||
- Optimization may be too aggressive
|
||||
- Try lower temperature in SolubleMPNN
|
||||
- Balance between solubility and naturalness
|
||||
- Consider that some optimization may require non-natural mutations
|
||||
|
||||
**Issue: Predictions don't match experimental results**
|
||||
- Predictions are probabilistic, not deterministic
|
||||
- Host system and conditions affect expression
|
||||
- Some proteins may need experimental validation
|
||||
- Use predictions as enrichment, not absolute filters
|
||||
371
scientific-skills/aeon/SKILL.md
Normal file
371
scientific-skills/aeon/SKILL.md
Normal file
@@ -0,0 +1,371 @@
|
||||
---
|
||||
name: aeon
|
||||
description: This skill should be used for time series machine learning tasks including classification, regression, clustering, forecasting, anomaly detection, segmentation, and similarity search. Use when working with temporal data, sequential patterns, or time-indexed observations requiring specialized algorithms beyond standard ML approaches. Particularly suited for univariate and multivariate time series analysis with scikit-learn compatible APIs.
|
||||
license: BSD-3-Clause license
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# Aeon Time Series Machine Learning
|
||||
|
||||
## Overview
|
||||
|
||||
Aeon is a scikit-learn compatible Python toolkit for time series machine learning. It provides state-of-the-art algorithms for classification, regression, clustering, forecasting, anomaly detection, segmentation, and similarity search.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Apply this skill when:
|
||||
- Classifying or predicting from time series data
|
||||
- Detecting anomalies or change points in temporal sequences
|
||||
- Clustering similar time series patterns
|
||||
- Forecasting future values
|
||||
- Finding repeated patterns (motifs) or unusual subsequences (discords)
|
||||
- Comparing time series with specialized distance metrics
|
||||
- Extracting features from temporal data
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv pip install aeon
|
||||
```
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Time Series Classification
|
||||
|
||||
Categorize time series into predefined classes. See `references/classification.md` for complete algorithm catalog.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.classification.convolution_based import RocketClassifier
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load data
|
||||
X_train, y_train = load_classification("GunPoint", split="train")
|
||||
X_test, y_test = load_classification("GunPoint", split="test")
|
||||
|
||||
# Train classifier
|
||||
clf = RocketClassifier(n_kernels=10000)
|
||||
clf.fit(X_train, y_train)
|
||||
accuracy = clf.score(X_test, y_test)
|
||||
```
|
||||
|
||||
**Algorithm Selection:**
|
||||
- **Speed + Performance**: `MiniRocketClassifier`, `Arsenal`
|
||||
- **Maximum Accuracy**: `HIVECOTEV2`, `InceptionTimeClassifier`
|
||||
- **Interpretability**: `ShapeletTransformClassifier`, `Catch22Classifier`
|
||||
- **Small Datasets**: `KNeighborsTimeSeriesClassifier` with DTW distance
|
||||
|
||||
### 2. Time Series Regression
|
||||
|
||||
Predict continuous values from time series. See `references/regression.md` for algorithms.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.regression.convolution_based import RocketRegressor
|
||||
from aeon.datasets import load_regression
|
||||
|
||||
X_train, y_train = load_regression("Covid3Month", split="train")
|
||||
X_test, y_test = load_regression("Covid3Month", split="test")
|
||||
|
||||
reg = RocketRegressor()
|
||||
reg.fit(X_train, y_train)
|
||||
predictions = reg.predict(X_test)
|
||||
```
|
||||
|
||||
### 3. Time Series Clustering
|
||||
|
||||
Group similar time series without labels. See `references/clustering.md` for methods.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.clustering import TimeSeriesKMeans
|
||||
|
||||
clusterer = TimeSeriesKMeans(
|
||||
n_clusters=3,
|
||||
distance="dtw",
|
||||
averaging_method="ba"
|
||||
)
|
||||
labels = clusterer.fit_predict(X_train)
|
||||
centers = clusterer.cluster_centers_
|
||||
```
|
||||
|
||||
### 4. Forecasting
|
||||
|
||||
Predict future time series values. See `references/forecasting.md` for forecasters.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.forecasting.arima import ARIMA
|
||||
|
||||
forecaster = ARIMA(order=(1, 1, 1))
|
||||
forecaster.fit(y_train)
|
||||
y_pred = forecaster.predict(fh=[1, 2, 3, 4, 5])
|
||||
```
|
||||
|
||||
### 5. Anomaly Detection
|
||||
|
||||
Identify unusual patterns or outliers. See `references/anomaly_detection.md` for detectors.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.anomaly_detection import STOMP
|
||||
|
||||
detector = STOMP(window_size=50)
|
||||
anomaly_scores = detector.fit_predict(y)
|
||||
|
||||
# Higher scores indicate anomalies
|
||||
threshold = np.percentile(anomaly_scores, 95)
|
||||
anomalies = anomaly_scores > threshold
|
||||
```
|
||||
|
||||
### 6. Segmentation
|
||||
|
||||
Partition time series into regions with change points. See `references/segmentation.md`.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.segmentation import ClaSPSegmenter
|
||||
|
||||
segmenter = ClaSPSegmenter()
|
||||
change_points = segmenter.fit_predict(y)
|
||||
```
|
||||
|
||||
### 7. Similarity Search
|
||||
|
||||
Find similar patterns within or across time series. See `references/similarity_search.md`.
|
||||
|
||||
**Quick Start:**
|
||||
```python
|
||||
from aeon.similarity_search import StompMotif
|
||||
|
||||
# Find recurring patterns
|
||||
motif_finder = StompMotif(window_size=50, k=3)
|
||||
motifs = motif_finder.fit_predict(y)
|
||||
```
|
||||
|
||||
## Feature Extraction and Transformations
|
||||
|
||||
Transform time series for feature engineering. See `references/transformations.md`.
|
||||
|
||||
**ROCKET Features:**
|
||||
```python
|
||||
from aeon.transformations.collection.convolution_based import RocketTransformer
|
||||
|
||||
rocket = RocketTransformer()
|
||||
X_features = rocket.fit_transform(X_train)
|
||||
|
||||
# Use features with any sklearn classifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
clf = RandomForestClassifier()
|
||||
clf.fit(X_features, y_train)
|
||||
```
|
||||
|
||||
**Statistical Features:**
|
||||
```python
|
||||
from aeon.transformations.collection.feature_based import Catch22
|
||||
|
||||
catch22 = Catch22()
|
||||
X_features = catch22.fit_transform(X_train)
|
||||
```
|
||||
|
||||
**Preprocessing:**
|
||||
```python
|
||||
from aeon.transformations.collection import MinMaxScaler, Normalizer
|
||||
|
||||
scaler = Normalizer() # Z-normalization
|
||||
X_normalized = scaler.fit_transform(X_train)
|
||||
```
|
||||
|
||||
## Distance Metrics
|
||||
|
||||
Specialized temporal distance measures. See `references/distances.md` for complete catalog.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from aeon.distances import dtw_distance, dtw_pairwise_distance
|
||||
|
||||
# Single distance
|
||||
distance = dtw_distance(x, y, window=0.1)
|
||||
|
||||
# Pairwise distances
|
||||
distance_matrix = dtw_pairwise_distance(X_train)
|
||||
|
||||
# Use with classifiers
|
||||
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
|
||||
|
||||
clf = KNeighborsTimeSeriesClassifier(
|
||||
n_neighbors=5,
|
||||
distance="dtw",
|
||||
distance_params={"window": 0.2}
|
||||
)
|
||||
```
|
||||
|
||||
**Available Distances:**
|
||||
- **Elastic**: DTW, DDTW, WDTW, ERP, EDR, LCSS, TWE, MSM
|
||||
- **Lock-step**: Euclidean, Manhattan, Minkowski
|
||||
- **Shape-based**: Shape DTW, SBD
|
||||
|
||||
## Deep Learning Networks
|
||||
|
||||
Neural architectures for time series. See `references/networks.md`.
|
||||
|
||||
**Architectures:**
|
||||
- Convolutional: `FCNClassifier`, `ResNetClassifier`, `InceptionTimeClassifier`
|
||||
- Recurrent: `RecurrentNetwork`, `TCNNetwork`
|
||||
- Autoencoders: `AEFCNClusterer`, `AEResNetClusterer`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from aeon.classification.deep_learning import InceptionTimeClassifier
|
||||
|
||||
clf = InceptionTimeClassifier(n_epochs=100, batch_size=32)
|
||||
clf.fit(X_train, y_train)
|
||||
predictions = clf.predict(X_test)
|
||||
```
|
||||
|
||||
## Datasets and Benchmarking
|
||||
|
||||
Load standard benchmarks and evaluate performance. See `references/datasets_benchmarking.md`.
|
||||
|
||||
**Load Datasets:**
|
||||
```python
|
||||
from aeon.datasets import load_classification, load_regression
|
||||
|
||||
# Classification
|
||||
X_train, y_train = load_classification("ArrowHead", split="train")
|
||||
|
||||
# Regression
|
||||
X_train, y_train = load_regression("Covid3Month", split="train")
|
||||
```
|
||||
|
||||
**Benchmarking:**
|
||||
```python
|
||||
from aeon.benchmarking import get_estimator_results
|
||||
|
||||
# Compare with published results
|
||||
published = get_estimator_results("ROCKET", "GunPoint")
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Classification Pipeline
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection import Normalizer
|
||||
from aeon.classification.convolution_based import RocketClassifier
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
pipeline = Pipeline([
|
||||
('normalize', Normalizer()),
|
||||
('classify', RocketClassifier())
|
||||
])
|
||||
|
||||
pipeline.fit(X_train, y_train)
|
||||
accuracy = pipeline.score(X_test, y_test)
|
||||
```
|
||||
|
||||
### Feature Extraction + Traditional ML
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection import RocketTransformer
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
||||
# Extract features
|
||||
rocket = RocketTransformer()
|
||||
X_train_features = rocket.fit_transform(X_train)
|
||||
X_test_features = rocket.transform(X_test)
|
||||
|
||||
# Train traditional ML
|
||||
clf = GradientBoostingClassifier()
|
||||
clf.fit(X_train_features, y_train)
|
||||
predictions = clf.predict(X_test_features)
|
||||
```
|
||||
|
||||
### Anomaly Detection with Visualization
|
||||
|
||||
```python
|
||||
from aeon.anomaly_detection import STOMP
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
detector = STOMP(window_size=50)
|
||||
scores = detector.fit_predict(y)
|
||||
|
||||
plt.figure(figsize=(15, 5))
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(y, label='Time Series')
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(scores, label='Anomaly Scores', color='red')
|
||||
plt.axhline(np.percentile(scores, 95), color='k', linestyle='--')
|
||||
plt.show()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Data Preparation
|
||||
|
||||
1. **Normalize**: Most algorithms benefit from z-normalization
|
||||
```python
|
||||
from aeon.transformations.collection import Normalizer
|
||||
normalizer = Normalizer()
|
||||
X_train = normalizer.fit_transform(X_train)
|
||||
X_test = normalizer.transform(X_test)
|
||||
```
|
||||
|
||||
2. **Handle Missing Values**: Impute before analysis
|
||||
```python
|
||||
from aeon.transformations.collection import SimpleImputer
|
||||
imputer = SimpleImputer(strategy='mean')
|
||||
X_train = imputer.fit_transform(X_train)
|
||||
```
|
||||
|
||||
3. **Check Data Format**: Aeon expects shape `(n_samples, n_channels, n_timepoints)`
|
||||
|
||||
### Model Selection
|
||||
|
||||
1. **Start Simple**: Begin with ROCKET variants before deep learning
|
||||
2. **Use Validation**: Split training data for hyperparameter tuning
|
||||
3. **Compare Baselines**: Test against simple methods (1-NN Euclidean, Naive)
|
||||
4. **Consider Resources**: ROCKET for speed, deep learning if GPU available
|
||||
|
||||
### Algorithm Selection Guide
|
||||
|
||||
**For Fast Prototyping:**
|
||||
- Classification: `MiniRocketClassifier`
|
||||
- Regression: `MiniRocketRegressor`
|
||||
- Clustering: `TimeSeriesKMeans` with Euclidean
|
||||
|
||||
**For Maximum Accuracy:**
|
||||
- Classification: `HIVECOTEV2`, `InceptionTimeClassifier`
|
||||
- Regression: `InceptionTimeRegressor`
|
||||
- Forecasting: `ARIMA`, `TCNForecaster`
|
||||
|
||||
**For Interpretability:**
|
||||
- Classification: `ShapeletTransformClassifier`, `Catch22Classifier`
|
||||
- Features: `Catch22`, `TSFresh`
|
||||
|
||||
**For Small Datasets:**
|
||||
- Distance-based: `KNeighborsTimeSeriesClassifier` with DTW
|
||||
- Avoid: Deep learning (requires large data)
|
||||
|
||||
## Reference Documentation
|
||||
|
||||
Detailed information available in `references/`:
|
||||
- `classification.md` - All classification algorithms
|
||||
- `regression.md` - Regression methods
|
||||
- `clustering.md` - Clustering algorithms
|
||||
- `forecasting.md` - Forecasting approaches
|
||||
- `anomaly_detection.md` - Anomaly detection methods
|
||||
- `segmentation.md` - Segmentation algorithms
|
||||
- `similarity_search.md` - Pattern matching and motif discovery
|
||||
- `transformations.md` - Feature extraction and preprocessing
|
||||
- `distances.md` - Time series distance metrics
|
||||
- `networks.md` - Deep learning architectures
|
||||
- `datasets_benchmarking.md` - Data loading and evaluation tools
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- Documentation: https://www.aeon-toolkit.org/
|
||||
- GitHub: https://github.com/aeon-toolkit/aeon
|
||||
- Examples: https://www.aeon-toolkit.org/en/stable/examples.html
|
||||
- API Reference: https://www.aeon-toolkit.org/en/stable/api_reference.html
|
||||
154
scientific-skills/aeon/references/anomaly_detection.md
Normal file
154
scientific-skills/aeon/references/anomaly_detection.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Anomaly Detection
|
||||
|
||||
Aeon provides anomaly detection methods for identifying unusual patterns in time series at both series and collection levels.
|
||||
|
||||
## Collection Anomaly Detectors
|
||||
|
||||
Detect anomalous time series within a collection:
|
||||
|
||||
- `ClassificationAdapter` - Adapts classifiers for anomaly detection
|
||||
- Train on normal data, flag outliers during prediction
|
||||
- **Use when**: Have labeled normal data, want classification-based approach
|
||||
|
||||
- `OutlierDetectionAdapter` - Wraps sklearn outlier detectors
|
||||
- Works with IsolationForest, LOF, OneClassSVM
|
||||
- **Use when**: Want to use sklearn anomaly detectors on collections
|
||||
|
||||
## Series Anomaly Detectors
|
||||
|
||||
Detect anomalous points or subsequences within a single time series.
|
||||
|
||||
### Distance-Based Methods
|
||||
|
||||
Use similarity metrics to identify anomalies:
|
||||
|
||||
- `CBLOF` - Cluster-Based Local Outlier Factor
|
||||
- Clusters data, identifies outliers based on cluster properties
|
||||
- **Use when**: Anomalies form sparse clusters
|
||||
|
||||
- `KMeansAD` - K-means based anomaly detection
|
||||
- Distance to nearest cluster center indicates anomaly
|
||||
- **Use when**: Normal patterns cluster well
|
||||
|
||||
- `LeftSTAMPi` - Left STAMP incremental
|
||||
- Matrix profile for online anomaly detection
|
||||
- **Use when**: Streaming data, need online detection
|
||||
|
||||
- `STOMP` - Scalable Time series Ordered-search Matrix Profile
|
||||
- Computes matrix profile for subsequence anomalies
|
||||
- **Use when**: Discord discovery, motif detection
|
||||
|
||||
- `MERLIN` - Matrix profile-based method
|
||||
- Efficient matrix profile computation
|
||||
- **Use when**: Large time series, need scalability
|
||||
|
||||
- `LOF` - Local Outlier Factor adapted for time series
|
||||
- Density-based outlier detection
|
||||
- **Use when**: Anomalies in low-density regions
|
||||
|
||||
- `ROCKAD` - ROCKET-based semi-supervised detection
|
||||
- Uses ROCKET features for anomaly identification
|
||||
- **Use when**: Have some labeled data, want feature-based approach
|
||||
|
||||
### Distribution-Based Methods
|
||||
|
||||
Analyze statistical distributions:
|
||||
|
||||
- `COPOD` - Copula-Based Outlier Detection
|
||||
- Models marginal and joint distributions
|
||||
- **Use when**: Multi-dimensional time series, complex dependencies
|
||||
|
||||
- `DWT_MLEAD` - Discrete Wavelet Transform Multi-Level Anomaly Detection
|
||||
- Decomposes series into frequency bands
|
||||
- **Use when**: Anomalies at specific frequencies
|
||||
|
||||
### Isolation-Based Methods
|
||||
|
||||
Use isolation principles:
|
||||
|
||||
- `IsolationForest` - Random forest-based isolation
|
||||
- Anomalies easier to isolate than normal points
|
||||
- **Use when**: High-dimensional data, no assumptions about distribution
|
||||
|
||||
- `OneClassSVM` - Support vector machine for novelty detection
|
||||
- Learns boundary around normal data
|
||||
- **Use when**: Well-defined normal region, need robust boundary
|
||||
|
||||
- `STRAY` - Streaming Robust Anomaly Detection
|
||||
- Robust to data distribution changes
|
||||
- **Use when**: Streaming data, distribution shifts
|
||||
|
||||
### External Library Integration
|
||||
|
||||
- `PyODAdapter` - Bridges PyOD library to aeon
|
||||
- Access 40+ PyOD anomaly detectors
|
||||
- **Use when**: Need specific PyOD algorithm
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.anomaly_detection import STOMP
|
||||
import numpy as np
|
||||
|
||||
# Create time series with anomaly
|
||||
y = np.concatenate([
|
||||
np.sin(np.linspace(0, 10, 100)),
|
||||
[5.0], # Anomaly spike
|
||||
np.sin(np.linspace(10, 20, 100))
|
||||
])
|
||||
|
||||
# Detect anomalies
|
||||
detector = STOMP(window_size=10)
|
||||
anomaly_scores = detector.fit_predict(y)
|
||||
|
||||
# Higher scores indicate more anomalous points
|
||||
threshold = np.percentile(anomaly_scores, 95)
|
||||
anomalies = anomaly_scores > threshold
|
||||
```
|
||||
|
||||
## Point vs Subsequence Anomalies
|
||||
|
||||
- **Point anomalies**: Single unusual values
|
||||
- Use: COPOD, DWT_MLEAD, IsolationForest
|
||||
|
||||
- **Subsequence anomalies** (discords): Unusual patterns
|
||||
- Use: STOMP, LeftSTAMPi, MERLIN
|
||||
|
||||
- **Collective anomalies**: Groups of points forming unusual pattern
|
||||
- Use: Matrix profile methods, clustering-based
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
Specialized metrics for anomaly detection:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking.metrics.anomaly_detection import (
|
||||
range_precision,
|
||||
range_recall,
|
||||
range_f_score,
|
||||
roc_auc_score
|
||||
)
|
||||
|
||||
# Range-based metrics account for window detection
|
||||
precision = range_precision(y_true, y_pred, alpha=0.5)
|
||||
recall = range_recall(y_true, y_pred, alpha=0.5)
|
||||
f1 = range_f_score(y_true, y_pred, alpha=0.5)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Speed priority**: KMeansAD, IsolationForest
|
||||
- **Accuracy priority**: STOMP, COPOD
|
||||
- **Streaming data**: LeftSTAMPi, STRAY
|
||||
- **Discord discovery**: STOMP, MERLIN
|
||||
- **Multi-dimensional**: COPOD, PyODAdapter
|
||||
- **Semi-supervised**: ROCKAD, OneClassSVM
|
||||
- **No training data**: IsolationForest, STOMP
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Normalize data**: Many methods sensitive to scale
|
||||
2. **Choose window size**: For matrix profile methods, window size critical
|
||||
3. **Set threshold**: Use percentile-based or domain-specific thresholds
|
||||
4. **Validate results**: Visualize detections to verify meaningfulness
|
||||
5. **Handle seasonality**: Detrend/deseasonalize before detection
|
||||
144
scientific-skills/aeon/references/classification.md
Normal file
144
scientific-skills/aeon/references/classification.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Time Series Classification
|
||||
|
||||
Aeon provides 13 categories of time series classifiers with scikit-learn compatible APIs.
|
||||
|
||||
## Convolution-Based Classifiers
|
||||
|
||||
Apply random convolutional transformations for efficient feature extraction:
|
||||
|
||||
- `Arsenal` - Ensemble of ROCKET classifiers with varied kernels
|
||||
- `HydraClassifier` - Multi-resolution convolution with dilation
|
||||
- `RocketClassifier` - Random convolution kernels with ridge regression
|
||||
- `MiniRocketClassifier` - Simplified ROCKET variant for speed
|
||||
- `MultiRocketClassifier` - Combines multiple ROCKET variants
|
||||
|
||||
**Use when**: Need fast, scalable classification with strong performance across diverse datasets.
|
||||
|
||||
## Deep Learning Classifiers
|
||||
|
||||
Neural network architectures optimized for temporal sequences:
|
||||
|
||||
- `FCNClassifier` - Fully convolutional network
|
||||
- `ResNetClassifier` - Residual networks with skip connections
|
||||
- `InceptionTimeClassifier` - Multi-scale inception modules
|
||||
- `TimeCNNClassifier` - Standard CNN for time series
|
||||
- `MLPClassifier` - Multi-layer perceptron baseline
|
||||
- `EncoderClassifier` - Generic encoder wrapper
|
||||
- `DisjointCNNClassifier` - Shapelet-focused architecture
|
||||
|
||||
**Use when**: Large datasets available, need end-to-end learning, or complex temporal patterns.
|
||||
|
||||
## Dictionary-Based Classifiers
|
||||
|
||||
Transform time series into symbolic representations:
|
||||
|
||||
- `BOSSEnsemble` - Bag-of-SFA-Symbols with ensemble voting
|
||||
- `TemporalDictionaryEnsemble` - Multiple dictionary methods combined
|
||||
- `WEASEL` - Word ExtrAction for time SEries cLassification
|
||||
- `MrSEQLClassifier` - Multiple symbolic sequence learning
|
||||
|
||||
**Use when**: Need interpretable models, sparse patterns, or symbolic reasoning.
|
||||
|
||||
## Distance-Based Classifiers
|
||||
|
||||
Leverage specialized time series distance metrics:
|
||||
|
||||
- `KNeighborsTimeSeriesClassifier` - k-NN with temporal distances (DTW, LCSS, ERP, etc.)
|
||||
- `ElasticEnsemble` - Combines multiple elastic distance measures
|
||||
- `ProximityForest` - Tree ensemble using distance-based splits
|
||||
|
||||
**Use when**: Small datasets, need similarity-based classification, or interpretable decisions.
|
||||
|
||||
## Feature-Based Classifiers
|
||||
|
||||
Extract statistical and signature features before classification:
|
||||
|
||||
- `Catch22Classifier` - 22 canonical time-series characteristics
|
||||
- `TSFreshClassifier` - Automated feature extraction via tsfresh
|
||||
- `SignatureClassifier` - Path signature transformations
|
||||
- `SummaryClassifier` - Summary statistics extraction
|
||||
- `FreshPRINCEClassifier` - Combines multiple feature extractors
|
||||
|
||||
**Use when**: Need interpretable features, domain expertise available, or feature engineering approach.
|
||||
|
||||
## Interval-Based Classifiers
|
||||
|
||||
Extract features from random or supervised intervals:
|
||||
|
||||
- `CanonicalIntervalForestClassifier` - Random interval features with decision trees
|
||||
- `DrCIFClassifier` - Diverse Representation CIF with catch22 features
|
||||
- `TimeSeriesForestClassifier` - Random intervals with summary statistics
|
||||
- `RandomIntervalClassifier` - Simple interval-based approach
|
||||
- `RandomIntervalSpectralEnsembleClassifier` - Spectral features from intervals
|
||||
- `SupervisedTimeSeriesForest` - Supervised interval selection
|
||||
|
||||
**Use when**: Discriminative patterns occur in specific time windows.
|
||||
|
||||
## Shapelet-Based Classifiers
|
||||
|
||||
Identify discriminative subsequences (shapelets):
|
||||
|
||||
- `ShapeletTransformClassifier` - Discovers and uses discriminative shapelets
|
||||
- `LearningShapeletClassifier` - Learns shapelets via gradient descent
|
||||
- `SASTClassifier` - Scalable approximate shapelet transform
|
||||
- `RDSTClassifier` - Random dilated shapelet transform
|
||||
|
||||
**Use when**: Need interpretable discriminative patterns or phase-invariant features.
|
||||
|
||||
## Hybrid Classifiers
|
||||
|
||||
Combine multiple classification paradigms:
|
||||
|
||||
- `HIVECOTEV1` - Hierarchical Vote Collective of Transformation-based Ensembles (version 1)
|
||||
- `HIVECOTEV2` - Enhanced version with updated components
|
||||
|
||||
**Use when**: Maximum accuracy required, computational resources available.
|
||||
|
||||
## Early Classification
|
||||
|
||||
Make predictions before observing entire time series:
|
||||
|
||||
- `TEASER` - Two-tier Early and Accurate Series Classifier
|
||||
- `ProbabilityThresholdEarlyClassifier` - Prediction when confidence exceeds threshold
|
||||
|
||||
**Use when**: Real-time decisions needed, or observations have cost.
|
||||
|
||||
## Ordinal Classification
|
||||
|
||||
Handle ordered class labels:
|
||||
|
||||
- `OrdinalTDE` - Temporal dictionary ensemble for ordinal outputs
|
||||
|
||||
**Use when**: Classes have natural ordering (e.g., severity levels).
|
||||
|
||||
## Composition Tools
|
||||
|
||||
Build custom pipelines and ensembles:
|
||||
|
||||
- `ClassifierPipeline` - Chain transformers with classifiers
|
||||
- `WeightedEnsembleClassifier` - Weighted combination of classifiers
|
||||
- `SklearnClassifierWrapper` - Adapt sklearn classifiers for time series
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.classification.convolution_based import RocketClassifier
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load data
|
||||
X_train, y_train = load_classification("GunPoint", split="train")
|
||||
X_test, y_test = load_classification("GunPoint", split="test")
|
||||
|
||||
# Train and predict
|
||||
clf = RocketClassifier()
|
||||
clf.fit(X_train, y_train)
|
||||
accuracy = clf.score(X_test, y_test)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Speed priority**: MiniRocketClassifier, Arsenal
|
||||
- **Accuracy priority**: HIVECOTEV2, InceptionTimeClassifier
|
||||
- **Interpretability**: ShapeletTransformClassifier, Catch22Classifier
|
||||
- **Small data**: KNeighborsTimeSeriesClassifier, Distance-based methods
|
||||
- **Large data**: Deep learning classifiers, ROCKET variants
|
||||
123
scientific-skills/aeon/references/clustering.md
Normal file
123
scientific-skills/aeon/references/clustering.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Time Series Clustering
|
||||
|
||||
Aeon provides clustering algorithms adapted for temporal data with specialized distance metrics and averaging methods.
|
||||
|
||||
## Partitioning Algorithms
|
||||
|
||||
Standard k-means/k-medoids adapted for time series:
|
||||
|
||||
- `TimeSeriesKMeans` - K-means with temporal distance metrics (DTW, Euclidean, etc.)
|
||||
- `TimeSeriesKMedoids` - Uses actual time series as cluster centers
|
||||
- `TimeSeriesKShape` - Shape-based clustering algorithm
|
||||
- `TimeSeriesKernelKMeans` - Kernel-based variant for nonlinear patterns
|
||||
|
||||
**Use when**: Known number of clusters, spherical cluster shapes expected.
|
||||
|
||||
## Large Dataset Methods
|
||||
|
||||
Efficient clustering for large collections:
|
||||
|
||||
- `TimeSeriesCLARA` - Clustering Large Applications with sampling
|
||||
- `TimeSeriesCLARANS` - Randomized search variant of CLARA
|
||||
|
||||
**Use when**: Dataset too large for standard k-medoids, need scalability.
|
||||
|
||||
## Elastic Distance Clustering
|
||||
|
||||
Specialized for alignment-based similarity:
|
||||
|
||||
- `KASBA` - K-means with shift-invariant elastic averaging
|
||||
- `ElasticSOM` - Self-organizing map using elastic distances
|
||||
|
||||
**Use when**: Time series have temporal shifts or warping.
|
||||
|
||||
## Spectral Methods
|
||||
|
||||
Graph-based clustering:
|
||||
|
||||
- `KSpectralCentroid` - Spectral clustering with centroid computation
|
||||
|
||||
**Use when**: Non-convex cluster shapes, need graph-based approach.
|
||||
|
||||
## Deep Learning Clustering
|
||||
|
||||
Neural network-based clustering with auto-encoders:
|
||||
|
||||
- `AEFCNClusterer` - Fully convolutional auto-encoder
|
||||
- `AEResNetClusterer` - Residual network auto-encoder
|
||||
- `AEDCNNClusterer` - Dilated CNN auto-encoder
|
||||
- `AEDRNNClusterer` - Dilated RNN auto-encoder
|
||||
- `AEBiGRUClusterer` - Bidirectional GRU auto-encoder
|
||||
- `AEAttentionBiGRUClusterer` - Attention-enhanced BiGRU auto-encoder
|
||||
|
||||
**Use when**: Large datasets, need learned representations, or complex patterns.
|
||||
|
||||
## Feature-Based Clustering
|
||||
|
||||
Transform to feature space before clustering:
|
||||
|
||||
- `Catch22Clusterer` - Clusters on 22 canonical features
|
||||
- `SummaryClusterer` - Uses summary statistics
|
||||
- `TSFreshClusterer` - Automated tsfresh features
|
||||
|
||||
**Use when**: Raw time series not informative, need interpretable features.
|
||||
|
||||
## Composition
|
||||
|
||||
Build custom clustering pipelines:
|
||||
|
||||
- `ClustererPipeline` - Chain transformers with clusterers
|
||||
|
||||
## Averaging Methods
|
||||
|
||||
Compute cluster centers for time series:
|
||||
|
||||
- `mean_average` - Arithmetic mean
|
||||
- `ba_average` - Barycentric averaging with DTW
|
||||
- `kasba_average` - Shift-invariant averaging
|
||||
- `shift_invariant_average` - General shift-invariant method
|
||||
|
||||
**Use when**: Need representative cluster centers for visualization or initialization.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.clustering import TimeSeriesKMeans
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load data (using classification data for clustering)
|
||||
X_train, _ = load_classification("GunPoint", split="train")
|
||||
|
||||
# Cluster time series
|
||||
clusterer = TimeSeriesKMeans(
|
||||
n_clusters=3,
|
||||
distance="dtw", # Use DTW distance
|
||||
averaging_method="ba" # Barycentric averaging
|
||||
)
|
||||
labels = clusterer.fit_predict(X_train)
|
||||
centers = clusterer.cluster_centers_
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Speed priority**: TimeSeriesKMeans with Euclidean distance
|
||||
- **Temporal alignment**: KASBA, TimeSeriesKMeans with DTW
|
||||
- **Large datasets**: TimeSeriesCLARA, TimeSeriesCLARANS
|
||||
- **Complex patterns**: Deep learning clusterers
|
||||
- **Interpretability**: Catch22Clusterer, SummaryClusterer
|
||||
- **Non-convex clusters**: KSpectralCentroid
|
||||
|
||||
## Distance Metrics
|
||||
|
||||
Compatible distance metrics include:
|
||||
- Euclidean, Manhattan, Minkowski (lock-step)
|
||||
- DTW, DDTW, WDTW (elastic with alignment)
|
||||
- ERP, EDR, LCSS (edit-based)
|
||||
- MSM, TWE (specialized elastic)
|
||||
|
||||
## Evaluation
|
||||
|
||||
Use clustering metrics from sklearn or aeon benchmarking:
|
||||
- Silhouette score
|
||||
- Davies-Bouldin index
|
||||
- Calinski-Harabasz index
|
||||
387
scientific-skills/aeon/references/datasets_benchmarking.md
Normal file
387
scientific-skills/aeon/references/datasets_benchmarking.md
Normal file
@@ -0,0 +1,387 @@
|
||||
# Datasets and Benchmarking
|
||||
|
||||
Aeon provides comprehensive tools for loading datasets and benchmarking time series algorithms.
|
||||
|
||||
## Dataset Loading
|
||||
|
||||
### Task-Specific Loaders
|
||||
|
||||
**Classification Datasets**:
|
||||
```python
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load train/test split
|
||||
X_train, y_train = load_classification("GunPoint", split="train")
|
||||
X_test, y_test = load_classification("GunPoint", split="test")
|
||||
|
||||
# Load entire dataset
|
||||
X, y = load_classification("GunPoint")
|
||||
```
|
||||
|
||||
**Regression Datasets**:
|
||||
```python
|
||||
from aeon.datasets import load_regression
|
||||
|
||||
X_train, y_train = load_regression("Covid3Month", split="train")
|
||||
X_test, y_test = load_regression("Covid3Month", split="test")
|
||||
|
||||
# Bulk download
|
||||
from aeon.datasets import download_all_regression
|
||||
download_all_regression() # Downloads Monash TSER archive
|
||||
```
|
||||
|
||||
**Forecasting Datasets**:
|
||||
```python
|
||||
from aeon.datasets import load_forecasting
|
||||
|
||||
# Load from forecastingdata.org
|
||||
y, X = load_forecasting("airline", return_X_y=True)
|
||||
```
|
||||
|
||||
**Anomaly Detection Datasets**:
|
||||
```python
|
||||
from aeon.datasets import load_anomaly_detection
|
||||
|
||||
X, y = load_anomaly_detection("NAB_realKnownCause")
|
||||
```
|
||||
|
||||
### File Format Loaders
|
||||
|
||||
**Load from .ts files**:
|
||||
```python
|
||||
from aeon.datasets import load_from_ts_file
|
||||
|
||||
X, y = load_from_ts_file("path/to/data.ts")
|
||||
```
|
||||
|
||||
**Load from .tsf files**:
|
||||
```python
|
||||
from aeon.datasets import load_from_tsf_file
|
||||
|
||||
df, metadata = load_from_tsf_file("path/to/data.tsf")
|
||||
```
|
||||
|
||||
**Load from ARFF files**:
|
||||
```python
|
||||
from aeon.datasets import load_from_arff_file
|
||||
|
||||
X, y = load_from_arff_file("path/to/data.arff")
|
||||
```
|
||||
|
||||
**Load from TSV files**:
|
||||
```python
|
||||
from aeon.datasets import load_from_tsv_file
|
||||
|
||||
data = load_from_tsv_file("path/to/data.tsv")
|
||||
```
|
||||
|
||||
**Load TimeEval CSV**:
|
||||
```python
|
||||
from aeon.datasets import load_from_timeeval_csv_file
|
||||
|
||||
X, y = load_from_timeeval_csv_file("path/to/timeeval.csv")
|
||||
```
|
||||
|
||||
### Writing Datasets
|
||||
|
||||
**Write to .ts format**:
|
||||
```python
|
||||
from aeon.datasets import write_to_ts_file
|
||||
|
||||
write_to_ts_file(X, "output.ts", y=y, problem_name="MyDataset")
|
||||
```
|
||||
|
||||
**Write to ARFF format**:
|
||||
```python
|
||||
from aeon.datasets import write_to_arff_file
|
||||
|
||||
write_to_arff_file(X, "output.arff", y=y)
|
||||
```
|
||||
|
||||
## Built-in Datasets
|
||||
|
||||
Aeon includes several benchmark datasets for quick testing:
|
||||
|
||||
### Classification
|
||||
- `ArrowHead` - Shape classification
|
||||
- `GunPoint` - Gesture recognition
|
||||
- `ItalyPowerDemand` - Energy demand
|
||||
- `BasicMotions` - Motion classification
|
||||
- And 100+ more from UCR/UEA archives
|
||||
|
||||
### Regression
|
||||
- `Covid3Month` - COVID forecasting
|
||||
- Various datasets from Monash TSER archive
|
||||
|
||||
### Segmentation
|
||||
- Time series segmentation datasets
|
||||
- Human activity data
|
||||
- Sensor data collections
|
||||
|
||||
### Special Collections
|
||||
- `RehabPile` - Rehabilitation data (classification & regression)
|
||||
|
||||
## Dataset Metadata
|
||||
|
||||
Get information about datasets:
|
||||
|
||||
```python
|
||||
from aeon.datasets import get_dataset_meta_data
|
||||
|
||||
metadata = get_dataset_meta_data("GunPoint")
|
||||
print(metadata)
|
||||
# {'n_train': 50, 'n_test': 150, 'length': 150, 'n_classes': 2, ...}
|
||||
```
|
||||
|
||||
## Benchmarking Tools
|
||||
|
||||
### Loading Published Results
|
||||
|
||||
Access pre-computed benchmark results:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking import get_estimator_results
|
||||
|
||||
# Get results for specific algorithm on dataset
|
||||
results = get_estimator_results(
|
||||
estimator_name="ROCKET",
|
||||
dataset_name="GunPoint"
|
||||
)
|
||||
|
||||
# Get all available estimators for a dataset
|
||||
estimators = get_available_estimators("GunPoint")
|
||||
```
|
||||
|
||||
### Resampling Strategies
|
||||
|
||||
Create reproducible train/test splits:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking import stratified_resample
|
||||
|
||||
# Stratified resampling maintaining class distribution
|
||||
X_train, X_test, y_train, y_test = stratified_resample(
|
||||
X, y,
|
||||
random_state=42,
|
||||
test_size=0.3
|
||||
)
|
||||
```
|
||||
|
||||
### Performance Metrics
|
||||
|
||||
Specialized metrics for time series tasks:
|
||||
|
||||
**Anomaly Detection Metrics**:
|
||||
```python
|
||||
from aeon.benchmarking.metrics.anomaly_detection import (
|
||||
range_precision,
|
||||
range_recall,
|
||||
range_f_score,
|
||||
range_roc_auc_score
|
||||
)
|
||||
|
||||
# Range-based metrics for window detection
|
||||
precision = range_precision(y_true, y_pred, alpha=0.5)
|
||||
recall = range_recall(y_true, y_pred, alpha=0.5)
|
||||
f1 = range_f_score(y_true, y_pred, alpha=0.5)
|
||||
auc = range_roc_auc_score(y_true, y_scores)
|
||||
```
|
||||
|
||||
**Clustering Metrics**:
|
||||
```python
|
||||
from aeon.benchmarking.metrics.clustering import clustering_accuracy
|
||||
|
||||
# Clustering accuracy with label matching
|
||||
accuracy = clustering_accuracy(y_true, y_pred)
|
||||
```
|
||||
|
||||
**Segmentation Metrics**:
|
||||
```python
|
||||
from aeon.benchmarking.metrics.segmentation import (
|
||||
count_error,
|
||||
hausdorff_error
|
||||
)
|
||||
|
||||
# Number of change points difference
|
||||
count_err = count_error(y_true, y_pred)
|
||||
|
||||
# Maximum distance between predicted and true change points
|
||||
hausdorff_err = hausdorff_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
### Statistical Testing
|
||||
|
||||
Post-hoc analysis for algorithm comparison:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking import (
|
||||
nemenyi_test,
|
||||
wilcoxon_test
|
||||
)
|
||||
|
||||
# Nemenyi test for multiple algorithms
|
||||
results = nemenyi_test(scores_matrix, alpha=0.05)
|
||||
|
||||
# Pairwise Wilcoxon signed-rank test
|
||||
stat, p_value = wilcoxon_test(scores_alg1, scores_alg2)
|
||||
```
|
||||
|
||||
## Benchmark Collections
|
||||
|
||||
### UCR/UEA Time Series Archives
|
||||
|
||||
Access to comprehensive benchmark repositories:
|
||||
|
||||
```python
|
||||
# Classification: 112 univariate + 30 multivariate datasets
|
||||
X_train, y_train = load_classification("Chinatown", split="train")
|
||||
|
||||
# Automatically downloads from timeseriesclassification.com
|
||||
```
|
||||
|
||||
### Monash Forecasting Archive
|
||||
|
||||
```python
|
||||
# Load forecasting datasets
|
||||
y = load_forecasting("nn5_daily", return_X_y=False)
|
||||
```
|
||||
|
||||
### Published Benchmark Results
|
||||
|
||||
Pre-computed results from major competitions:
|
||||
|
||||
- 2017 Univariate Bake-off
|
||||
- 2021 Multivariate Classification
|
||||
- 2023 Univariate Bake-off
|
||||
|
||||
## Workflow Example
|
||||
|
||||
Complete benchmarking workflow:
|
||||
|
||||
```python
|
||||
from aeon.datasets import load_classification
|
||||
from aeon.classification.convolution_based import RocketClassifier
|
||||
from aeon.benchmarking import get_estimator_results
|
||||
from sklearn.metrics import accuracy_score
|
||||
import numpy as np
|
||||
|
||||
# Load dataset
|
||||
dataset_name = "GunPoint"
|
||||
X_train, y_train = load_classification(dataset_name, split="train")
|
||||
X_test, y_test = load_classification(dataset_name, split="test")
|
||||
|
||||
# Train model
|
||||
clf = RocketClassifier(n_kernels=10000, random_state=42)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
# Evaluate
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
|
||||
# Compare with published results
|
||||
published = get_estimator_results("ROCKET", dataset_name)
|
||||
print(f"Published ROCKET accuracy: {published['accuracy']:.4f}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Standard Splits
|
||||
|
||||
For reproducibility, use provided train/test splits:
|
||||
|
||||
```python
|
||||
# Good: Use standard splits
|
||||
X_train, y_train = load_classification("GunPoint", split="train")
|
||||
X_test, y_test = load_classification("GunPoint", split="test")
|
||||
|
||||
# Avoid: Creating custom splits
|
||||
X, y = load_classification("GunPoint")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
```
|
||||
|
||||
### 2. Set Random Seeds
|
||||
|
||||
Ensure reproducibility:
|
||||
|
||||
```python
|
||||
clf = RocketClassifier(random_state=42)
|
||||
results = stratified_resample(X, y, random_state=42)
|
||||
```
|
||||
|
||||
### 3. Report Multiple Metrics
|
||||
|
||||
Don't rely on single metric:
|
||||
|
||||
```python
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score
|
||||
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
f1 = f1_score(y_test, y_pred, average='weighted')
|
||||
precision = precision_score(y_test, y_pred, average='weighted')
|
||||
```
|
||||
|
||||
### 4. Cross-Validation
|
||||
|
||||
For robust evaluation on small datasets:
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
scores = cross_val_score(
|
||||
clf, X_train, y_train,
|
||||
cv=5,
|
||||
scoring='accuracy'
|
||||
)
|
||||
print(f"CV Accuracy: {scores.mean():.4f} (+/- {scores.std():.4f})")
|
||||
```
|
||||
|
||||
### 5. Compare Against Baselines
|
||||
|
||||
Always compare with simple baselines:
|
||||
|
||||
```python
|
||||
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
|
||||
|
||||
# Simple baseline: 1-NN with Euclidean distance
|
||||
baseline = KNeighborsTimeSeriesClassifier(n_neighbors=1, distance="euclidean")
|
||||
baseline.fit(X_train, y_train)
|
||||
baseline_acc = baseline.score(X_test, y_test)
|
||||
|
||||
print(f"Baseline: {baseline_acc:.4f}")
|
||||
print(f"Your model: {accuracy:.4f}")
|
||||
```
|
||||
|
||||
### 6. Statistical Significance
|
||||
|
||||
Test if improvements are statistically significant:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking import wilcoxon_test
|
||||
|
||||
# Run on multiple datasets
|
||||
accuracies_alg1 = [0.85, 0.92, 0.78, 0.88]
|
||||
accuracies_alg2 = [0.83, 0.90, 0.76, 0.86]
|
||||
|
||||
stat, p_value = wilcoxon_test(accuracies_alg1, accuracies_alg2)
|
||||
if p_value < 0.05:
|
||||
print("Difference is statistically significant")
|
||||
```
|
||||
|
||||
## Dataset Discovery
|
||||
|
||||
Find datasets matching criteria:
|
||||
|
||||
```python
|
||||
# List all available classification datasets
|
||||
from aeon.datasets import get_available_datasets
|
||||
|
||||
datasets = get_available_datasets("classification")
|
||||
print(f"Found {len(datasets)} classification datasets")
|
||||
|
||||
# Filter by properties
|
||||
univariate_datasets = [
|
||||
d for d in datasets
|
||||
if get_dataset_meta_data(d)['n_channels'] == 1
|
||||
]
|
||||
```
|
||||
256
scientific-skills/aeon/references/distances.md
Normal file
256
scientific-skills/aeon/references/distances.md
Normal file
@@ -0,0 +1,256 @@
|
||||
# Distance Metrics
|
||||
|
||||
Aeon provides specialized distance functions for measuring similarity between time series, compatible with both aeon and scikit-learn estimators.
|
||||
|
||||
## Distance Categories
|
||||
|
||||
### Elastic Distances
|
||||
|
||||
Allow flexible temporal alignment between series:
|
||||
|
||||
**Dynamic Time Warping Family:**
|
||||
- `dtw` - Classic Dynamic Time Warping
|
||||
- `ddtw` - Derivative DTW (compares derivatives)
|
||||
- `wdtw` - Weighted DTW (penalizes warping by location)
|
||||
- `wddtw` - Weighted Derivative DTW
|
||||
- `shape_dtw` - Shape-based DTW
|
||||
|
||||
**Edit-Based:**
|
||||
- `erp` - Edit distance with Real Penalty
|
||||
- `edr` - Edit Distance on Real sequences
|
||||
- `lcss` - Longest Common SubSequence
|
||||
- `twe` - Time Warp Edit distance
|
||||
|
||||
**Specialized:**
|
||||
- `msm` - Move-Split-Merge distance
|
||||
- `adtw` - Amerced DTW
|
||||
- `sbd` - Shape-Based Distance
|
||||
|
||||
**Use when**: Time series may have temporal shifts, speed variations, or phase differences.
|
||||
|
||||
### Lock-Step Distances
|
||||
|
||||
Compare time series point-by-point without alignment:
|
||||
|
||||
- `euclidean` - Euclidean distance (L2 norm)
|
||||
- `manhattan` - Manhattan distance (L1 norm)
|
||||
- `minkowski` - Generalized Minkowski distance (Lp norm)
|
||||
- `squared` - Squared Euclidean distance
|
||||
|
||||
**Use when**: Series already aligned, need computational speed, or no temporal warping expected.
|
||||
|
||||
## Usage Patterns
|
||||
|
||||
### Computing Single Distance
|
||||
|
||||
```python
|
||||
from aeon.distances import dtw_distance
|
||||
|
||||
# Distance between two time series
|
||||
distance = dtw_distance(x, y)
|
||||
|
||||
# With window constraint (Sakoe-Chiba band)
|
||||
distance = dtw_distance(x, y, window=0.1)
|
||||
```
|
||||
|
||||
### Pairwise Distance Matrix
|
||||
|
||||
```python
|
||||
from aeon.distances import dtw_pairwise_distance
|
||||
|
||||
# All pairwise distances in collection
|
||||
X = [series1, series2, series3, series4]
|
||||
distance_matrix = dtw_pairwise_distance(X)
|
||||
|
||||
# Cross-collection distances
|
||||
distance_matrix = dtw_pairwise_distance(X_train, X_test)
|
||||
```
|
||||
|
||||
### Cost Matrix and Alignment Path
|
||||
|
||||
```python
|
||||
from aeon.distances import dtw_cost_matrix, dtw_alignment_path
|
||||
|
||||
# Get full cost matrix
|
||||
cost_matrix = dtw_cost_matrix(x, y)
|
||||
|
||||
# Get optimal alignment path
|
||||
path = dtw_alignment_path(x, y)
|
||||
# Returns indices: [(0,0), (1,1), (2,1), (2,2), ...]
|
||||
```
|
||||
|
||||
### Using with Estimators
|
||||
|
||||
```python
|
||||
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
|
||||
|
||||
# Use DTW distance in classifier
|
||||
clf = KNeighborsTimeSeriesClassifier(
|
||||
n_neighbors=5,
|
||||
distance="dtw",
|
||||
distance_params={"window": 0.2}
|
||||
)
|
||||
clf.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Distance Parameters
|
||||
|
||||
### Window Constraints
|
||||
|
||||
Limit warping path deviation (improves speed and prevents pathological warping):
|
||||
|
||||
```python
|
||||
# Sakoe-Chiba band: window as fraction of series length
|
||||
dtw_distance(x, y, window=0.1) # Allow 10% deviation
|
||||
|
||||
# Itakura parallelogram: slopes constrain path
|
||||
dtw_distance(x, y, itakura_max_slope=2.0)
|
||||
```
|
||||
|
||||
### Normalization
|
||||
|
||||
Control whether to z-normalize series before distance computation:
|
||||
|
||||
```python
|
||||
# Most elastic distances support normalization
|
||||
distance = dtw_distance(x, y, normalize=True)
|
||||
```
|
||||
|
||||
### Distance-Specific Parameters
|
||||
|
||||
```python
|
||||
# ERP: penalty for gaps
|
||||
distance = erp_distance(x, y, g=0.5)
|
||||
|
||||
# TWE: stiffness and penalty parameters
|
||||
distance = twe_distance(x, y, nu=0.001, lmbda=1.0)
|
||||
|
||||
# LCSS: epsilon threshold for matching
|
||||
distance = lcss_distance(x, y, epsilon=0.5)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
### By Use Case:
|
||||
|
||||
**Temporal misalignment**: DTW, DDTW, WDTW
|
||||
**Speed variations**: DTW with window constraint
|
||||
**Shape similarity**: Shape DTW, SBD
|
||||
**Edit operations**: ERP, EDR, LCSS
|
||||
**Derivative matching**: DDTW
|
||||
**Computational speed**: Euclidean, Manhattan
|
||||
**Outlier robustness**: Manhattan, LCSS
|
||||
|
||||
### By Computational Cost:
|
||||
|
||||
**Fastest**: Euclidean (O(n))
|
||||
**Fast**: Constrained DTW (O(nw) where w is window)
|
||||
**Medium**: Full DTW (O(n²))
|
||||
**Slower**: Complex elastic distances (ERP, TWE, MSM)
|
||||
|
||||
## Quick Reference Table
|
||||
|
||||
| Distance | Alignment | Speed | Robustness | Interpretability |
|
||||
|----------|-----------|-------|------------|------------------|
|
||||
| Euclidean | Lock-step | Very Fast | Low | High |
|
||||
| DTW | Elastic | Medium | Medium | Medium |
|
||||
| DDTW | Elastic | Medium | High | Medium |
|
||||
| WDTW | Elastic | Medium | Medium | Medium |
|
||||
| ERP | Edit-based | Slow | High | Low |
|
||||
| LCSS | Edit-based | Slow | Very High | Low |
|
||||
| Shape DTW | Elastic | Medium | Medium | High |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Normalization
|
||||
|
||||
Most distances sensitive to scale; normalize when appropriate:
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection import Normalizer
|
||||
|
||||
normalizer = Normalizer()
|
||||
X_normalized = normalizer.fit_transform(X)
|
||||
```
|
||||
|
||||
### 2. Window Constraints
|
||||
|
||||
For DTW variants, use window constraints for speed and better generalization:
|
||||
|
||||
```python
|
||||
# Start with 10-20% window
|
||||
distance = dtw_distance(x, y, window=0.1)
|
||||
```
|
||||
|
||||
### 3. Series Length
|
||||
|
||||
- Equal-length required: Most lock-step distances
|
||||
- Unequal-length supported: Elastic distances (DTW, ERP, etc.)
|
||||
|
||||
### 4. Multivariate Series
|
||||
|
||||
Most distances support multivariate time series:
|
||||
|
||||
```python
|
||||
# x.shape = (n_channels, n_timepoints)
|
||||
distance = dtw_distance(x_multivariate, y_multivariate)
|
||||
```
|
||||
|
||||
### 5. Performance Optimization
|
||||
|
||||
- Use numba-compiled implementations (default in aeon)
|
||||
- Consider lock-step distances if alignment not needed
|
||||
- Use windowed DTW instead of full DTW
|
||||
- Precompute distance matrices for repeated use
|
||||
|
||||
### 6. Choosing the Right Distance
|
||||
|
||||
```python
|
||||
# Quick decision tree:
|
||||
if series_aligned:
|
||||
use_distance = "euclidean"
|
||||
elif need_speed:
|
||||
use_distance = "dtw" # with window constraint
|
||||
elif temporal_shifts_expected:
|
||||
use_distance = "dtw" or "shape_dtw"
|
||||
elif outliers_present:
|
||||
use_distance = "lcss" or "manhattan"
|
||||
elif derivatives_matter:
|
||||
use_distance = "ddtw" or "wddtw"
|
||||
```
|
||||
|
||||
## Integration with scikit-learn
|
||||
|
||||
Aeon distances work with sklearn estimators:
|
||||
|
||||
```python
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from aeon.distances import dtw_pairwise_distance
|
||||
|
||||
# Precompute distance matrix
|
||||
X_train_distances = dtw_pairwise_distance(X_train)
|
||||
|
||||
# Use with sklearn
|
||||
clf = KNeighborsClassifier(metric='precomputed')
|
||||
clf.fit(X_train_distances, y_train)
|
||||
```
|
||||
|
||||
## Available Distance Functions
|
||||
|
||||
Get list of all available distances:
|
||||
|
||||
```python
|
||||
from aeon.distances import get_distance_function_names
|
||||
|
||||
print(get_distance_function_names())
|
||||
# ['dtw', 'ddtw', 'wdtw', 'euclidean', 'erp', 'edr', ...]
|
||||
```
|
||||
|
||||
Retrieve specific distance function:
|
||||
|
||||
```python
|
||||
from aeon.distances import get_distance_function
|
||||
|
||||
distance_func = get_distance_function("dtw")
|
||||
result = distance_func(x, y, window=0.1)
|
||||
```
|
||||
140
scientific-skills/aeon/references/forecasting.md
Normal file
140
scientific-skills/aeon/references/forecasting.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# Time Series Forecasting
|
||||
|
||||
Aeon provides forecasting algorithms for predicting future time series values.
|
||||
|
||||
## Naive and Baseline Methods
|
||||
|
||||
Simple forecasting strategies for comparison:
|
||||
|
||||
- `NaiveForecaster` - Multiple strategies: last value, mean, seasonal naive
|
||||
- Parameters: `strategy` ("last", "mean", "seasonal"), `sp` (seasonal period)
|
||||
- **Use when**: Establishing baselines or simple patterns
|
||||
|
||||
## Statistical Models
|
||||
|
||||
Classical time series forecasting methods:
|
||||
|
||||
### ARIMA
|
||||
- `ARIMA` - AutoRegressive Integrated Moving Average
|
||||
- Parameters: `p` (AR order), `d` (differencing), `q` (MA order)
|
||||
- **Use when**: Linear patterns, stationary or difference-stationary series
|
||||
|
||||
### Exponential Smoothing
|
||||
- `ETS` - Error-Trend-Seasonal decomposition
|
||||
- Parameters: `error`, `trend`, `seasonal` types
|
||||
- **Use when**: Trend and seasonal patterns present
|
||||
|
||||
### Threshold Autoregressive
|
||||
- `TAR` - Threshold Autoregressive model for regime switching
|
||||
- `AutoTAR` - Automated threshold discovery
|
||||
- **Use when**: Series exhibits different behaviors in different regimes
|
||||
|
||||
### Theta Method
|
||||
- `Theta` - Classical Theta forecasting
|
||||
- Parameters: `theta`, `weights` for decomposition
|
||||
- **Use when**: Simple but effective baseline needed
|
||||
|
||||
### Time-Varying Parameter
|
||||
- `TVP` - Time-varying parameter model with Kalman filtering
|
||||
- **Use when**: Parameters change over time
|
||||
|
||||
## Deep Learning Forecasters
|
||||
|
||||
Neural networks for complex temporal patterns:
|
||||
|
||||
- `TCNForecaster` - Temporal Convolutional Network
|
||||
- Dilated convolutions for large receptive fields
|
||||
- **Use when**: Long sequences, need non-recurrent architecture
|
||||
|
||||
- `DeepARNetwork` - Probabilistic forecasting with RNNs
|
||||
- Provides prediction intervals
|
||||
- **Use when**: Need probabilistic forecasts, uncertainty quantification
|
||||
|
||||
## Regression-Based Forecasting
|
||||
|
||||
Apply regression to lagged features:
|
||||
|
||||
- `RegressionForecaster` - Wraps regressors for forecasting
|
||||
- Parameters: `window_length`, `horizon`
|
||||
- **Use when**: Want to use any regressor as forecaster
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.forecasting.naive import NaiveForecaster
|
||||
from aeon.forecasting.arima import ARIMA
|
||||
import numpy as np
|
||||
|
||||
# Create time series
|
||||
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
|
||||
# Naive baseline
|
||||
naive = NaiveForecaster(strategy="last")
|
||||
naive.fit(y)
|
||||
forecast_naive = naive.predict(fh=[1, 2, 3])
|
||||
|
||||
# ARIMA model
|
||||
arima = ARIMA(order=(1, 1, 1))
|
||||
arima.fit(y)
|
||||
forecast_arima = arima.predict(fh=[1, 2, 3])
|
||||
```
|
||||
|
||||
## Forecasting Horizon
|
||||
|
||||
The forecasting horizon (`fh`) specifies which future time points to predict:
|
||||
|
||||
```python
|
||||
# Relative horizon (next 3 steps)
|
||||
fh = [1, 2, 3]
|
||||
|
||||
# Absolute horizon (specific time indices)
|
||||
from aeon.forecasting.base import ForecastingHorizon
|
||||
fh = ForecastingHorizon([11, 12, 13], is_relative=False)
|
||||
```
|
||||
|
||||
## Model Selection
|
||||
|
||||
- **Baseline**: NaiveForecaster with seasonal strategy
|
||||
- **Linear patterns**: ARIMA
|
||||
- **Trend + seasonality**: ETS
|
||||
- **Regime changes**: TAR, AutoTAR
|
||||
- **Complex patterns**: TCNForecaster
|
||||
- **Probabilistic**: DeepARNetwork
|
||||
- **Long sequences**: TCNForecaster
|
||||
- **Short sequences**: ARIMA, ETS
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
Use standard forecasting metrics:
|
||||
|
||||
```python
|
||||
from aeon.performance_metrics.forecasting import (
|
||||
mean_absolute_error,
|
||||
mean_squared_error,
|
||||
mean_absolute_percentage_error
|
||||
)
|
||||
|
||||
# Calculate error
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
mape = mean_absolute_percentage_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
## Exogenous Variables
|
||||
|
||||
Many forecasters support exogenous features:
|
||||
|
||||
```python
|
||||
# Train with exogenous variables
|
||||
forecaster.fit(y, X=X_train)
|
||||
|
||||
# Predict requires future exogenous values
|
||||
y_pred = forecaster.predict(fh=[1, 2, 3], X=X_test)
|
||||
```
|
||||
|
||||
## Base Classes
|
||||
|
||||
- `BaseForecaster` - Abstract base for all forecasters
|
||||
- `BaseDeepForecaster` - Base for deep learning forecasters
|
||||
|
||||
Extend these to implement custom forecasting algorithms.
|
||||
289
scientific-skills/aeon/references/networks.md
Normal file
289
scientific-skills/aeon/references/networks.md
Normal file
@@ -0,0 +1,289 @@
|
||||
# Deep Learning Networks
|
||||
|
||||
Aeon provides neural network architectures specifically designed for time series tasks. These networks serve as building blocks for classification, regression, clustering, and forecasting.
|
||||
|
||||
## Core Network Architectures
|
||||
|
||||
### Convolutional Networks
|
||||
|
||||
**FCNNetwork** - Fully Convolutional Network
|
||||
- Three convolutional blocks with batch normalization
|
||||
- Global average pooling for dimensionality reduction
|
||||
- **Use when**: Need simple yet effective CNN baseline
|
||||
|
||||
**ResNetNetwork** - Residual Network
|
||||
- Residual blocks with skip connections
|
||||
- Prevents vanishing gradients in deep networks
|
||||
- **Use when**: Deep networks needed, training stability important
|
||||
|
||||
**InceptionNetwork** - Inception Modules
|
||||
- Multi-scale feature extraction with parallel convolutions
|
||||
- Different kernel sizes capture patterns at various scales
|
||||
- **Use when**: Patterns exist at multiple temporal scales
|
||||
|
||||
**TimeCNNNetwork** - Standard CNN
|
||||
- Basic convolutional architecture
|
||||
- **Use when**: Simple CNN sufficient, interpretability valued
|
||||
|
||||
**DisjointCNNNetwork** - Separate Pathways
|
||||
- Disjoint convolutional pathways
|
||||
- **Use when**: Different feature extraction strategies needed
|
||||
|
||||
**DCNNNetwork** - Dilated CNN
|
||||
- Dilated convolutions for large receptive fields
|
||||
- **Use when**: Long-range dependencies without many layers
|
||||
|
||||
### Recurrent Networks
|
||||
|
||||
**RecurrentNetwork** - RNN/LSTM/GRU
|
||||
- Configurable cell type (RNN, LSTM, GRU)
|
||||
- Sequential modeling of temporal dependencies
|
||||
- **Use when**: Sequential dependencies critical, variable-length series
|
||||
|
||||
### Temporal Convolutional Network
|
||||
|
||||
**TCNNetwork** - Temporal Convolutional Network
|
||||
- Dilated causal convolutions
|
||||
- Large receptive field without recurrence
|
||||
- **Use when**: Long sequences, need parallelizable architecture
|
||||
|
||||
### Multi-Layer Perceptron
|
||||
|
||||
**MLPNetwork** - Basic Feedforward
|
||||
- Simple fully-connected layers
|
||||
- Flattens time series before processing
|
||||
- **Use when**: Baseline needed, computational limits, or simple patterns
|
||||
|
||||
## Encoder-Based Architectures
|
||||
|
||||
Networks designed for representation learning and clustering.
|
||||
|
||||
### Autoencoder Variants
|
||||
|
||||
**EncoderNetwork** - Generic Encoder
|
||||
- Flexible encoder structure
|
||||
- **Use when**: Custom encoding needed
|
||||
|
||||
**AEFCNNetwork** - FCN-based Autoencoder
|
||||
- Fully convolutional encoder-decoder
|
||||
- **Use when**: Need convolutional representation learning
|
||||
|
||||
**AEResNetNetwork** - ResNet Autoencoder
|
||||
- Residual blocks in encoder-decoder
|
||||
- **Use when**: Deep autoencoding with skip connections
|
||||
|
||||
**AEDCNNNetwork** - Dilated CNN Autoencoder
|
||||
- Dilated convolutions for compression
|
||||
- **Use when**: Need large receptive field in autoencoder
|
||||
|
||||
**AEDRNNNetwork** - Dilated RNN Autoencoder
|
||||
- Dilated recurrent connections
|
||||
- **Use when**: Sequential patterns with long-range dependencies
|
||||
|
||||
**AEBiGRUNetwork** - Bidirectional GRU
|
||||
- Bidirectional recurrent encoding
|
||||
- **Use when**: Context from both directions helpful
|
||||
|
||||
**AEAttentionBiGRUNetwork** - Attention + BiGRU
|
||||
- Attention mechanism on BiGRU outputs
|
||||
- **Use when**: Need to focus on important time steps
|
||||
|
||||
## Specialized Architectures
|
||||
|
||||
**LITENetwork** - Lightweight Inception Time Ensemble
|
||||
- Efficient inception-based architecture
|
||||
- LITEMV variant for multivariate series
|
||||
- **Use when**: Need efficiency with strong performance
|
||||
|
||||
**DeepARNetwork** - Probabilistic Forecasting
|
||||
- Autoregressive RNN for forecasting
|
||||
- Produces probabilistic predictions
|
||||
- **Use when**: Need forecast uncertainty quantification
|
||||
|
||||
## Usage with Estimators
|
||||
|
||||
Networks are typically used within estimators, not directly:
|
||||
|
||||
```python
|
||||
from aeon.classification.deep_learning import FCNClassifier
|
||||
from aeon.regression.deep_learning import ResNetRegressor
|
||||
from aeon.clustering.deep_learning import AEFCNClusterer
|
||||
|
||||
# Classification with FCN
|
||||
clf = FCNClassifier(n_epochs=100, batch_size=16)
|
||||
clf.fit(X_train, y_train)
|
||||
|
||||
# Regression with ResNet
|
||||
reg = ResNetRegressor(n_epochs=100)
|
||||
reg.fit(X_train, y_train)
|
||||
|
||||
# Clustering with autoencoder
|
||||
clusterer = AEFCNClusterer(n_clusters=3, n_epochs=100)
|
||||
labels = clusterer.fit_predict(X_train)
|
||||
```
|
||||
|
||||
## Custom Network Configuration
|
||||
|
||||
Many networks accept configuration parameters:
|
||||
|
||||
```python
|
||||
# Configure FCN layers
|
||||
clf = FCNClassifier(
|
||||
n_epochs=200,
|
||||
batch_size=32,
|
||||
kernel_size=[7, 5, 3], # Kernel sizes for each layer
|
||||
n_filters=[128, 256, 128], # Filters per layer
|
||||
learning_rate=0.001
|
||||
)
|
||||
```
|
||||
|
||||
## Base Classes
|
||||
|
||||
- `BaseDeepLearningNetwork` - Abstract base for all networks
|
||||
- `BaseDeepRegressor` - Base for deep regression
|
||||
- `BaseDeepClassifier` - Base for deep classification
|
||||
- `BaseDeepForecaster` - Base for deep forecasting
|
||||
|
||||
Extend these to implement custom architectures.
|
||||
|
||||
## Training Considerations
|
||||
|
||||
### Hyperparameters
|
||||
|
||||
Key hyperparameters to tune:
|
||||
|
||||
- `n_epochs` - Training iterations (50-200 typical)
|
||||
- `batch_size` - Samples per batch (16-64 typical)
|
||||
- `learning_rate` - Step size (0.0001-0.01)
|
||||
- Network-specific: layers, filters, kernel sizes
|
||||
|
||||
### Callbacks
|
||||
|
||||
Many networks support callbacks for training monitoring:
|
||||
|
||||
```python
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
||||
|
||||
clf = FCNClassifier(
|
||||
n_epochs=200,
|
||||
callbacks=[
|
||||
EarlyStopping(patience=20, restore_best_weights=True),
|
||||
ReduceLROnPlateau(patience=10, factor=0.5)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### GPU Acceleration
|
||||
|
||||
Deep learning networks benefit from GPU:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Use first GPU
|
||||
|
||||
# Networks automatically use GPU if available
|
||||
clf = InceptionTimeClassifier(n_epochs=100)
|
||||
clf.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
## Architecture Selection
|
||||
|
||||
### By Task:
|
||||
|
||||
**Classification**: InceptionNetwork, ResNetNetwork, FCNNetwork
|
||||
**Regression**: InceptionNetwork, ResNetNetwork, TCNNetwork
|
||||
**Forecasting**: TCNNetwork, DeepARNetwork, RecurrentNetwork
|
||||
**Clustering**: AEFCNNetwork, AEResNetNetwork, AEAttentionBiGRUNetwork
|
||||
|
||||
### By Data Characteristics:
|
||||
|
||||
**Long sequences**: TCNNetwork, DCNNNetwork (dilated convolutions)
|
||||
**Short sequences**: MLPNetwork, FCNNetwork
|
||||
**Multivariate**: InceptionNetwork, FCNNetwork, LITENetwork
|
||||
**Variable length**: RecurrentNetwork with masking
|
||||
**Multi-scale patterns**: InceptionNetwork
|
||||
|
||||
### By Computational Resources:
|
||||
|
||||
**Limited compute**: MLPNetwork, LITENetwork
|
||||
**Moderate compute**: FCNNetwork, TimeCNNNetwork
|
||||
**High compute available**: InceptionNetwork, ResNetNetwork
|
||||
**GPU available**: Any deep network (major speedup)
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Data Preparation
|
||||
|
||||
Normalize input data:
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection import Normalizer
|
||||
|
||||
normalizer = Normalizer()
|
||||
X_train_norm = normalizer.fit_transform(X_train)
|
||||
X_test_norm = normalizer.transform(X_test)
|
||||
```
|
||||
|
||||
### 2. Training/Validation Split
|
||||
|
||||
Use validation set for early stopping:
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
X_train_fit, X_val, y_train_fit, y_val = train_test_split(
|
||||
X_train, y_train, test_size=0.2, stratify=y_train
|
||||
)
|
||||
|
||||
clf = FCNClassifier(n_epochs=200)
|
||||
clf.fit(X_train_fit, y_train_fit, validation_data=(X_val, y_val))
|
||||
```
|
||||
|
||||
### 3. Start Simple
|
||||
|
||||
Begin with simpler architectures before complex ones:
|
||||
|
||||
1. Try MLPNetwork or FCNNetwork first
|
||||
2. If insufficient, try ResNetNetwork or InceptionNetwork
|
||||
3. Consider ensembles if single models insufficient
|
||||
|
||||
### 4. Hyperparameter Tuning
|
||||
|
||||
Use grid search or random search:
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
param_grid = {
|
||||
'n_epochs': [100, 200],
|
||||
'batch_size': [16, 32],
|
||||
'learning_rate': [0.001, 0.0001]
|
||||
}
|
||||
|
||||
clf = FCNClassifier()
|
||||
grid = GridSearchCV(clf, param_grid, cv=3)
|
||||
grid.fit(X_train, y_train)
|
||||
```
|
||||
|
||||
### 5. Regularization
|
||||
|
||||
Prevent overfitting:
|
||||
- Use dropout (if network supports)
|
||||
- Early stopping
|
||||
- Data augmentation (if available)
|
||||
- Reduce model complexity
|
||||
|
||||
### 6. Reproducibility
|
||||
|
||||
Set random seeds:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import random
|
||||
import tensorflow as tf
|
||||
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
```
|
||||
118
scientific-skills/aeon/references/regression.md
Normal file
118
scientific-skills/aeon/references/regression.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# Time Series Regression
|
||||
|
||||
Aeon provides time series regressors across 9 categories for predicting continuous values from temporal sequences.
|
||||
|
||||
## Convolution-Based Regressors
|
||||
|
||||
Apply convolutional kernels for feature extraction:
|
||||
|
||||
- `HydraRegressor` - Multi-resolution dilated convolutions
|
||||
- `RocketRegressor` - Random convolutional kernels
|
||||
- `MiniRocketRegressor` - Simplified ROCKET for speed
|
||||
- `MultiRocketRegressor` - Combined ROCKET variants
|
||||
- `MultiRocketHydraRegressor` - Merges ROCKET and Hydra approaches
|
||||
|
||||
**Use when**: Need fast regression with strong baseline performance.
|
||||
|
||||
## Deep Learning Regressors
|
||||
|
||||
Neural architectures for end-to-end temporal regression:
|
||||
|
||||
- `FCNRegressor` - Fully convolutional network
|
||||
- `ResNetRegressor` - Residual blocks with skip connections
|
||||
- `InceptionTimeRegressor` - Multi-scale inception modules
|
||||
- `TimeCNNRegressor` - Standard CNN architecture
|
||||
- `RecurrentRegressor` - RNN/LSTM/GRU variants
|
||||
- `MLPRegressor` - Multi-layer perceptron
|
||||
- `EncoderRegressor` - Generic encoder wrapper
|
||||
- `LITERegressor` - Lightweight inception time ensemble
|
||||
- `DisjointCNNRegressor` - Specialized CNN architecture
|
||||
|
||||
**Use when**: Large datasets, complex patterns, or need feature learning.
|
||||
|
||||
## Distance-Based Regressors
|
||||
|
||||
k-nearest neighbors with temporal distance metrics:
|
||||
|
||||
- `KNeighborsTimeSeriesRegressor` - k-NN with DTW, LCSS, ERP, or other distances
|
||||
|
||||
**Use when**: Small datasets, local similarity patterns, or interpretable predictions.
|
||||
|
||||
## Feature-Based Regressors
|
||||
|
||||
Extract statistical features before regression:
|
||||
|
||||
- `Catch22Regressor` - 22 canonical time-series characteristics
|
||||
- `FreshPRINCERegressor` - Pipeline combining multiple feature extractors
|
||||
- `SummaryRegressor` - Summary statistics features
|
||||
- `TSFreshRegressor` - Automated tsfresh feature extraction
|
||||
|
||||
**Use when**: Need interpretable features or domain-specific feature engineering.
|
||||
|
||||
## Hybrid Regressors
|
||||
|
||||
Combine multiple approaches:
|
||||
|
||||
- `RISTRegressor` - Randomized Interval-Shapelet Transformation
|
||||
|
||||
**Use when**: Benefit from combining interval and shapelet methods.
|
||||
|
||||
## Interval-Based Regressors
|
||||
|
||||
Extract features from time intervals:
|
||||
|
||||
- `CanonicalIntervalForestRegressor` - Random intervals with decision trees
|
||||
- `DrCIFRegressor` - Diverse Representation CIF
|
||||
- `TimeSeriesForestRegressor` - Random interval ensemble
|
||||
- `RandomIntervalRegressor` - Simple interval-based approach
|
||||
- `RandomIntervalSpectralEnsembleRegressor` - Spectral interval features
|
||||
- `QUANTRegressor` - Quantile-based interval features
|
||||
|
||||
**Use when**: Predictive patterns occur in specific time windows.
|
||||
|
||||
## Shapelet-Based Regressors
|
||||
|
||||
Use discriminative subsequences for prediction:
|
||||
|
||||
- `RDSTRegressor` - Random Dilated Shapelet Transform
|
||||
|
||||
**Use when**: Need phase-invariant discriminative patterns.
|
||||
|
||||
## Composition Tools
|
||||
|
||||
Build custom regression pipelines:
|
||||
|
||||
- `RegressorPipeline` - Chain transformers with regressors
|
||||
- `RegressorEnsemble` - Weighted ensemble with learnable weights
|
||||
- `SklearnRegressorWrapper` - Adapt sklearn regressors for time series
|
||||
|
||||
## Utilities
|
||||
|
||||
- `DummyRegressor` - Baseline strategies (mean, median)
|
||||
- `BaseRegressor` - Abstract base for custom regressors
|
||||
- `BaseDeepRegressor` - Base for deep learning regressors
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.regression.convolution_based import RocketRegressor
|
||||
from aeon.datasets import load_regression
|
||||
|
||||
# Load data
|
||||
X_train, y_train = load_regression("Covid3Month", split="train")
|
||||
X_test, y_test = load_regression("Covid3Month", split="test")
|
||||
|
||||
# Train and predict
|
||||
reg = RocketRegressor()
|
||||
reg.fit(X_train, y_train)
|
||||
predictions = reg.predict(X_test)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Speed priority**: MiniRocketRegressor
|
||||
- **Accuracy priority**: InceptionTimeRegressor, MultiRocketHydraRegressor
|
||||
- **Interpretability**: Catch22Regressor, SummaryRegressor
|
||||
- **Small data**: KNeighborsTimeSeriesRegressor
|
||||
- **Large data**: Deep learning regressors, ROCKET variants
|
||||
- **Interval patterns**: DrCIFRegressor, CanonicalIntervalForestRegressor
|
||||
163
scientific-skills/aeon/references/segmentation.md
Normal file
163
scientific-skills/aeon/references/segmentation.md
Normal file
@@ -0,0 +1,163 @@
|
||||
# Time Series Segmentation
|
||||
|
||||
Aeon provides algorithms to partition time series into regions with distinct characteristics, identifying change points and boundaries.
|
||||
|
||||
## Segmentation Algorithms
|
||||
|
||||
### Binary Segmentation
|
||||
- `BinSegmenter` - Recursive binary segmentation
|
||||
- Iteratively splits series at most significant change points
|
||||
- Parameters: `n_segments`, `cost_function`
|
||||
- **Use when**: Known number of segments, hierarchical structure
|
||||
|
||||
### Classification-Based
|
||||
- `ClaSPSegmenter` - Classification Score Profile
|
||||
- Uses classification performance to identify boundaries
|
||||
- Discovers segments where classification distinguishes neighbors
|
||||
- **Use when**: Segments have different temporal patterns
|
||||
|
||||
### Fast Pattern-Based
|
||||
- `FLUSSSegmenter` - Fast Low-cost Unipotent Semantic Segmentation
|
||||
- Efficient semantic segmentation using arc crossings
|
||||
- Based on matrix profile
|
||||
- **Use when**: Large time series, need speed and pattern discovery
|
||||
|
||||
### Information Theory
|
||||
- `InformationGainSegmenter` - Information gain maximization
|
||||
- Finds boundaries maximizing information gain
|
||||
- **Use when**: Statistical differences between segments
|
||||
|
||||
### Gaussian Modeling
|
||||
- `GreedyGaussianSegmenter` - Greedy Gaussian approximation
|
||||
- Models segments as Gaussian distributions
|
||||
- Incrementally adds change points
|
||||
- **Use when**: Segments follow Gaussian distributions
|
||||
|
||||
### Hierarchical Agglomerative
|
||||
- `EAggloSegmenter` - Bottom-up merging approach
|
||||
- Estimates change points via agglomeration
|
||||
- **Use when**: Want hierarchical segmentation structure
|
||||
|
||||
### Hidden Markov Models
|
||||
- `HMMSegmenter` - HMM with Viterbi decoding
|
||||
- Probabilistic state-based segmentation
|
||||
- **Use when**: Segments represent hidden states
|
||||
|
||||
### Dimensionality-Based
|
||||
- `HidalgoSegmenter` - Heterogeneous Intrinsic Dimensionality Algorithm
|
||||
- Detects changes in local dimensionality
|
||||
- **Use when**: Dimensionality shifts between segments
|
||||
|
||||
### Baseline
|
||||
- `RandomSegmenter` - Random change point generation
|
||||
- **Use when**: Need null hypothesis baseline
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from aeon.segmentation import ClaSPSegmenter
|
||||
import numpy as np
|
||||
|
||||
# Create time series with regime changes
|
||||
y = np.concatenate([
|
||||
np.sin(np.linspace(0, 10, 100)), # Segment 1
|
||||
np.cos(np.linspace(0, 10, 100)), # Segment 2
|
||||
np.sin(2 * np.linspace(0, 10, 100)) # Segment 3
|
||||
])
|
||||
|
||||
# Segment the series
|
||||
segmenter = ClaSPSegmenter()
|
||||
change_points = segmenter.fit_predict(y)
|
||||
|
||||
print(f"Detected change points: {change_points}")
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
Segmenters return change point indices:
|
||||
|
||||
```python
|
||||
# change_points = [100, 200] # Boundaries between segments
|
||||
# This divides series into: [0:100], [100:200], [200:end]
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Speed priority**: FLUSSSegmenter, BinSegmenter
|
||||
- **Accuracy priority**: ClaSPSegmenter, HMMSegmenter
|
||||
- **Known segment count**: BinSegmenter with n_segments parameter
|
||||
- **Unknown segment count**: ClaSPSegmenter, InformationGainSegmenter
|
||||
- **Pattern changes**: FLUSSSegmenter, ClaSPSegmenter
|
||||
- **Statistical changes**: InformationGainSegmenter, GreedyGaussianSegmenter
|
||||
- **State transitions**: HMMSegmenter
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Regime Change Detection
|
||||
Identify when time series behavior fundamentally changes:
|
||||
|
||||
```python
|
||||
from aeon.segmentation import InformationGainSegmenter
|
||||
|
||||
segmenter = InformationGainSegmenter(k=3) # Up to 3 change points
|
||||
change_points = segmenter.fit_predict(stock_prices)
|
||||
```
|
||||
|
||||
### Activity Segmentation
|
||||
Segment sensor data into activities:
|
||||
|
||||
```python
|
||||
from aeon.segmentation import ClaSPSegmenter
|
||||
|
||||
segmenter = ClaSPSegmenter()
|
||||
boundaries = segmenter.fit_predict(accelerometer_data)
|
||||
```
|
||||
|
||||
### Seasonal Boundary Detection
|
||||
Find season transitions in time series:
|
||||
|
||||
```python
|
||||
from aeon.segmentation import HMMSegmenter
|
||||
|
||||
segmenter = HMMSegmenter(n_states=4) # 4 seasons
|
||||
segments = segmenter.fit_predict(temperature_data)
|
||||
```
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
Use segmentation quality metrics:
|
||||
|
||||
```python
|
||||
from aeon.benchmarking.metrics.segmentation import (
|
||||
count_error,
|
||||
hausdorff_error
|
||||
)
|
||||
|
||||
# Count error: difference in number of change points
|
||||
count_err = count_error(y_true, y_pred)
|
||||
|
||||
# Hausdorff: maximum distance between predicted and true points
|
||||
hausdorff_err = hausdorff_error(y_true, y_pred)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Normalize data**: Ensures change detection not dominated by scale
|
||||
2. **Choose appropriate metric**: Different algorithms optimize different criteria
|
||||
3. **Validate segments**: Visualize to verify meaningful boundaries
|
||||
4. **Handle noise**: Consider smoothing before segmentation
|
||||
5. **Domain knowledge**: Use expected segment count if known
|
||||
6. **Parameter tuning**: Adjust sensitivity parameters (thresholds, penalties)
|
||||
|
||||
## Visualization
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.plot(y, label='Time Series')
|
||||
for cp in change_points:
|
||||
plt.axvline(cp, color='r', linestyle='--', label='Change Point')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
```
|
||||
187
scientific-skills/aeon/references/similarity_search.md
Normal file
187
scientific-skills/aeon/references/similarity_search.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# Similarity Search
|
||||
|
||||
Aeon provides tools for finding similar patterns within and across time series, including subsequence search, motif discovery, and approximate nearest neighbors.
|
||||
|
||||
## Subsequence Nearest Neighbors (SNN)
|
||||
|
||||
Find most similar subsequences within a time series.
|
||||
|
||||
### MASS Algorithm
|
||||
- `MassSNN` - Mueen's Algorithm for Similarity Search
|
||||
- Fast normalized cross-correlation for similarity
|
||||
- Computes distance profile efficiently
|
||||
- **Use when**: Need exact nearest neighbor distances, large series
|
||||
|
||||
### STOMP-Based Motif Discovery
|
||||
- `StompMotif` - Discovers recurring patterns (motifs)
|
||||
- Finds top-k most similar subsequence pairs
|
||||
- Based on matrix profile computation
|
||||
- **Use when**: Want to discover repeated patterns
|
||||
|
||||
### Brute Force Baseline
|
||||
- `DummySNN` - Exhaustive distance computation
|
||||
- Computes all pairwise distances
|
||||
- **Use when**: Small series, need exact baseline
|
||||
|
||||
## Collection-Level Search
|
||||
|
||||
Find similar time series across collections.
|
||||
|
||||
### Approximate Nearest Neighbors (ANN)
|
||||
- `RandomProjectionIndexANN` - Locality-sensitive hashing
|
||||
- Uses random projections with cosine similarity
|
||||
- Builds index for fast approximate search
|
||||
- **Use when**: Large collection, speed more important than exactness
|
||||
|
||||
## Quick Start: Motif Discovery
|
||||
|
||||
```python
|
||||
from aeon.similarity_search import StompMotif
|
||||
import numpy as np
|
||||
|
||||
# Create time series with repeated patterns
|
||||
pattern = np.sin(np.linspace(0, 2*np.pi, 50))
|
||||
y = np.concatenate([
|
||||
pattern + np.random.normal(0, 0.1, 50),
|
||||
np.random.normal(0, 1, 100),
|
||||
pattern + np.random.normal(0, 0.1, 50),
|
||||
np.random.normal(0, 1, 100)
|
||||
])
|
||||
|
||||
# Find top-3 motifs
|
||||
motif_finder = StompMotif(window_size=50, k=3)
|
||||
motifs = motif_finder.fit_predict(y)
|
||||
|
||||
# motifs contains indices of motif occurrences
|
||||
for i, (idx1, idx2) in enumerate(motifs):
|
||||
print(f"Motif {i+1} at positions {idx1} and {idx2}")
|
||||
```
|
||||
|
||||
## Quick Start: Subsequence Search
|
||||
|
||||
```python
|
||||
from aeon.similarity_search import MassSNN
|
||||
import numpy as np
|
||||
|
||||
# Time series to search within
|
||||
y = np.sin(np.linspace(0, 20, 500))
|
||||
|
||||
# Query subsequence
|
||||
query = np.sin(np.linspace(0, 2, 50))
|
||||
|
||||
# Find nearest subsequences
|
||||
searcher = MassSNN()
|
||||
distances = searcher.fit_transform(y, query)
|
||||
|
||||
# Find best match
|
||||
best_match_idx = np.argmin(distances)
|
||||
print(f"Best match at index {best_match_idx}")
|
||||
```
|
||||
|
||||
## Quick Start: Approximate NN on Collections
|
||||
|
||||
```python
|
||||
from aeon.similarity_search import RandomProjectionIndexANN
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load time series collection
|
||||
X_train, _ = load_classification("GunPoint", split="train")
|
||||
|
||||
# Build index
|
||||
ann = RandomProjectionIndexANN(n_projections=8, n_bits=4)
|
||||
ann.fit(X_train)
|
||||
|
||||
# Find approximate nearest neighbors
|
||||
query = X_train[0]
|
||||
neighbors, distances = ann.kneighbors(query, k=5)
|
||||
```
|
||||
|
||||
## Matrix Profile
|
||||
|
||||
The matrix profile is a fundamental data structure for many similarity search tasks:
|
||||
|
||||
- **Distance Profile**: Distances from a query to all subsequences
|
||||
- **Matrix Profile**: Minimum distance for each subsequence to any other
|
||||
- **Motif**: Pair of subsequences with minimum distance
|
||||
- **Discord**: Subsequence with maximum minimum distance (anomaly)
|
||||
|
||||
```python
|
||||
from aeon.similarity_search import StompMotif
|
||||
|
||||
# Compute matrix profile and find motifs/discords
|
||||
mp = StompMotif(window_size=50)
|
||||
mp.fit(y)
|
||||
|
||||
# Access matrix profile
|
||||
profile = mp.matrix_profile_
|
||||
profile_indices = mp.matrix_profile_index_
|
||||
|
||||
# Find discords (anomalies)
|
||||
discord_idx = np.argmax(profile)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
- **Exact subsequence search**: MassSNN
|
||||
- **Motif discovery**: StompMotif
|
||||
- **Anomaly detection**: Matrix profile (see anomaly_detection.md)
|
||||
- **Fast approximate search**: RandomProjectionIndexANN
|
||||
- **Small data**: DummySNN for exact results
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Pattern Matching
|
||||
Find where a pattern occurs in a long series:
|
||||
|
||||
```python
|
||||
# Find heartbeat pattern in ECG data
|
||||
searcher = MassSNN()
|
||||
distances = searcher.fit_transform(ecg_data, heartbeat_pattern)
|
||||
occurrences = np.where(distances < threshold)[0]
|
||||
```
|
||||
|
||||
### Motif Discovery
|
||||
Identify recurring patterns:
|
||||
|
||||
```python
|
||||
# Find repeated behavioral patterns
|
||||
motif_finder = StompMotif(window_size=100, k=5)
|
||||
motifs = motif_finder.fit_predict(activity_data)
|
||||
```
|
||||
|
||||
### Time Series Retrieval
|
||||
Find similar time series in database:
|
||||
|
||||
```python
|
||||
# Build searchable index
|
||||
ann = RandomProjectionIndexANN()
|
||||
ann.fit(time_series_database)
|
||||
|
||||
# Query for similar series
|
||||
neighbors = ann.kneighbors(query_series, k=10)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Window size**: Critical parameter for subsequence methods
|
||||
- Too small: Captures noise
|
||||
- Too large: Misses fine-grained patterns
|
||||
- Rule of thumb: 10-20% of series length
|
||||
|
||||
2. **Normalization**: Most methods assume z-normalized subsequences
|
||||
- Handles amplitude variations
|
||||
- Focus on shape similarity
|
||||
|
||||
3. **Distance metrics**: Different metrics for different needs
|
||||
- Euclidean: Fast, shape-based
|
||||
- DTW: Handles temporal warping
|
||||
- Cosine: Scale-invariant
|
||||
|
||||
4. **Exclusion zone**: For motif discovery, exclude trivial matches
|
||||
- Typically set to 0.5-1.0 × window_size
|
||||
- Prevents finding overlapping occurrences
|
||||
|
||||
5. **Performance**:
|
||||
- MASS is O(n log n) vs O(n²) brute force
|
||||
- ANN trades accuracy for speed
|
||||
- GPU acceleration available for some methods
|
||||
246
scientific-skills/aeon/references/transformations.md
Normal file
246
scientific-skills/aeon/references/transformations.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Transformations
|
||||
|
||||
Aeon provides extensive transformation capabilities for preprocessing, feature extraction, and representation learning from time series data.
|
||||
|
||||
## Transformation Types
|
||||
|
||||
Aeon distinguishes between:
|
||||
- **CollectionTransformers**: Transform multiple time series (collections)
|
||||
- **SeriesTransformers**: Transform individual time series
|
||||
|
||||
## Collection Transformers
|
||||
|
||||
### Convolution-Based Feature Extraction
|
||||
|
||||
Fast, scalable feature generation using random kernels:
|
||||
|
||||
- `RocketTransformer` - Random convolutional kernels
|
||||
- `MiniRocketTransformer` - Simplified ROCKET for speed
|
||||
- `MultiRocketTransformer` - Enhanced ROCKET variant
|
||||
- `HydraTransformer` - Multi-resolution dilated convolutions
|
||||
- `MultiRocketHydraTransformer` - Combines ROCKET and Hydra
|
||||
- `ROCKETGPU` - GPU-accelerated variant
|
||||
|
||||
**Use when**: Need fast, scalable features for any ML algorithm, strong baseline performance.
|
||||
|
||||
### Statistical Feature Extraction
|
||||
|
||||
Domain-agnostic features based on time series characteristics:
|
||||
|
||||
- `Catch22` - 22 canonical time-series characteristics
|
||||
- `TSFresh` - Comprehensive automated feature extraction (100+ features)
|
||||
- `TSFreshRelevant` - Feature extraction with relevance filtering
|
||||
- `SevenNumberSummary` - Descriptive statistics (mean, std, quantiles)
|
||||
|
||||
**Use when**: Need interpretable features, domain-agnostic approach, or feeding traditional ML.
|
||||
|
||||
### Dictionary-Based Representations
|
||||
|
||||
Symbolic approximations for discrete representations:
|
||||
|
||||
- `SAX` - Symbolic Aggregate approXimation
|
||||
- `PAA` - Piecewise Aggregate Approximation
|
||||
- `SFA` - Symbolic Fourier Approximation
|
||||
- `SFAFast` - Optimized SFA
|
||||
- `SFAWhole` - SFA on entire series (no windowing)
|
||||
- `BORF` - Bag-of-Receptive-Fields
|
||||
|
||||
**Use when**: Need discrete/symbolic representation, dimensionality reduction, interpretability.
|
||||
|
||||
### Shapelet-Based Features
|
||||
|
||||
Discriminative subsequence extraction:
|
||||
|
||||
- `RandomShapeletTransform` - Random discriminative shapelets
|
||||
- `RandomDilatedShapeletTransform` - Dilated shapelets for multi-scale
|
||||
- `SAST` - Scalable And Accurate Subsequence Transform
|
||||
- `RSAST` - Randomized SAST
|
||||
|
||||
**Use when**: Need interpretable discriminative patterns, phase-invariant features.
|
||||
|
||||
### Interval-Based Features
|
||||
|
||||
Statistical summaries from time intervals:
|
||||
|
||||
- `RandomIntervals` - Features from random intervals
|
||||
- `SupervisedIntervals` - Supervised interval selection
|
||||
- `QUANTTransformer` - Quantile-based interval features
|
||||
|
||||
**Use when**: Predictive patterns localized to specific windows.
|
||||
|
||||
### Preprocessing Transformations
|
||||
|
||||
Data preparation and normalization:
|
||||
|
||||
- `MinMaxScaler` - Scale to [0, 1] range
|
||||
- `Normalizer` - Z-normalization (zero mean, unit variance)
|
||||
- `Centerer` - Center to zero mean
|
||||
- `SimpleImputer` - Fill missing values
|
||||
- `DownsampleTransformer` - Reduce temporal resolution
|
||||
- `Tabularizer` - Convert time series to tabular format
|
||||
|
||||
**Use when**: Need standardization, missing value handling, format conversion.
|
||||
|
||||
### Specialized Transformations
|
||||
|
||||
Advanced analysis methods:
|
||||
|
||||
- `MatrixProfile` - Computes distance profiles for pattern discovery
|
||||
- `DWTTransformer` - Discrete Wavelet Transform
|
||||
- `AutocorrelationFunctionTransformer` - ACF computation
|
||||
- `Dobin` - Distance-based Outlier BasIs using Neighbors
|
||||
- `SignatureTransformer` - Path signature methods
|
||||
- `PLATransformer` - Piecewise Linear Approximation
|
||||
|
||||
### Class Imbalance Handling
|
||||
|
||||
- `ADASYN` - Adaptive Synthetic Sampling
|
||||
- `SMOTE` - Synthetic Minority Over-sampling
|
||||
- `OHIT` - Over-sampling with Highly Imbalanced Time series
|
||||
|
||||
**Use when**: Classification with imbalanced classes.
|
||||
|
||||
### Pipeline Composition
|
||||
|
||||
- `CollectionTransformerPipeline` - Chain multiple transformers
|
||||
|
||||
## Series Transformers
|
||||
|
||||
Transform individual time series (e.g., for preprocessing in forecasting).
|
||||
|
||||
### Statistical Analysis
|
||||
|
||||
- `AutoCorrelationSeriesTransformer` - Autocorrelation
|
||||
- `StatsModelsACF` - ACF using statsmodels
|
||||
- `StatsModelsPACF` - Partial autocorrelation
|
||||
|
||||
### Smoothing and Filtering
|
||||
|
||||
- `ExponentialSmoothing` - Exponentially weighted moving average
|
||||
- `MovingAverage` - Simple or weighted moving average
|
||||
- `SavitzkyGolayFilter` - Polynomial smoothing
|
||||
- `GaussianFilter` - Gaussian kernel smoothing
|
||||
- `BKFilter` - Baxter-King bandpass filter
|
||||
- `DiscreteFourierApproximation` - Fourier-based filtering
|
||||
|
||||
**Use when**: Need noise reduction, trend extraction, or frequency filtering.
|
||||
|
||||
### Dimensionality Reduction
|
||||
|
||||
- `PCASeriesTransformer` - Principal component analysis
|
||||
- `PlASeriesTransformer` - Piecewise Linear Approximation
|
||||
|
||||
### Transformations
|
||||
|
||||
- `BoxCoxTransformer` - Variance stabilization
|
||||
- `LogTransformer` - Logarithmic scaling
|
||||
- `ClaSPTransformer` - Classification Score Profile
|
||||
|
||||
### Pipeline Composition
|
||||
|
||||
- `SeriesTransformerPipeline` - Chain series transformers
|
||||
|
||||
## Quick Start: Feature Extraction
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection.convolution_based import RocketTransformer
|
||||
from aeon.classification.sklearn import RotationForest
|
||||
from aeon.datasets import load_classification
|
||||
|
||||
# Load data
|
||||
X_train, y_train = load_classification("GunPoint", split="train")
|
||||
X_test, y_test = load_classification("GunPoint", split="test")
|
||||
|
||||
# Extract ROCKET features
|
||||
rocket = RocketTransformer()
|
||||
X_train_features = rocket.fit_transform(X_train)
|
||||
X_test_features = rocket.transform(X_test)
|
||||
|
||||
# Use with any sklearn classifier
|
||||
clf = RotationForest()
|
||||
clf.fit(X_train_features, y_train)
|
||||
accuracy = clf.score(X_test_features, y_test)
|
||||
```
|
||||
|
||||
## Quick Start: Preprocessing Pipeline
|
||||
|
||||
```python
|
||||
from aeon.transformations.collection import (
|
||||
MinMaxScaler,
|
||||
SimpleImputer,
|
||||
CollectionTransformerPipeline
|
||||
)
|
||||
|
||||
# Build preprocessing pipeline
|
||||
pipeline = CollectionTransformerPipeline([
|
||||
('imputer', SimpleImputer(strategy='mean')),
|
||||
('scaler', MinMaxScaler())
|
||||
])
|
||||
|
||||
X_transformed = pipeline.fit_transform(X_train)
|
||||
```
|
||||
|
||||
## Quick Start: Series Smoothing
|
||||
|
||||
```python
|
||||
from aeon.transformations.series import MovingAverage
|
||||
|
||||
# Smooth individual time series
|
||||
smoother = MovingAverage(window_size=5)
|
||||
y_smoothed = smoother.fit_transform(y)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
### For Feature Extraction:
|
||||
- **Speed + Performance**: MiniRocketTransformer
|
||||
- **Interpretability**: Catch22, TSFresh
|
||||
- **Dimensionality reduction**: PAA, SAX, PCA
|
||||
- **Discriminative patterns**: Shapelet transforms
|
||||
- **Comprehensive features**: TSFresh (with longer runtime)
|
||||
|
||||
### For Preprocessing:
|
||||
- **Normalization**: Normalizer, MinMaxScaler
|
||||
- **Smoothing**: MovingAverage, SavitzkyGolayFilter
|
||||
- **Missing values**: SimpleImputer
|
||||
- **Frequency analysis**: DWTTransformer, Fourier methods
|
||||
|
||||
### For Symbolic Representation:
|
||||
- **Fast approximation**: PAA
|
||||
- **Alphabet-based**: SAX
|
||||
- **Frequency-based**: SFA, SFAFast
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Fit on training data only**: Avoid data leakage
|
||||
```python
|
||||
transformer.fit(X_train)
|
||||
X_train_tf = transformer.transform(X_train)
|
||||
X_test_tf = transformer.transform(X_test)
|
||||
```
|
||||
|
||||
2. **Pipeline composition**: Chain transformers for complex workflows
|
||||
```python
|
||||
pipeline = CollectionTransformerPipeline([
|
||||
('imputer', SimpleImputer()),
|
||||
('scaler', Normalizer()),
|
||||
('features', RocketTransformer())
|
||||
])
|
||||
```
|
||||
|
||||
3. **Feature selection**: TSFresh can generate many features; consider selection
|
||||
```python
|
||||
from sklearn.feature_selection import SelectKBest
|
||||
selector = SelectKBest(k=100)
|
||||
X_selected = selector.fit_transform(X_features, y)
|
||||
```
|
||||
|
||||
4. **Memory considerations**: Some transformers memory-intensive on large datasets
|
||||
- Use MiniRocket instead of ROCKET for speed
|
||||
- Consider downsampling for very long series
|
||||
- Use ROCKETGPU for GPU acceleration
|
||||
|
||||
5. **Domain knowledge**: Choose transformations matching domain:
|
||||
- Periodic data: Fourier-based methods
|
||||
- Noisy data: Smoothing filters
|
||||
- Spike detection: Wavelet transforms
|
||||
@@ -1,6 +1,9 @@
|
||||
---
|
||||
name: alphafold-database
|
||||
description: "Access AlphaFold's 200M+ AI-predicted protein structures. Retrieve structures by UniProt ID, download PDB/mmCIF files, analyze confidence metrics (pLDDT, PAE), for drug discovery and structural biology."
|
||||
license: Unknown
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# AlphaFold Database
|
||||
@@ -195,7 +198,7 @@ For large-scale analyses, use Google Cloud datasets:
|
||||
|
||||
```bash
|
||||
# Install gsutil
|
||||
pip install gsutil
|
||||
uv pip install gsutil
|
||||
|
||||
# List available data
|
||||
gsutil ls gs://public-datasets-deepmind-alphafold-v4/
|
||||
@@ -359,16 +362,16 @@ print(df)
|
||||
|
||||
```bash
|
||||
# Install Biopython for structure access
|
||||
pip install biopython
|
||||
uv pip install biopython
|
||||
|
||||
# Install requests for API access
|
||||
pip install requests
|
||||
uv pip install requests
|
||||
|
||||
# For visualization and analysis
|
||||
pip install numpy matplotlib pandas scipy
|
||||
uv pip install numpy matplotlib pandas scipy
|
||||
|
||||
# For Google Cloud access (optional)
|
||||
pip install google-cloud-bigquery gsutil
|
||||
uv pip install google-cloud-bigquery gsutil
|
||||
```
|
||||
|
||||
### 3D-Beacons API Alternative
|
||||
397
scientific-skills/anndata/SKILL.md
Normal file
397
scientific-skills/anndata/SKILL.md
Normal file
@@ -0,0 +1,397 @@
|
||||
---
|
||||
name: anndata
|
||||
description: This skill should be used when working with annotated data matrices in Python, particularly for single-cell genomics analysis, managing experimental measurements with metadata, or handling large-scale biological datasets. Use when tasks involve AnnData objects, h5ad files, single-cell RNA-seq data, or integration with scanpy/scverse tools.
|
||||
license: BSD-3-Clause license
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# AnnData
|
||||
|
||||
## Overview
|
||||
|
||||
AnnData is a Python package for handling annotated data matrices, storing experimental measurements (X) alongside observation metadata (obs), variable metadata (var), and multi-dimensional annotations (obsm, varm, obsp, varp, uns). Originally designed for single-cell genomics through Scanpy, it now serves as a general-purpose framework for any annotated data requiring efficient storage, manipulation, and analysis.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when:
|
||||
- Creating, reading, or writing AnnData objects
|
||||
- Working with h5ad, zarr, or other genomics data formats
|
||||
- Performing single-cell RNA-seq analysis
|
||||
- Managing large datasets with sparse matrices or backed mode
|
||||
- Concatenating multiple datasets or experimental batches
|
||||
- Subsetting, filtering, or transforming annotated data
|
||||
- Integrating with scanpy, scvi-tools, or other scverse ecosystem tools
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv pip install anndata
|
||||
|
||||
# With optional dependencies
|
||||
uv pip install anndata[dev,test,doc]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Creating an AnnData object
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Minimal creation
|
||||
X = np.random.rand(100, 2000) # 100 cells × 2000 genes
|
||||
adata = ad.AnnData(X)
|
||||
|
||||
# With metadata
|
||||
obs = pd.DataFrame({
|
||||
'cell_type': ['T cell', 'B cell'] * 50,
|
||||
'sample': ['A', 'B'] * 50
|
||||
}, index=[f'cell_{i}' for i in range(100)])
|
||||
|
||||
var = pd.DataFrame({
|
||||
'gene_name': [f'Gene_{i}' for i in range(2000)]
|
||||
}, index=[f'ENSG{i:05d}' for i in range(2000)])
|
||||
|
||||
adata = ad.AnnData(X=X, obs=obs, var=var)
|
||||
```
|
||||
|
||||
### Reading data
|
||||
```python
|
||||
# Read h5ad file
|
||||
adata = ad.read_h5ad('data.h5ad')
|
||||
|
||||
# Read with backed mode (for large files)
|
||||
adata = ad.read_h5ad('large_data.h5ad', backed='r')
|
||||
|
||||
# Read other formats
|
||||
adata = ad.read_csv('data.csv')
|
||||
adata = ad.read_loom('data.loom')
|
||||
adata = ad.read_10x_h5('filtered_feature_bc_matrix.h5')
|
||||
```
|
||||
|
||||
### Writing data
|
||||
```python
|
||||
# Write h5ad file
|
||||
adata.write_h5ad('output.h5ad')
|
||||
|
||||
# Write with compression
|
||||
adata.write_h5ad('output.h5ad', compression='gzip')
|
||||
|
||||
# Write other formats
|
||||
adata.write_zarr('output.zarr')
|
||||
adata.write_csvs('output_dir/')
|
||||
```
|
||||
|
||||
### Basic operations
|
||||
```python
|
||||
# Subset by conditions
|
||||
t_cells = adata[adata.obs['cell_type'] == 'T cell']
|
||||
|
||||
# Subset by indices
|
||||
subset = adata[0:50, 0:100]
|
||||
|
||||
# Add metadata
|
||||
adata.obs['quality_score'] = np.random.rand(adata.n_obs)
|
||||
adata.var['highly_variable'] = np.random.rand(adata.n_vars) > 0.8
|
||||
|
||||
# Access dimensions
|
||||
print(f"{adata.n_obs} observations × {adata.n_vars} variables")
|
||||
```
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Data Structure
|
||||
|
||||
Understand the AnnData object structure including X, obs, var, layers, obsm, varm, obsp, varp, uns, and raw components.
|
||||
|
||||
**See**: `references/data_structure.md` for comprehensive information on:
|
||||
- Core components (X, obs, var, layers, obsm, varm, obsp, varp, uns, raw)
|
||||
- Creating AnnData objects from various sources
|
||||
- Accessing and manipulating data components
|
||||
- Memory-efficient practices
|
||||
|
||||
### 2. Input/Output Operations
|
||||
|
||||
Read and write data in various formats with support for compression, backed mode, and cloud storage.
|
||||
|
||||
**See**: `references/io_operations.md` for details on:
|
||||
- Native formats (h5ad, zarr)
|
||||
- Alternative formats (CSV, MTX, Loom, 10X, Excel)
|
||||
- Backed mode for large datasets
|
||||
- Remote data access
|
||||
- Format conversion
|
||||
- Performance optimization
|
||||
|
||||
Common commands:
|
||||
```python
|
||||
# Read/write h5ad
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r')
|
||||
adata.write_h5ad('output.h5ad', compression='gzip')
|
||||
|
||||
# Read 10X data
|
||||
adata = ad.read_10x_h5('filtered_feature_bc_matrix.h5')
|
||||
|
||||
# Read MTX format
|
||||
adata = ad.read_mtx('matrix.mtx').T
|
||||
```
|
||||
|
||||
### 3. Concatenation
|
||||
|
||||
Combine multiple AnnData objects along observations or variables with flexible join strategies.
|
||||
|
||||
**See**: `references/concatenation.md` for comprehensive coverage of:
|
||||
- Basic concatenation (axis=0 for observations, axis=1 for variables)
|
||||
- Join types (inner, outer)
|
||||
- Merge strategies (same, unique, first, only)
|
||||
- Tracking data sources with labels
|
||||
- Lazy concatenation (AnnCollection)
|
||||
- On-disk concatenation for large datasets
|
||||
|
||||
Common commands:
|
||||
```python
|
||||
# Concatenate observations (combine samples)
|
||||
adata = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
axis=0,
|
||||
join='inner',
|
||||
label='batch',
|
||||
keys=['batch1', 'batch2', 'batch3']
|
||||
)
|
||||
|
||||
# Concatenate variables (combine modalities)
|
||||
adata = ad.concat([adata_rna, adata_protein], axis=1)
|
||||
|
||||
# Lazy concatenation
|
||||
from anndata.experimental import AnnCollection
|
||||
collection = AnnCollection(
|
||||
['data1.h5ad', 'data2.h5ad'],
|
||||
join_obs='outer',
|
||||
label='dataset'
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Data Manipulation
|
||||
|
||||
Transform, subset, filter, and reorganize data efficiently.
|
||||
|
||||
**See**: `references/manipulation.md` for detailed guidance on:
|
||||
- Subsetting (by indices, names, boolean masks, metadata conditions)
|
||||
- Transposition
|
||||
- Copying (full copies vs views)
|
||||
- Renaming (observations, variables, categories)
|
||||
- Type conversions (strings to categoricals, sparse/dense)
|
||||
- Adding/removing data components
|
||||
- Reordering
|
||||
- Quality control filtering
|
||||
|
||||
Common commands:
|
||||
```python
|
||||
# Subset by metadata
|
||||
filtered = adata[adata.obs['quality_score'] > 0.8]
|
||||
hv_genes = adata[:, adata.var['highly_variable']]
|
||||
|
||||
# Transpose
|
||||
adata_T = adata.T
|
||||
|
||||
# Copy vs view
|
||||
view = adata[0:100, :] # View (lightweight reference)
|
||||
copy = adata[0:100, :].copy() # Independent copy
|
||||
|
||||
# Convert strings to categoricals
|
||||
adata.strings_to_categoricals()
|
||||
```
|
||||
|
||||
### 5. Best Practices
|
||||
|
||||
Follow recommended patterns for memory efficiency, performance, and reproducibility.
|
||||
|
||||
**See**: `references/best_practices.md` for guidelines on:
|
||||
- Memory management (sparse matrices, categoricals, backed mode)
|
||||
- Views vs copies
|
||||
- Data storage optimization
|
||||
- Performance optimization
|
||||
- Working with raw data
|
||||
- Metadata management
|
||||
- Reproducibility
|
||||
- Error handling
|
||||
- Integration with other tools
|
||||
- Common pitfalls and solutions
|
||||
|
||||
Key recommendations:
|
||||
```python
|
||||
# Use sparse matrices for sparse data
|
||||
from scipy.sparse import csr_matrix
|
||||
adata.X = csr_matrix(adata.X)
|
||||
|
||||
# Convert strings to categoricals
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# Use backed mode for large files
|
||||
adata = ad.read_h5ad('large.h5ad', backed='r')
|
||||
|
||||
# Store raw before filtering
|
||||
adata.raw = adata.copy()
|
||||
adata = adata[:, adata.var['highly_variable']]
|
||||
```
|
||||
|
||||
## Integration with Scverse Ecosystem
|
||||
|
||||
AnnData serves as the foundational data structure for the scverse ecosystem:
|
||||
|
||||
### Scanpy (Single-cell analysis)
|
||||
```python
|
||||
import scanpy as sc
|
||||
|
||||
# Preprocessing
|
||||
sc.pp.filter_cells(adata, min_genes=200)
|
||||
sc.pp.normalize_total(adata, target_sum=1e4)
|
||||
sc.pp.log1p(adata)
|
||||
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
|
||||
|
||||
# Dimensionality reduction
|
||||
sc.pp.pca(adata, n_comps=50)
|
||||
sc.pp.neighbors(adata, n_neighbors=15)
|
||||
sc.tl.umap(adata)
|
||||
sc.tl.leiden(adata)
|
||||
|
||||
# Visualization
|
||||
sc.pl.umap(adata, color=['cell_type', 'leiden'])
|
||||
```
|
||||
|
||||
### Muon (Multimodal data)
|
||||
```python
|
||||
import muon as mu
|
||||
|
||||
# Combine RNA and protein data
|
||||
mdata = mu.MuData({'rna': adata_rna, 'protein': adata_protein})
|
||||
```
|
||||
|
||||
### PyTorch integration
|
||||
```python
|
||||
from anndata.experimental import AnnLoader
|
||||
|
||||
# Create DataLoader for deep learning
|
||||
dataloader = AnnLoader(adata, batch_size=128, shuffle=True)
|
||||
|
||||
for batch in dataloader:
|
||||
X = batch.X
|
||||
# Train model
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Single-cell RNA-seq analysis
|
||||
```python
|
||||
import anndata as ad
|
||||
import scanpy as sc
|
||||
|
||||
# 1. Load data
|
||||
adata = ad.read_10x_h5('filtered_feature_bc_matrix.h5')
|
||||
|
||||
# 2. Quality control
|
||||
adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
|
||||
adata.obs['n_counts'] = adata.X.sum(axis=1)
|
||||
adata = adata[adata.obs['n_genes'] > 200]
|
||||
adata = adata[adata.obs['n_counts'] < 50000]
|
||||
|
||||
# 3. Store raw
|
||||
adata.raw = adata.copy()
|
||||
|
||||
# 4. Normalize and filter
|
||||
sc.pp.normalize_total(adata, target_sum=1e4)
|
||||
sc.pp.log1p(adata)
|
||||
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
|
||||
adata = adata[:, adata.var['highly_variable']]
|
||||
|
||||
# 5. Save processed data
|
||||
adata.write_h5ad('processed.h5ad')
|
||||
```
|
||||
|
||||
### Batch integration
|
||||
```python
|
||||
# Load multiple batches
|
||||
adata1 = ad.read_h5ad('batch1.h5ad')
|
||||
adata2 = ad.read_h5ad('batch2.h5ad')
|
||||
adata3 = ad.read_h5ad('batch3.h5ad')
|
||||
|
||||
# Concatenate with batch labels
|
||||
adata = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='batch',
|
||||
keys=['batch1', 'batch2', 'batch3'],
|
||||
join='inner'
|
||||
)
|
||||
|
||||
# Apply batch correction
|
||||
import scanpy as sc
|
||||
sc.pp.combat(adata, key='batch')
|
||||
|
||||
# Continue analysis
|
||||
sc.pp.pca(adata)
|
||||
sc.pp.neighbors(adata)
|
||||
sc.tl.umap(adata)
|
||||
```
|
||||
|
||||
### Working with large datasets
|
||||
```python
|
||||
# Open in backed mode
|
||||
adata = ad.read_h5ad('100GB_dataset.h5ad', backed='r')
|
||||
|
||||
# Filter based on metadata (no data loading)
|
||||
high_quality = adata[adata.obs['quality_score'] > 0.8]
|
||||
|
||||
# Load filtered subset
|
||||
adata_subset = high_quality.to_memory()
|
||||
|
||||
# Process subset
|
||||
process(adata_subset)
|
||||
|
||||
# Or process in chunks
|
||||
chunk_size = 1000
|
||||
for i in range(0, adata.n_obs, chunk_size):
|
||||
chunk = adata[i:i+chunk_size, :].to_memory()
|
||||
process(chunk)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Out of memory errors
|
||||
Use backed mode or convert to sparse matrices:
|
||||
```python
|
||||
# Backed mode
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r')
|
||||
|
||||
# Sparse matrices
|
||||
from scipy.sparse import csr_matrix
|
||||
adata.X = csr_matrix(adata.X)
|
||||
```
|
||||
|
||||
### Slow file reading
|
||||
Use compression and appropriate formats:
|
||||
```python
|
||||
# Optimize for storage
|
||||
adata.strings_to_categoricals()
|
||||
adata.write_h5ad('file.h5ad', compression='gzip')
|
||||
|
||||
# Use Zarr for cloud storage
|
||||
adata.write_zarr('file.zarr', chunks=(1000, 1000))
|
||||
```
|
||||
|
||||
### Index alignment issues
|
||||
Always align external data on index:
|
||||
```python
|
||||
# Wrong
|
||||
adata.obs['new_col'] = external_data['values']
|
||||
|
||||
# Correct
|
||||
adata.obs['new_col'] = external_data.set_index('cell_id').loc[adata.obs_names, 'values']
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **Official documentation**: https://anndata.readthedocs.io/
|
||||
- **Scanpy tutorials**: https://scanpy.readthedocs.io/
|
||||
- **Scverse ecosystem**: https://scverse.org/
|
||||
- **GitHub repository**: https://github.com/scverse/anndata
|
||||
525
scientific-skills/anndata/references/best_practices.md
Normal file
525
scientific-skills/anndata/references/best_practices.md
Normal file
@@ -0,0 +1,525 @@
|
||||
# Best Practices
|
||||
|
||||
Guidelines for efficient and effective use of AnnData.
|
||||
|
||||
## Memory Management
|
||||
|
||||
### Use sparse matrices for sparse data
|
||||
```python
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import anndata as ad
|
||||
|
||||
# Check data sparsity
|
||||
data = np.random.rand(1000, 2000)
|
||||
sparsity = 1 - np.count_nonzero(data) / data.size
|
||||
print(f"Sparsity: {sparsity:.2%}")
|
||||
|
||||
# Convert to sparse if >50% zeros
|
||||
if sparsity > 0.5:
|
||||
adata = ad.AnnData(X=csr_matrix(data))
|
||||
else:
|
||||
adata = ad.AnnData(X=data)
|
||||
|
||||
# Benefits: 10-100x memory reduction for sparse genomics data
|
||||
```
|
||||
|
||||
### Convert strings to categoricals
|
||||
```python
|
||||
# Inefficient: string columns use lots of memory
|
||||
adata.obs['cell_type'] = ['Type_A', 'Type_B', 'Type_C'] * 333 + ['Type_A']
|
||||
|
||||
# Efficient: convert to categorical
|
||||
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')
|
||||
|
||||
# Convert all string columns
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# Benefits: 10-50x memory reduction for repeated strings
|
||||
```
|
||||
|
||||
### Use backed mode for large datasets
|
||||
```python
|
||||
# Don't load entire dataset into memory
|
||||
adata = ad.read_h5ad('large_dataset.h5ad', backed='r')
|
||||
|
||||
# Work with metadata
|
||||
filtered = adata[adata.obs['quality'] > 0.8]
|
||||
|
||||
# Load only filtered subset
|
||||
adata_subset = filtered.to_memory()
|
||||
|
||||
# Benefits: Work with datasets larger than RAM
|
||||
```
|
||||
|
||||
## Views vs Copies
|
||||
|
||||
### Understanding views
|
||||
```python
|
||||
# Subsetting creates a view by default
|
||||
subset = adata[0:100, :]
|
||||
print(subset.is_view) # True
|
||||
|
||||
# Views don't copy data (memory efficient)
|
||||
# But modifications can affect original
|
||||
|
||||
# Check if object is a view
|
||||
if adata.is_view:
|
||||
adata = adata.copy() # Make independent
|
||||
```
|
||||
|
||||
### When to use views
|
||||
```python
|
||||
# Good: Read-only operations on subsets
|
||||
mean_expr = adata[adata.obs['cell_type'] == 'T cell'].X.mean()
|
||||
|
||||
# Good: Temporary analysis
|
||||
temp_subset = adata[:100, :]
|
||||
result = analyze(temp_subset.X)
|
||||
```
|
||||
|
||||
### When to use copies
|
||||
```python
|
||||
# Create independent copy for modifications
|
||||
adata_filtered = adata[keep_cells, :].copy()
|
||||
|
||||
# Safe to modify without affecting original
|
||||
adata_filtered.obs['new_column'] = values
|
||||
|
||||
# Always copy when:
|
||||
# - Storing subset for later use
|
||||
# - Modifying subset data
|
||||
# - Passing to function that modifies data
|
||||
```
|
||||
|
||||
## Data Storage Best Practices
|
||||
|
||||
### Choose the right format
|
||||
|
||||
**H5AD (HDF5) - Default choice**
|
||||
```python
|
||||
adata.write_h5ad('data.h5ad', compression='gzip')
|
||||
```
|
||||
- Fast random access
|
||||
- Supports backed mode
|
||||
- Good compression
|
||||
- Best for: Most use cases
|
||||
|
||||
**Zarr - Cloud and parallel access**
|
||||
```python
|
||||
adata.write_zarr('data.zarr', chunks=(100, 100))
|
||||
```
|
||||
- Excellent for cloud storage (S3, GCS)
|
||||
- Supports parallel I/O
|
||||
- Good compression
|
||||
- Best for: Large datasets, cloud workflows, parallel processing
|
||||
|
||||
**CSV - Interoperability**
|
||||
```python
|
||||
adata.write_csvs('output_dir/')
|
||||
```
|
||||
- Human readable
|
||||
- Compatible with all tools
|
||||
- Large file sizes, slow
|
||||
- Best for: Sharing with non-Python tools, small datasets
|
||||
|
||||
### Optimize file size
|
||||
```python
|
||||
# Before saving, optimize:
|
||||
|
||||
# 1. Convert to sparse if appropriate
|
||||
from scipy.sparse import csr_matrix, issparse
|
||||
if not issparse(adata.X):
|
||||
density = np.count_nonzero(adata.X) / adata.X.size
|
||||
if density < 0.5:
|
||||
adata.X = csr_matrix(adata.X)
|
||||
|
||||
# 2. Convert strings to categoricals
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# 3. Use compression
|
||||
adata.write_h5ad('data.h5ad', compression='gzip', compression_opts=9)
|
||||
|
||||
# Typical results: 5-20x file size reduction
|
||||
```
|
||||
|
||||
## Backed Mode Strategies
|
||||
|
||||
### Read-only analysis
|
||||
```python
|
||||
# Open in read-only backed mode
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r')
|
||||
|
||||
# Perform filtering without loading data
|
||||
high_quality = adata[adata.obs['quality_score'] > 0.8]
|
||||
|
||||
# Load only filtered data
|
||||
adata_filtered = high_quality.to_memory()
|
||||
```
|
||||
|
||||
### Read-write modifications
|
||||
```python
|
||||
# Open in read-write backed mode
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r+')
|
||||
|
||||
# Modify metadata (written to disk)
|
||||
adata.obs['new_annotation'] = values
|
||||
|
||||
# X remains on disk, modifications saved immediately
|
||||
```
|
||||
|
||||
### Chunked processing
|
||||
```python
|
||||
# Process large dataset in chunks
|
||||
adata = ad.read_h5ad('huge_dataset.h5ad', backed='r')
|
||||
|
||||
results = []
|
||||
chunk_size = 1000
|
||||
|
||||
for i in range(0, adata.n_obs, chunk_size):
|
||||
chunk = adata[i:i+chunk_size, :].to_memory()
|
||||
result = process(chunk)
|
||||
results.append(result)
|
||||
|
||||
final_result = combine(results)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Subsetting performance
|
||||
```python
|
||||
# Fast: Boolean indexing with arrays
|
||||
mask = np.array(adata.obs['quality'] > 0.5)
|
||||
subset = adata[mask, :]
|
||||
|
||||
# Slow: Boolean indexing with Series (creates view chain)
|
||||
subset = adata[adata.obs['quality'] > 0.5, :]
|
||||
|
||||
# Fastest: Integer indices
|
||||
indices = np.where(adata.obs['quality'] > 0.5)[0]
|
||||
subset = adata[indices, :]
|
||||
```
|
||||
|
||||
### Avoid repeated subsetting
|
||||
```python
|
||||
# Inefficient: Multiple subset operations
|
||||
for cell_type in ['A', 'B', 'C']:
|
||||
subset = adata[adata.obs['cell_type'] == cell_type]
|
||||
process(subset)
|
||||
|
||||
# Efficient: Group and process
|
||||
groups = adata.obs.groupby('cell_type').groups
|
||||
for cell_type, indices in groups.items():
|
||||
subset = adata[indices, :]
|
||||
process(subset)
|
||||
```
|
||||
|
||||
### Use chunked operations for large matrices
|
||||
```python
|
||||
# Process X in chunks
|
||||
for chunk in adata.chunked_X(chunk_size=1000):
|
||||
result = compute(chunk)
|
||||
|
||||
# More memory efficient than loading full X
|
||||
```
|
||||
|
||||
## Working with Raw Data
|
||||
|
||||
### Store raw before filtering
|
||||
```python
|
||||
# Original data with all genes
|
||||
adata = ad.AnnData(X=counts)
|
||||
|
||||
# Store raw before filtering
|
||||
adata.raw = adata.copy()
|
||||
|
||||
# Filter to highly variable genes
|
||||
adata = adata[:, adata.var['highly_variable']]
|
||||
|
||||
# Later: access original data
|
||||
original_expression = adata.raw.X
|
||||
all_genes = adata.raw.var_names
|
||||
```
|
||||
|
||||
### When to use raw
|
||||
```python
|
||||
# Use raw for:
|
||||
# - Differential expression on filtered genes
|
||||
# - Visualization of specific genes not in filtered set
|
||||
# - Accessing original counts after normalization
|
||||
|
||||
# Access raw data
|
||||
if adata.raw is not None:
|
||||
gene_expr = adata.raw[:, 'GENE_NAME'].X
|
||||
else:
|
||||
gene_expr = adata[:, 'GENE_NAME'].X
|
||||
```
|
||||
|
||||
## Metadata Management
|
||||
|
||||
### Naming conventions
|
||||
```python
|
||||
# Consistent naming improves usability
|
||||
|
||||
# Observation metadata (obs):
|
||||
# - cell_id, sample_id
|
||||
# - cell_type, tissue, condition
|
||||
# - n_genes, n_counts, percent_mito
|
||||
# - cluster, leiden, louvain
|
||||
|
||||
# Variable metadata (var):
|
||||
# - gene_id, gene_name
|
||||
# - highly_variable, n_cells
|
||||
# - mean_expression, dispersion
|
||||
|
||||
# Embeddings (obsm):
|
||||
# - X_pca, X_umap, X_tsne
|
||||
# - X_diffmap, X_draw_graph_fr
|
||||
|
||||
# Follow conventions from scanpy/scverse ecosystem
|
||||
```
|
||||
|
||||
### Document metadata
|
||||
```python
|
||||
# Store metadata descriptions in uns
|
||||
adata.uns['metadata_descriptions'] = {
|
||||
'cell_type': 'Cell type annotation from automated clustering',
|
||||
'quality_score': 'QC score from scrublet (0-1, higher is better)',
|
||||
'batch': 'Experimental batch identifier'
|
||||
}
|
||||
|
||||
# Store processing history
|
||||
adata.uns['processing_steps'] = [
|
||||
'Raw counts loaded from 10X',
|
||||
'Filtered: n_genes > 200, n_counts < 50000',
|
||||
'Normalized to 10000 counts per cell',
|
||||
'Log transformed'
|
||||
]
|
||||
```
|
||||
|
||||
## Reproducibility
|
||||
|
||||
### Set random seeds
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
# Set seed for reproducible results
|
||||
np.random.seed(42)
|
||||
|
||||
# Document in uns
|
||||
adata.uns['random_seed'] = 42
|
||||
```
|
||||
|
||||
### Store parameters
|
||||
```python
|
||||
# Store analysis parameters in uns
|
||||
adata.uns['pca'] = {
|
||||
'n_comps': 50,
|
||||
'svd_solver': 'arpack',
|
||||
'random_state': 42
|
||||
}
|
||||
|
||||
adata.uns['neighbors'] = {
|
||||
'n_neighbors': 15,
|
||||
'n_pcs': 50,
|
||||
'metric': 'euclidean',
|
||||
'method': 'umap'
|
||||
}
|
||||
```
|
||||
|
||||
### Version tracking
|
||||
```python
|
||||
import anndata
|
||||
import scanpy
|
||||
import numpy
|
||||
|
||||
# Store versions
|
||||
adata.uns['versions'] = {
|
||||
'anndata': anndata.__version__,
|
||||
'scanpy': scanpy.__version__,
|
||||
'numpy': numpy.__version__,
|
||||
'python': sys.version
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Check data validity
|
||||
```python
|
||||
# Verify dimensions
|
||||
assert adata.n_obs == len(adata.obs)
|
||||
assert adata.n_vars == len(adata.var)
|
||||
assert adata.X.shape == (adata.n_obs, adata.n_vars)
|
||||
|
||||
# Check for NaN values
|
||||
has_nan = np.isnan(adata.X.data).any() if issparse(adata.X) else np.isnan(adata.X).any()
|
||||
if has_nan:
|
||||
print("Warning: Data contains NaN values")
|
||||
|
||||
# Check for negative values (if counts expected)
|
||||
has_negative = (adata.X.data < 0).any() if issparse(adata.X) else (adata.X < 0).any()
|
||||
if has_negative:
|
||||
print("Warning: Data contains negative values")
|
||||
```
|
||||
|
||||
### Validate metadata
|
||||
```python
|
||||
# Check for missing values
|
||||
missing_obs = adata.obs.isnull().sum()
|
||||
if missing_obs.any():
|
||||
print("Missing values in obs:")
|
||||
print(missing_obs[missing_obs > 0])
|
||||
|
||||
# Verify indices are unique
|
||||
assert adata.obs_names.is_unique, "Observation names not unique"
|
||||
assert adata.var_names.is_unique, "Variable names not unique"
|
||||
|
||||
# Check metadata alignment
|
||||
assert len(adata.obs) == adata.n_obs
|
||||
assert len(adata.var) == adata.n_vars
|
||||
```
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
### Scanpy integration
|
||||
```python
|
||||
import scanpy as sc
|
||||
|
||||
# AnnData is native format for scanpy
|
||||
sc.pp.filter_cells(adata, min_genes=200)
|
||||
sc.pp.filter_genes(adata, min_cells=3)
|
||||
sc.pp.normalize_total(adata, target_sum=1e4)
|
||||
sc.pp.log1p(adata)
|
||||
sc.pp.highly_variable_genes(adata)
|
||||
sc.pp.pca(adata)
|
||||
sc.pp.neighbors(adata)
|
||||
sc.tl.umap(adata)
|
||||
```
|
||||
|
||||
### Pandas integration
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Convert to DataFrame
|
||||
df = adata.to_df()
|
||||
|
||||
# Create from DataFrame
|
||||
adata = ad.AnnData(df)
|
||||
|
||||
# Work with metadata as DataFrames
|
||||
adata.obs = adata.obs.merge(external_metadata, left_index=True, right_index=True)
|
||||
```
|
||||
|
||||
### PyTorch integration
|
||||
```python
|
||||
from anndata.experimental import AnnLoader
|
||||
|
||||
# Create PyTorch DataLoader
|
||||
dataloader = AnnLoader(adata, batch_size=128, shuffle=True)
|
||||
|
||||
# Iterate in training loop
|
||||
for batch in dataloader:
|
||||
X = batch.X
|
||||
# Train model on batch
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### Pitfall 1: Modifying views
|
||||
```python
|
||||
# Wrong: Modifying view can affect original
|
||||
subset = adata[:100, :]
|
||||
subset.X = new_data # May modify adata.X!
|
||||
|
||||
# Correct: Copy before modifying
|
||||
subset = adata[:100, :].copy()
|
||||
subset.X = new_data # Independent copy
|
||||
```
|
||||
|
||||
### Pitfall 2: Index misalignment
|
||||
```python
|
||||
# Wrong: Assuming order matches
|
||||
external_data = pd.read_csv('data.csv')
|
||||
adata.obs['new_col'] = external_data['values'] # May misalign!
|
||||
|
||||
# Correct: Align on index
|
||||
adata.obs['new_col'] = external_data.set_index('cell_id').loc[adata.obs_names, 'values']
|
||||
```
|
||||
|
||||
### Pitfall 3: Mixing sparse and dense
|
||||
```python
|
||||
# Wrong: Converting sparse to dense uses huge memory
|
||||
result = adata.X + 1 # Converts sparse to dense!
|
||||
|
||||
# Correct: Use sparse operations
|
||||
from scipy.sparse import issparse
|
||||
if issparse(adata.X):
|
||||
result = adata.X.copy()
|
||||
result.data += 1
|
||||
```
|
||||
|
||||
### Pitfall 4: Not handling views
|
||||
```python
|
||||
# Wrong: Assuming subset is independent
|
||||
subset = adata[mask, :]
|
||||
del adata # subset may become invalid!
|
||||
|
||||
# Correct: Copy when needed
|
||||
subset = adata[mask, :].copy()
|
||||
del adata # subset remains valid
|
||||
```
|
||||
|
||||
### Pitfall 5: Ignoring memory constraints
|
||||
```python
|
||||
# Wrong: Loading huge dataset into memory
|
||||
adata = ad.read_h5ad('100GB_file.h5ad') # OOM error!
|
||||
|
||||
# Correct: Use backed mode
|
||||
adata = ad.read_h5ad('100GB_file.h5ad', backed='r')
|
||||
subset = adata[adata.obs['keep']].to_memory()
|
||||
```
|
||||
|
||||
## Workflow Example
|
||||
|
||||
Complete best-practices workflow:
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# 1. Load with backed mode if large
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r')
|
||||
|
||||
# 2. Quick metadata check without loading data
|
||||
print(f"Dataset: {adata.n_obs} cells × {adata.n_vars} genes")
|
||||
|
||||
# 3. Filter based on metadata
|
||||
high_quality = adata[adata.obs['quality_score'] > 0.8]
|
||||
|
||||
# 4. Load filtered subset to memory
|
||||
adata = high_quality.to_memory()
|
||||
|
||||
# 5. Convert to optimal storage types
|
||||
adata.strings_to_categoricals()
|
||||
if not issparse(adata.X):
|
||||
density = np.count_nonzero(adata.X) / adata.X.size
|
||||
if density < 0.5:
|
||||
adata.X = csr_matrix(adata.X)
|
||||
|
||||
# 6. Store raw before filtering genes
|
||||
adata.raw = adata.copy()
|
||||
|
||||
# 7. Filter to highly variable genes
|
||||
adata = adata[:, adata.var['highly_variable']].copy()
|
||||
|
||||
# 8. Document processing
|
||||
adata.uns['processing'] = {
|
||||
'filtered': 'quality_score > 0.8',
|
||||
'n_hvg': adata.n_vars,
|
||||
'date': '2025-11-03'
|
||||
}
|
||||
|
||||
# 9. Save optimized
|
||||
adata.write_h5ad('processed.h5ad', compression='gzip')
|
||||
```
|
||||
396
scientific-skills/anndata/references/concatenation.md
Normal file
396
scientific-skills/anndata/references/concatenation.md
Normal file
@@ -0,0 +1,396 @@
|
||||
# Concatenating AnnData Objects
|
||||
|
||||
Combine multiple AnnData objects along either observations or variables axis.
|
||||
|
||||
## Basic Concatenation
|
||||
|
||||
### Concatenate along observations (stack cells/samples)
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
|
||||
# Create multiple AnnData objects
|
||||
adata1 = ad.AnnData(X=np.random.rand(100, 50))
|
||||
adata2 = ad.AnnData(X=np.random.rand(150, 50))
|
||||
adata3 = ad.AnnData(X=np.random.rand(200, 50))
|
||||
|
||||
# Concatenate along observations (axis=0, default)
|
||||
adata_combined = ad.concat([adata1, adata2, adata3], axis=0)
|
||||
|
||||
print(adata_combined.shape) # (450, 50)
|
||||
```
|
||||
|
||||
### Concatenate along variables (stack genes/features)
|
||||
```python
|
||||
# Create objects with same observations, different variables
|
||||
adata1 = ad.AnnData(X=np.random.rand(100, 50))
|
||||
adata2 = ad.AnnData(X=np.random.rand(100, 30))
|
||||
adata3 = ad.AnnData(X=np.random.rand(100, 70))
|
||||
|
||||
# Concatenate along variables (axis=1)
|
||||
adata_combined = ad.concat([adata1, adata2, adata3], axis=1)
|
||||
|
||||
print(adata_combined.shape) # (100, 150)
|
||||
```
|
||||
|
||||
## Join Types
|
||||
|
||||
### Inner join (intersection)
|
||||
Keep only variables/observations present in all objects.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create objects with different variables
|
||||
adata1 = ad.AnnData(
|
||||
X=np.random.rand(100, 50),
|
||||
var=pd.DataFrame(index=[f'Gene_{i}' for i in range(50)])
|
||||
)
|
||||
adata2 = ad.AnnData(
|
||||
X=np.random.rand(150, 60),
|
||||
var=pd.DataFrame(index=[f'Gene_{i}' for i in range(10, 70)])
|
||||
)
|
||||
|
||||
# Inner join: only genes 10-49 are kept (overlap)
|
||||
adata_inner = ad.concat([adata1, adata2], join='inner')
|
||||
print(adata_inner.n_vars) # 40 genes (overlap)
|
||||
```
|
||||
|
||||
### Outer join (union)
|
||||
Keep all variables/observations, filling missing values.
|
||||
|
||||
```python
|
||||
# Outer join: all genes are kept
|
||||
adata_outer = ad.concat([adata1, adata2], join='outer')
|
||||
print(adata_outer.n_vars) # 70 genes (union)
|
||||
|
||||
# Missing values are filled with appropriate defaults:
|
||||
# - 0 for sparse matrices
|
||||
# - NaN for dense matrices
|
||||
```
|
||||
|
||||
### Fill values for outer joins
|
||||
```python
|
||||
# Specify fill value for missing data
|
||||
adata_filled = ad.concat([adata1, adata2], join='outer', fill_value=0)
|
||||
```
|
||||
|
||||
## Tracking Data Sources
|
||||
|
||||
### Add batch labels
|
||||
```python
|
||||
# Label which object each observation came from
|
||||
adata_combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='batch', # Column name for labels
|
||||
keys=['batch1', 'batch2', 'batch3'] # Labels for each object
|
||||
)
|
||||
|
||||
print(adata_combined.obs['batch'].value_counts())
|
||||
# batch1 100
|
||||
# batch2 150
|
||||
# batch3 200
|
||||
```
|
||||
|
||||
### Automatic batch labels
|
||||
```python
|
||||
# If keys not provided, uses integer indices
|
||||
adata_combined = ad.concat(
|
||||
[adata1, adata2, adata3],
|
||||
label='dataset'
|
||||
)
|
||||
# dataset column contains: 0, 1, 2
|
||||
```
|
||||
|
||||
## Merge Strategies
|
||||
|
||||
Control how metadata from different objects is combined using the `merge` parameter.
|
||||
|
||||
### merge=None (default for observations)
|
||||
Exclude metadata on non-concatenation axis.
|
||||
|
||||
```python
|
||||
# When concatenating observations, var metadata must match
|
||||
adata1.var['gene_type'] = 'protein_coding'
|
||||
adata2.var['gene_type'] = 'protein_coding'
|
||||
|
||||
# var is kept only if identical across all objects
|
||||
adata_combined = ad.concat([adata1, adata2], merge=None)
|
||||
```
|
||||
|
||||
### merge='same'
|
||||
Keep metadata that is identical across all objects.
|
||||
|
||||
```python
|
||||
adata1.var['chromosome'] = ['chr1'] * 25 + ['chr2'] * 25
|
||||
adata2.var['chromosome'] = ['chr1'] * 25 + ['chr2'] * 25
|
||||
adata1.var['type'] = 'protein_coding'
|
||||
adata2.var['type'] = 'lncRNA' # Different
|
||||
|
||||
# 'chromosome' is kept (same), 'type' is excluded (different)
|
||||
adata_combined = ad.concat([adata1, adata2], merge='same')
|
||||
```
|
||||
|
||||
### merge='unique'
|
||||
Keep metadata columns where each key has exactly one value.
|
||||
|
||||
```python
|
||||
adata1.var['gene_id'] = [f'ENSG{i:05d}' for i in range(50)]
|
||||
adata2.var['gene_id'] = [f'ENSG{i:05d}' for i in range(50)]
|
||||
|
||||
# gene_id is kept (unique values for each key)
|
||||
adata_combined = ad.concat([adata1, adata2], merge='unique')
|
||||
```
|
||||
|
||||
### merge='first'
|
||||
Take values from the first object containing each key.
|
||||
|
||||
```python
|
||||
adata1.var['description'] = ['Desc1'] * 50
|
||||
adata2.var['description'] = ['Desc2'] * 50
|
||||
|
||||
# Uses descriptions from adata1
|
||||
adata_combined = ad.concat([adata1, adata2], merge='first')
|
||||
```
|
||||
|
||||
### merge='only'
|
||||
Keep metadata that appears in only one object.
|
||||
|
||||
```python
|
||||
adata1.var['adata1_specific'] = [1] * 50
|
||||
adata2.var['adata2_specific'] = [2] * 50
|
||||
|
||||
# Both metadata columns are kept
|
||||
adata_combined = ad.concat([adata1, adata2], merge='only')
|
||||
```
|
||||
|
||||
## Handling Index Conflicts
|
||||
|
||||
### Make indices unique
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create objects with overlapping observation names
|
||||
adata1 = ad.AnnData(
|
||||
X=np.random.rand(3, 10),
|
||||
obs=pd.DataFrame(index=['cell_1', 'cell_2', 'cell_3'])
|
||||
)
|
||||
adata2 = ad.AnnData(
|
||||
X=np.random.rand(3, 10),
|
||||
obs=pd.DataFrame(index=['cell_1', 'cell_2', 'cell_3'])
|
||||
)
|
||||
|
||||
# Make indices unique by appending batch keys
|
||||
adata_combined = ad.concat(
|
||||
[adata1, adata2],
|
||||
label='batch',
|
||||
keys=['batch1', 'batch2'],
|
||||
index_unique='_' # Separator for making indices unique
|
||||
)
|
||||
|
||||
print(adata_combined.obs_names)
|
||||
# ['cell_1_batch1', 'cell_2_batch1', 'cell_3_batch1',
|
||||
# 'cell_1_batch2', 'cell_2_batch2', 'cell_3_batch2']
|
||||
```
|
||||
|
||||
## Concatenating Layers
|
||||
|
||||
```python
|
||||
# Objects with layers
|
||||
adata1 = ad.AnnData(X=np.random.rand(100, 50))
|
||||
adata1.layers['normalized'] = np.random.rand(100, 50)
|
||||
adata1.layers['scaled'] = np.random.rand(100, 50)
|
||||
|
||||
adata2 = ad.AnnData(X=np.random.rand(150, 50))
|
||||
adata2.layers['normalized'] = np.random.rand(150, 50)
|
||||
adata2.layers['scaled'] = np.random.rand(150, 50)
|
||||
|
||||
# Layers are concatenated automatically if present in all objects
|
||||
adata_combined = ad.concat([adata1, adata2])
|
||||
|
||||
print(adata_combined.layers.keys())
|
||||
# dict_keys(['normalized', 'scaled'])
|
||||
```
|
||||
|
||||
## Concatenating Multi-dimensional Annotations
|
||||
|
||||
### obsm/varm
|
||||
```python
|
||||
# Objects with embeddings
|
||||
adata1.obsm['X_pca'] = np.random.rand(100, 50)
|
||||
adata2.obsm['X_pca'] = np.random.rand(150, 50)
|
||||
|
||||
# obsm is concatenated along observation axis
|
||||
adata_combined = ad.concat([adata1, adata2])
|
||||
print(adata_combined.obsm['X_pca'].shape) # (250, 50)
|
||||
```
|
||||
|
||||
### obsp/varp (pairwise annotations)
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Pairwise matrices
|
||||
adata1.obsp['connectivities'] = csr_matrix((100, 100))
|
||||
adata2.obsp['connectivities'] = csr_matrix((150, 150))
|
||||
|
||||
# By default, obsp is NOT concatenated (set pairwise=True to include)
|
||||
adata_combined = ad.concat([adata1, adata2])
|
||||
# adata_combined.obsp is empty
|
||||
|
||||
# Include pairwise data (creates block diagonal matrix)
|
||||
adata_combined = ad.concat([adata1, adata2], pairwise=True)
|
||||
print(adata_combined.obsp['connectivities'].shape) # (250, 250)
|
||||
```
|
||||
|
||||
## Concatenating uns (unstructured)
|
||||
|
||||
Unstructured metadata is merged recursively:
|
||||
|
||||
```python
|
||||
adata1.uns['experiment'] = {'date': '2025-01-01', 'batch': 'A'}
|
||||
adata2.uns['experiment'] = {'date': '2025-01-01', 'batch': 'B'}
|
||||
|
||||
# Using merge='unique' for uns
|
||||
adata_combined = ad.concat([adata1, adata2], uns_merge='unique')
|
||||
# 'date' is kept (same value), 'batch' might be excluded (different values)
|
||||
```
|
||||
|
||||
## Lazy Concatenation (AnnCollection)
|
||||
|
||||
For very large datasets, use lazy concatenation that doesn't load all data:
|
||||
|
||||
```python
|
||||
from anndata.experimental import AnnCollection
|
||||
|
||||
# Create collection from file paths (doesn't load data)
|
||||
files = ['data1.h5ad', 'data2.h5ad', 'data3.h5ad']
|
||||
collection = AnnCollection(
|
||||
files,
|
||||
join_obs='outer',
|
||||
join_vars='inner',
|
||||
label='dataset',
|
||||
keys=['dataset1', 'dataset2', 'dataset3']
|
||||
)
|
||||
|
||||
# Access data lazily
|
||||
print(collection.n_obs) # Total observations
|
||||
print(collection.obs.head()) # Metadata loaded, not X
|
||||
|
||||
# Convert to regular AnnData when needed (loads all data)
|
||||
adata = collection.to_adata()
|
||||
```
|
||||
|
||||
### Working with AnnCollection
|
||||
```python
|
||||
# Subset without loading data
|
||||
subset = collection[collection.obs['cell_type'] == 'T cell']
|
||||
|
||||
# Iterate through datasets
|
||||
for adata in collection:
|
||||
print(adata.shape)
|
||||
|
||||
# Access specific dataset
|
||||
first_dataset = collection[0]
|
||||
```
|
||||
|
||||
## Concatenation on Disk
|
||||
|
||||
For datasets too large for memory, concatenate directly on disk:
|
||||
|
||||
```python
|
||||
from anndata.experimental import concat_on_disk
|
||||
|
||||
# Concatenate without loading into memory
|
||||
concat_on_disk(
|
||||
['data1.h5ad', 'data2.h5ad', 'data3.h5ad'],
|
||||
'combined.h5ad',
|
||||
join='outer'
|
||||
)
|
||||
|
||||
# Load result in backed mode
|
||||
adata = ad.read_h5ad('combined.h5ad', backed='r')
|
||||
```
|
||||
|
||||
## Common Concatenation Patterns
|
||||
|
||||
### Combine technical replicates
|
||||
```python
|
||||
# Multiple runs of the same samples
|
||||
replicates = [adata_run1, adata_run2, adata_run3]
|
||||
adata_combined = ad.concat(
|
||||
replicates,
|
||||
label='technical_replicate',
|
||||
keys=['rep1', 'rep2', 'rep3'],
|
||||
join='inner' # Keep only genes measured in all runs
|
||||
)
|
||||
```
|
||||
|
||||
### Combine batches from experiment
|
||||
```python
|
||||
# Different experimental batches
|
||||
batches = [adata_batch1, adata_batch2, adata_batch3]
|
||||
adata_combined = ad.concat(
|
||||
batches,
|
||||
label='batch',
|
||||
keys=['batch1', 'batch2', 'batch3'],
|
||||
join='outer' # Keep all genes
|
||||
)
|
||||
|
||||
# Later: apply batch correction
|
||||
```
|
||||
|
||||
### Merge multi-modal data
|
||||
```python
|
||||
# Different measurement modalities (e.g., RNA + protein)
|
||||
adata_rna = ad.AnnData(X=np.random.rand(100, 2000))
|
||||
adata_protein = ad.AnnData(X=np.random.rand(100, 50))
|
||||
|
||||
# Concatenate along variables to combine modalities
|
||||
adata_multimodal = ad.concat([adata_rna, adata_protein], axis=1)
|
||||
|
||||
# Add labels to distinguish modalities
|
||||
adata_multimodal.var['modality'] = ['RNA'] * 2000 + ['protein'] * 50
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Check compatibility before concatenating**
|
||||
```python
|
||||
# Verify shapes are compatible
|
||||
print([adata.n_vars for adata in [adata1, adata2, adata3]])
|
||||
|
||||
# Check variable names match
|
||||
print([set(adata.var_names) for adata in [adata1, adata2, adata3]])
|
||||
```
|
||||
|
||||
2. **Use appropriate join type**
|
||||
- `inner`: When you need the same features across all samples (most stringent)
|
||||
- `outer`: When you want to preserve all features (most inclusive)
|
||||
|
||||
3. **Track data sources**
|
||||
Always use `label` and `keys` to track which observations came from which dataset.
|
||||
|
||||
4. **Consider memory usage**
|
||||
- For large datasets, use `AnnCollection` or `concat_on_disk`
|
||||
- Consider backed mode for the result
|
||||
|
||||
5. **Handle batch effects**
|
||||
Concatenation combines data but doesn't correct for batch effects. Apply batch correction after concatenation:
|
||||
```python
|
||||
# After concatenation, apply batch correction
|
||||
import scanpy as sc
|
||||
sc.pp.combat(adata_combined, key='batch')
|
||||
```
|
||||
|
||||
6. **Validate results**
|
||||
```python
|
||||
# Check dimensions
|
||||
print(adata_combined.shape)
|
||||
|
||||
# Check batch distribution
|
||||
print(adata_combined.obs['batch'].value_counts())
|
||||
|
||||
# Verify metadata integrity
|
||||
print(adata_combined.var.head())
|
||||
print(adata_combined.obs.head())
|
||||
```
|
||||
314
scientific-skills/anndata/references/data_structure.md
Normal file
314
scientific-skills/anndata/references/data_structure.md
Normal file
@@ -0,0 +1,314 @@
|
||||
# AnnData Object Structure
|
||||
|
||||
The AnnData object stores a data matrix with associated annotations, providing a flexible framework for managing experimental data and metadata.
|
||||
|
||||
## Core Components
|
||||
|
||||
### X (Data Matrix)
|
||||
The primary data matrix with shape (n_obs, n_vars) storing experimental measurements.
|
||||
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
|
||||
# Create with dense array
|
||||
adata = ad.AnnData(X=np.random.rand(100, 2000))
|
||||
|
||||
# Create with sparse matrix (recommended for large, sparse data)
|
||||
from scipy.sparse import csr_matrix
|
||||
sparse_data = csr_matrix(np.random.rand(100, 2000))
|
||||
adata = ad.AnnData(X=sparse_data)
|
||||
```
|
||||
|
||||
Access data:
|
||||
```python
|
||||
# Full matrix (caution with large datasets)
|
||||
full_data = adata.X
|
||||
|
||||
# Single observation
|
||||
obs_data = adata.X[0, :]
|
||||
|
||||
# Single variable across all observations
|
||||
var_data = adata.X[:, 0]
|
||||
```
|
||||
|
||||
### obs (Observation Annotations)
|
||||
DataFrame storing metadata about observations (rows). Each row corresponds to one observation in X.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create AnnData with observation metadata
|
||||
obs_df = pd.DataFrame({
|
||||
'cell_type': ['T cell', 'B cell', 'Monocyte'],
|
||||
'treatment': ['control', 'treated', 'control'],
|
||||
'timepoint': [0, 24, 24]
|
||||
}, index=['cell_1', 'cell_2', 'cell_3'])
|
||||
|
||||
adata = ad.AnnData(X=np.random.rand(3, 100), obs=obs_df)
|
||||
|
||||
# Access observation metadata
|
||||
print(adata.obs['cell_type'])
|
||||
print(adata.obs.loc['cell_1'])
|
||||
```
|
||||
|
||||
### var (Variable Annotations)
|
||||
DataFrame storing metadata about variables (columns). Each row corresponds to one variable in X.
|
||||
|
||||
```python
|
||||
# Create AnnData with variable metadata
|
||||
var_df = pd.DataFrame({
|
||||
'gene_name': ['ACTB', 'GAPDH', 'TP53'],
|
||||
'chromosome': ['7', '12', '17'],
|
||||
'highly_variable': [True, False, True]
|
||||
}, index=['ENSG00001', 'ENSG00002', 'ENSG00003'])
|
||||
|
||||
adata = ad.AnnData(X=np.random.rand(100, 3), var=var_df)
|
||||
|
||||
# Access variable metadata
|
||||
print(adata.var['gene_name'])
|
||||
print(adata.var.loc['ENSG00001'])
|
||||
```
|
||||
|
||||
### layers (Alternative Data Representations)
|
||||
Dictionary storing alternative matrices with the same dimensions as X.
|
||||
|
||||
```python
|
||||
# Store raw counts, normalized data, and scaled data
|
||||
adata = ad.AnnData(X=np.random.rand(100, 2000))
|
||||
adata.layers['raw_counts'] = np.random.randint(0, 100, (100, 2000))
|
||||
adata.layers['normalized'] = adata.X / np.sum(adata.X, axis=1, keepdims=True)
|
||||
adata.layers['scaled'] = (adata.X - adata.X.mean()) / adata.X.std()
|
||||
|
||||
# Access layers
|
||||
raw_data = adata.layers['raw_counts']
|
||||
normalized_data = adata.layers['normalized']
|
||||
```
|
||||
|
||||
Common layer uses:
|
||||
- `raw_counts`: Original count data before normalization
|
||||
- `normalized`: Log-normalized or TPM values
|
||||
- `scaled`: Z-scored values for analysis
|
||||
- `imputed`: Data after imputation
|
||||
|
||||
### obsm (Multi-dimensional Observation Annotations)
|
||||
Dictionary storing multi-dimensional arrays aligned to observations.
|
||||
|
||||
```python
|
||||
# Store PCA coordinates and UMAP embeddings
|
||||
adata.obsm['X_pca'] = np.random.rand(100, 50) # 50 principal components
|
||||
adata.obsm['X_umap'] = np.random.rand(100, 2) # 2D UMAP coordinates
|
||||
adata.obsm['X_tsne'] = np.random.rand(100, 2) # 2D t-SNE coordinates
|
||||
|
||||
# Access embeddings
|
||||
pca_coords = adata.obsm['X_pca']
|
||||
umap_coords = adata.obsm['X_umap']
|
||||
```
|
||||
|
||||
Common obsm uses:
|
||||
- `X_pca`: Principal component coordinates
|
||||
- `X_umap`: UMAP embedding coordinates
|
||||
- `X_tsne`: t-SNE embedding coordinates
|
||||
- `X_diffmap`: Diffusion map coordinates
|
||||
- `protein_expression`: Protein abundance measurements (CITE-seq)
|
||||
|
||||
### varm (Multi-dimensional Variable Annotations)
|
||||
Dictionary storing multi-dimensional arrays aligned to variables.
|
||||
|
||||
```python
|
||||
# Store PCA loadings
|
||||
adata.varm['PCs'] = np.random.rand(2000, 50) # Loadings for 50 components
|
||||
adata.varm['gene_modules'] = np.random.rand(2000, 10) # Gene module scores
|
||||
|
||||
# Access loadings
|
||||
pc_loadings = adata.varm['PCs']
|
||||
```
|
||||
|
||||
Common varm uses:
|
||||
- `PCs`: Principal component loadings
|
||||
- `gene_modules`: Gene co-expression module assignments
|
||||
|
||||
### obsp (Pairwise Observation Relationships)
|
||||
Dictionary storing sparse matrices representing relationships between observations.
|
||||
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Store k-nearest neighbor graph
|
||||
n_obs = 100
|
||||
knn_graph = csr_matrix(np.random.rand(n_obs, n_obs) > 0.95)
|
||||
adata.obsp['connectivities'] = knn_graph
|
||||
adata.obsp['distances'] = csr_matrix(np.random.rand(n_obs, n_obs))
|
||||
|
||||
# Access graphs
|
||||
knn_connections = adata.obsp['connectivities']
|
||||
distances = adata.obsp['distances']
|
||||
```
|
||||
|
||||
Common obsp uses:
|
||||
- `connectivities`: Cell-cell neighborhood graph
|
||||
- `distances`: Pairwise distances between cells
|
||||
|
||||
### varp (Pairwise Variable Relationships)
|
||||
Dictionary storing sparse matrices representing relationships between variables.
|
||||
|
||||
```python
|
||||
# Store gene-gene correlation matrix
|
||||
n_vars = 2000
|
||||
gene_corr = csr_matrix(np.random.rand(n_vars, n_vars) > 0.99)
|
||||
adata.varp['correlations'] = gene_corr
|
||||
|
||||
# Access correlations
|
||||
gene_correlations = adata.varp['correlations']
|
||||
```
|
||||
|
||||
### uns (Unstructured Annotations)
|
||||
Dictionary storing arbitrary unstructured metadata.
|
||||
|
||||
```python
|
||||
# Store analysis parameters and results
|
||||
adata.uns['experiment_date'] = '2025-11-03'
|
||||
adata.uns['pca'] = {
|
||||
'variance_ratio': [0.15, 0.10, 0.08],
|
||||
'params': {'n_comps': 50}
|
||||
}
|
||||
adata.uns['neighbors'] = {
|
||||
'params': {'n_neighbors': 15, 'method': 'umap'},
|
||||
'connectivities_key': 'connectivities'
|
||||
}
|
||||
|
||||
# Access unstructured data
|
||||
exp_date = adata.uns['experiment_date']
|
||||
pca_params = adata.uns['pca']['params']
|
||||
```
|
||||
|
||||
Common uns uses:
|
||||
- Analysis parameters and settings
|
||||
- Color palettes for plotting
|
||||
- Cluster information
|
||||
- Tool-specific metadata
|
||||
|
||||
### raw (Original Data Snapshot)
|
||||
Optional attribute preserving the original data matrix and variable annotations before filtering.
|
||||
|
||||
```python
|
||||
# Create AnnData and store raw state
|
||||
adata = ad.AnnData(X=np.random.rand(100, 5000))
|
||||
adata.var['gene_name'] = [f'Gene_{i}' for i in range(5000)]
|
||||
|
||||
# Store raw state before filtering
|
||||
adata.raw = adata.copy()
|
||||
|
||||
# Filter to highly variable genes
|
||||
highly_variable_mask = np.random.rand(5000) > 0.5
|
||||
adata = adata[:, highly_variable_mask]
|
||||
|
||||
# Access original data
|
||||
original_matrix = adata.raw.X
|
||||
original_var = adata.raw.var
|
||||
```
|
||||
|
||||
## Object Properties
|
||||
|
||||
```python
|
||||
# Dimensions
|
||||
n_observations = adata.n_obs
|
||||
n_variables = adata.n_vars
|
||||
shape = adata.shape # (n_obs, n_vars)
|
||||
|
||||
# Index information
|
||||
obs_names = adata.obs_names # Observation identifiers
|
||||
var_names = adata.var_names # Variable identifiers
|
||||
|
||||
# Storage mode
|
||||
is_view = adata.is_view # True if this is a view of another object
|
||||
is_backed = adata.isbacked # True if backed by on-disk storage
|
||||
filename = adata.filename # Path to backing file (if backed)
|
||||
```
|
||||
|
||||
## Creating AnnData Objects
|
||||
|
||||
### From arrays and DataFrames
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Minimal creation
|
||||
X = np.random.rand(100, 2000)
|
||||
adata = ad.AnnData(X)
|
||||
|
||||
# With metadata
|
||||
obs = pd.DataFrame({'cell_type': ['A', 'B'] * 50}, index=[f'cell_{i}' for i in range(100)])
|
||||
var = pd.DataFrame({'gene_name': [f'Gene_{i}' for i in range(2000)]}, index=[f'ENSG{i:05d}' for i in range(2000)])
|
||||
adata = ad.AnnData(X=X, obs=obs, var=var)
|
||||
|
||||
# With all components
|
||||
adata = ad.AnnData(
|
||||
X=X,
|
||||
obs=obs,
|
||||
var=var,
|
||||
layers={'raw': np.random.randint(0, 100, (100, 2000))},
|
||||
obsm={'X_pca': np.random.rand(100, 50)},
|
||||
uns={'experiment': 'test'}
|
||||
)
|
||||
```
|
||||
|
||||
### From DataFrame
|
||||
```python
|
||||
# Create from pandas DataFrame (genes as columns, cells as rows)
|
||||
df = pd.DataFrame(
|
||||
np.random.rand(100, 50),
|
||||
columns=[f'Gene_{i}' for i in range(50)],
|
||||
index=[f'Cell_{i}' for i in range(100)]
|
||||
)
|
||||
adata = ad.AnnData(df)
|
||||
```
|
||||
|
||||
## Data Access Patterns
|
||||
|
||||
### Vector extraction
|
||||
```python
|
||||
# Get observation annotation as array
|
||||
cell_types = adata.obs_vector('cell_type')
|
||||
|
||||
# Get variable values across observations
|
||||
gene_expression = adata.obs_vector('ACTB') # If ACTB is in var_names
|
||||
|
||||
# Get variable annotation as array
|
||||
gene_names = adata.var_vector('gene_name')
|
||||
```
|
||||
|
||||
### Subsetting
|
||||
```python
|
||||
# By index
|
||||
subset = adata[0:10, 0:100] # First 10 obs, first 100 vars
|
||||
|
||||
# By name
|
||||
subset = adata[['cell_1', 'cell_2'], ['ACTB', 'GAPDH']]
|
||||
|
||||
# By boolean mask
|
||||
high_count_cells = adata.obs['total_counts'] > 1000
|
||||
subset = adata[high_count_cells, :]
|
||||
|
||||
# By observation metadata
|
||||
t_cells = adata[adata.obs['cell_type'] == 'T cell']
|
||||
```
|
||||
|
||||
## Memory Considerations
|
||||
|
||||
The AnnData structure is designed for memory efficiency:
|
||||
- Sparse matrices reduce memory for sparse data
|
||||
- Views avoid copying data when possible
|
||||
- Backed mode enables working with data larger than RAM
|
||||
- Categorical annotations reduce memory for discrete values
|
||||
|
||||
```python
|
||||
# Convert strings to categoricals (more memory efficient)
|
||||
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
# Check if object is a view (doesn't own data)
|
||||
if adata.is_view:
|
||||
adata = adata.copy() # Create independent copy
|
||||
```
|
||||
404
scientific-skills/anndata/references/io_operations.md
Normal file
404
scientific-skills/anndata/references/io_operations.md
Normal file
@@ -0,0 +1,404 @@
|
||||
# Input/Output Operations
|
||||
|
||||
AnnData provides comprehensive I/O functionality for reading and writing data in various formats.
|
||||
|
||||
## Native Formats
|
||||
|
||||
### H5AD (HDF5-based)
|
||||
The recommended native format for AnnData objects, providing efficient storage and fast access.
|
||||
|
||||
#### Writing H5AD files
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Write to file
|
||||
adata.write_h5ad('data.h5ad')
|
||||
|
||||
# Write with compression
|
||||
adata.write_h5ad('data.h5ad', compression='gzip')
|
||||
|
||||
# Write with specific compression level (0-9, higher = more compression)
|
||||
adata.write_h5ad('data.h5ad', compression='gzip', compression_opts=9)
|
||||
```
|
||||
|
||||
#### Reading H5AD files
|
||||
```python
|
||||
# Read entire file into memory
|
||||
adata = ad.read_h5ad('data.h5ad')
|
||||
|
||||
# Read in backed mode (lazy loading for large files)
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r') # Read-only
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r+') # Read-write
|
||||
|
||||
# Backed mode enables working with datasets larger than RAM
|
||||
# Only accessed data is loaded into memory
|
||||
```
|
||||
|
||||
#### Backed mode operations
|
||||
```python
|
||||
# Open in backed mode
|
||||
adata = ad.read_h5ad('large_dataset.h5ad', backed='r')
|
||||
|
||||
# Access metadata without loading X into memory
|
||||
print(adata.obs.head())
|
||||
print(adata.var.head())
|
||||
|
||||
# Subset operations create views
|
||||
subset = adata[:100, :500] # View, no data loaded
|
||||
|
||||
# Load specific data into memory
|
||||
X_subset = subset.X[:] # Now loads this subset
|
||||
|
||||
# Convert entire backed object to memory
|
||||
adata_memory = adata.to_memory()
|
||||
```
|
||||
|
||||
### Zarr
|
||||
Hierarchical array storage format, optimized for cloud storage and parallel I/O.
|
||||
|
||||
#### Writing Zarr
|
||||
```python
|
||||
# Write to Zarr store
|
||||
adata.write_zarr('data.zarr')
|
||||
|
||||
# Write with specific chunks (important for performance)
|
||||
adata.write_zarr('data.zarr', chunks=(100, 100))
|
||||
```
|
||||
|
||||
#### Reading Zarr
|
||||
```python
|
||||
# Read Zarr store
|
||||
adata = ad.read_zarr('data.zarr')
|
||||
```
|
||||
|
||||
#### Remote Zarr access
|
||||
```python
|
||||
import fsspec
|
||||
|
||||
# Access Zarr from S3
|
||||
store = fsspec.get_mapper('s3://bucket-name/data.zarr')
|
||||
adata = ad.read_zarr(store)
|
||||
|
||||
# Access Zarr from URL
|
||||
store = fsspec.get_mapper('https://example.com/data.zarr')
|
||||
adata = ad.read_zarr(store)
|
||||
```
|
||||
|
||||
## Alternative Input Formats
|
||||
|
||||
### CSV/TSV
|
||||
```python
|
||||
# Read CSV (genes as columns, cells as rows)
|
||||
adata = ad.read_csv('data.csv')
|
||||
|
||||
# Read with custom delimiter
|
||||
adata = ad.read_csv('data.tsv', delimiter='\t')
|
||||
|
||||
# Specify that first column is row names
|
||||
adata = ad.read_csv('data.csv', first_column_names=True)
|
||||
```
|
||||
|
||||
### Excel
|
||||
```python
|
||||
# Read Excel file
|
||||
adata = ad.read_excel('data.xlsx')
|
||||
|
||||
# Read specific sheet
|
||||
adata = ad.read_excel('data.xlsx', sheet='Sheet1')
|
||||
```
|
||||
|
||||
### Matrix Market (MTX)
|
||||
Common format for sparse matrices in genomics.
|
||||
|
||||
```python
|
||||
# Read MTX with associated files
|
||||
# Requires: matrix.mtx, genes.tsv, barcodes.tsv
|
||||
adata = ad.read_mtx('matrix.mtx')
|
||||
|
||||
# Read with custom gene and barcode files
|
||||
adata = ad.read_mtx(
|
||||
'matrix.mtx',
|
||||
var_names='genes.tsv',
|
||||
obs_names='barcodes.tsv'
|
||||
)
|
||||
|
||||
# Transpose if needed (MTX often has genes as rows)
|
||||
adata = adata.T
|
||||
```
|
||||
|
||||
### 10X Genomics formats
|
||||
```python
|
||||
# Read 10X h5 format
|
||||
adata = ad.read_10x_h5('filtered_feature_bc_matrix.h5')
|
||||
|
||||
# Read 10X MTX directory
|
||||
adata = ad.read_10x_mtx('filtered_feature_bc_matrix/')
|
||||
|
||||
# Specify genome if multiple present
|
||||
adata = ad.read_10x_h5('data.h5', genome='GRCh38')
|
||||
```
|
||||
|
||||
### Loom
|
||||
```python
|
||||
# Read Loom file
|
||||
adata = ad.read_loom('data.loom')
|
||||
|
||||
# Read with specific observation and variable annotations
|
||||
adata = ad.read_loom(
|
||||
'data.loom',
|
||||
obs_names='CellID',
|
||||
var_names='Gene'
|
||||
)
|
||||
```
|
||||
|
||||
### Text files
|
||||
```python
|
||||
# Read generic text file
|
||||
adata = ad.read_text('data.txt', delimiter='\t')
|
||||
|
||||
# Read with custom parameters
|
||||
adata = ad.read_text(
|
||||
'data.txt',
|
||||
delimiter=',',
|
||||
first_column_names=True,
|
||||
dtype='float32'
|
||||
)
|
||||
```
|
||||
|
||||
### UMI tools
|
||||
```python
|
||||
# Read UMI tools format
|
||||
adata = ad.read_umi_tools('counts.tsv')
|
||||
```
|
||||
|
||||
### HDF5 (generic)
|
||||
```python
|
||||
# Read from HDF5 file (not h5ad format)
|
||||
adata = ad.read_hdf('data.h5', key='dataset')
|
||||
```
|
||||
|
||||
## Alternative Output Formats
|
||||
|
||||
### CSV
|
||||
```python
|
||||
# Write to CSV files (creates multiple files)
|
||||
adata.write_csvs('output_dir/')
|
||||
|
||||
# This creates:
|
||||
# - output_dir/X.csv (expression matrix)
|
||||
# - output_dir/obs.csv (observation annotations)
|
||||
# - output_dir/var.csv (variable annotations)
|
||||
# - output_dir/uns.csv (unstructured annotations, if possible)
|
||||
|
||||
# Skip certain components
|
||||
adata.write_csvs('output_dir/', skip_data=True) # Skip X matrix
|
||||
```
|
||||
|
||||
### Loom
|
||||
```python
|
||||
# Write to Loom format
|
||||
adata.write_loom('output.loom')
|
||||
```
|
||||
|
||||
## Reading Specific Elements
|
||||
|
||||
For fine-grained control, read specific elements from storage:
|
||||
|
||||
```python
|
||||
from anndata import read_elem
|
||||
|
||||
# Read just observation annotations
|
||||
obs = read_elem('data.h5ad/obs')
|
||||
|
||||
# Read specific layer
|
||||
layer = read_elem('data.h5ad/layers/normalized')
|
||||
|
||||
# Read unstructured data element
|
||||
params = read_elem('data.h5ad/uns/pca_params')
|
||||
```
|
||||
|
||||
## Writing Specific Elements
|
||||
|
||||
```python
|
||||
from anndata import write_elem
|
||||
import h5py
|
||||
|
||||
# Write element to existing file
|
||||
with h5py.File('data.h5ad', 'a') as f:
|
||||
write_elem(f, 'new_layer', adata.X.copy())
|
||||
```
|
||||
|
||||
## Lazy Operations
|
||||
|
||||
For very large datasets, use lazy reading to avoid loading entire datasets:
|
||||
|
||||
```python
|
||||
from anndata.experimental import read_elem_lazy
|
||||
|
||||
# Lazy read (returns dask array or similar)
|
||||
X_lazy = read_elem_lazy('large_data.h5ad/X')
|
||||
|
||||
# Compute only when needed
|
||||
subset = X_lazy[:100, :100].compute()
|
||||
```
|
||||
|
||||
## Common I/O Patterns
|
||||
|
||||
### Convert between formats
|
||||
```python
|
||||
# MTX to H5AD
|
||||
adata = ad.read_mtx('matrix.mtx').T
|
||||
adata.write_h5ad('data.h5ad')
|
||||
|
||||
# CSV to H5AD
|
||||
adata = ad.read_csv('data.csv')
|
||||
adata.write_h5ad('data.h5ad')
|
||||
|
||||
# H5AD to Zarr
|
||||
adata = ad.read_h5ad('data.h5ad')
|
||||
adata.write_zarr('data.zarr')
|
||||
```
|
||||
|
||||
### Load metadata without data
|
||||
```python
|
||||
# Backed mode allows inspecting metadata without loading X
|
||||
adata = ad.read_h5ad('large_file.h5ad', backed='r')
|
||||
print(f"Dataset contains {adata.n_obs} observations and {adata.n_vars} variables")
|
||||
print(adata.obs.columns)
|
||||
print(adata.var.columns)
|
||||
# X is not loaded into memory
|
||||
```
|
||||
|
||||
### Append to existing file
|
||||
```python
|
||||
# Open in read-write mode
|
||||
adata = ad.read_h5ad('data.h5ad', backed='r+')
|
||||
|
||||
# Modify metadata
|
||||
adata.obs['new_column'] = values
|
||||
|
||||
# Changes are written to disk
|
||||
```
|
||||
|
||||
### Download from URL
|
||||
```python
|
||||
import anndata as ad
|
||||
|
||||
# Read directly from URL (for h5ad files)
|
||||
url = 'https://example.com/data.h5ad'
|
||||
adata = ad.read_h5ad(url, backed='r') # Streaming access
|
||||
|
||||
# For other formats, download first
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(url, 'local_file.h5ad')
|
||||
adata = ad.read_h5ad('local_file.h5ad')
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
### Reading
|
||||
- Use `backed='r'` for large files you only need to query
|
||||
- Use `backed='r+'` if you need to modify metadata without loading all data
|
||||
- H5AD format is generally fastest for random access
|
||||
- Zarr is better for cloud storage and parallel access
|
||||
- Consider compression for storage, but note it may slow down reading
|
||||
|
||||
### Writing
|
||||
- Use compression for long-term storage: `compression='gzip'` or `compression='lzf'`
|
||||
- LZF compression is faster but compresses less than GZIP
|
||||
- For Zarr, tune chunk sizes based on access patterns:
|
||||
- Larger chunks for sequential reads
|
||||
- Smaller chunks for random access
|
||||
- Convert string columns to categorical before writing (smaller files)
|
||||
|
||||
### Memory management
|
||||
```python
|
||||
# Convert strings to categoricals (reduces file size and memory)
|
||||
adata.strings_to_categoricals()
|
||||
adata.write_h5ad('data.h5ad')
|
||||
|
||||
# Use sparse matrices for sparse data
|
||||
from scipy.sparse import csr_matrix
|
||||
if isinstance(adata.X, np.ndarray):
|
||||
density = np.count_nonzero(adata.X) / adata.X.size
|
||||
if density < 0.5: # If more than 50% zeros
|
||||
adata.X = csr_matrix(adata.X)
|
||||
```
|
||||
|
||||
## Handling Large Datasets
|
||||
|
||||
### Strategy 1: Backed mode
|
||||
```python
|
||||
# Work with dataset larger than RAM
|
||||
adata = ad.read_h5ad('100GB_file.h5ad', backed='r')
|
||||
|
||||
# Filter based on metadata (fast, no data loading)
|
||||
filtered = adata[adata.obs['quality_score'] > 0.8]
|
||||
|
||||
# Load filtered subset into memory
|
||||
adata_memory = filtered.to_memory()
|
||||
```
|
||||
|
||||
### Strategy 2: Chunked processing
|
||||
```python
|
||||
# Process data in chunks
|
||||
adata = ad.read_h5ad('large_file.h5ad', backed='r')
|
||||
|
||||
chunk_size = 1000
|
||||
results = []
|
||||
|
||||
for i in range(0, adata.n_obs, chunk_size):
|
||||
chunk = adata[i:i+chunk_size, :].to_memory()
|
||||
# Process chunk
|
||||
result = process(chunk)
|
||||
results.append(result)
|
||||
```
|
||||
|
||||
### Strategy 3: Use AnnCollection
|
||||
```python
|
||||
from anndata.experimental import AnnCollection
|
||||
|
||||
# Create collection without loading data
|
||||
adatas = [f'dataset_{i}.h5ad' for i in range(10)]
|
||||
collection = AnnCollection(
|
||||
adatas,
|
||||
join_obs='inner',
|
||||
join_vars='inner'
|
||||
)
|
||||
|
||||
# Process collection lazily
|
||||
# Data is loaded only when accessed
|
||||
```
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: Out of memory when reading
|
||||
**Solution**: Use backed mode or read in chunks
|
||||
```python
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r')
|
||||
```
|
||||
|
||||
### Issue: Slow reading from cloud storage
|
||||
**Solution**: Use Zarr format with appropriate chunking
|
||||
```python
|
||||
adata.write_zarr('data.zarr', chunks=(1000, 1000))
|
||||
```
|
||||
|
||||
### Issue: Large file sizes
|
||||
**Solution**: Use compression and convert to sparse/categorical
|
||||
```python
|
||||
adata.strings_to_categoricals()
|
||||
from scipy.sparse import csr_matrix
|
||||
adata.X = csr_matrix(adata.X)
|
||||
adata.write_h5ad('compressed.h5ad', compression='gzip')
|
||||
```
|
||||
|
||||
### Issue: Cannot modify backed object
|
||||
**Solution**: Either load to memory or open in 'r+' mode
|
||||
```python
|
||||
# Option 1: Load to memory
|
||||
adata = adata.to_memory()
|
||||
|
||||
# Option 2: Open in read-write mode
|
||||
adata = ad.read_h5ad('file.h5ad', backed='r+')
|
||||
```
|
||||
516
scientific-skills/anndata/references/manipulation.md
Normal file
516
scientific-skills/anndata/references/manipulation.md
Normal file
@@ -0,0 +1,516 @@
|
||||
# Data Manipulation
|
||||
|
||||
Operations for transforming, subsetting, and manipulating AnnData objects.
|
||||
|
||||
## Subsetting
|
||||
|
||||
### By indices
|
||||
```python
|
||||
import anndata as ad
|
||||
import numpy as np
|
||||
|
||||
adata = ad.AnnData(X=np.random.rand(1000, 2000))
|
||||
|
||||
# Integer indices
|
||||
subset = adata[0:100, 0:500] # First 100 obs, first 500 vars
|
||||
|
||||
# List of indices
|
||||
obs_indices = [0, 10, 20, 30, 40]
|
||||
var_indices = [0, 1, 2, 3, 4]
|
||||
subset = adata[obs_indices, var_indices]
|
||||
|
||||
# Single observation or variable
|
||||
single_obs = adata[0, :]
|
||||
single_var = adata[:, 0]
|
||||
```
|
||||
|
||||
### By names
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Create with named indices
|
||||
obs_names = [f'cell_{i}' for i in range(1000)]
|
||||
var_names = [f'gene_{i}' for i in range(2000)]
|
||||
adata = ad.AnnData(
|
||||
X=np.random.rand(1000, 2000),
|
||||
obs=pd.DataFrame(index=obs_names),
|
||||
var=pd.DataFrame(index=var_names)
|
||||
)
|
||||
|
||||
# Subset by observation names
|
||||
subset = adata[['cell_0', 'cell_1', 'cell_2'], :]
|
||||
|
||||
# Subset by variable names
|
||||
subset = adata[:, ['gene_0', 'gene_10', 'gene_20']]
|
||||
|
||||
# Both axes
|
||||
subset = adata[['cell_0', 'cell_1'], ['gene_0', 'gene_1']]
|
||||
```
|
||||
|
||||
### By boolean masks
|
||||
```python
|
||||
# Create boolean masks
|
||||
high_count_obs = np.random.rand(1000) > 0.5
|
||||
high_var_genes = np.random.rand(2000) > 0.7
|
||||
|
||||
# Subset using masks
|
||||
subset = adata[high_count_obs, :]
|
||||
subset = adata[:, high_var_genes]
|
||||
subset = adata[high_count_obs, high_var_genes]
|
||||
```
|
||||
|
||||
### By metadata conditions
|
||||
```python
|
||||
# Add metadata
|
||||
adata.obs['cell_type'] = np.random.choice(['A', 'B', 'C'], 1000)
|
||||
adata.obs['quality_score'] = np.random.rand(1000)
|
||||
adata.var['highly_variable'] = np.random.rand(2000) > 0.8
|
||||
|
||||
# Filter by cell type
|
||||
t_cells = adata[adata.obs['cell_type'] == 'A']
|
||||
|
||||
# Filter by multiple conditions
|
||||
high_quality_a_cells = adata[
|
||||
(adata.obs['cell_type'] == 'A') &
|
||||
(adata.obs['quality_score'] > 0.7)
|
||||
]
|
||||
|
||||
# Filter by variable metadata
|
||||
hv_genes = adata[:, adata.var['highly_variable']]
|
||||
|
||||
# Complex conditions
|
||||
filtered = adata[
|
||||
(adata.obs['quality_score'] > 0.5) &
|
||||
(adata.obs['cell_type'].isin(['A', 'B'])),
|
||||
adata.var['highly_variable']
|
||||
]
|
||||
```
|
||||
|
||||
## Transposition
|
||||
|
||||
```python
|
||||
# Transpose AnnData object (swap observations and variables)
|
||||
adata_T = adata.T
|
||||
|
||||
# Shape changes
|
||||
print(adata.shape) # (1000, 2000)
|
||||
print(adata_T.shape) # (2000, 1000)
|
||||
|
||||
# obs and var are swapped
|
||||
print(adata.obs.head()) # Observation metadata
|
||||
print(adata_T.var.head()) # Same data, now as variable metadata
|
||||
|
||||
# Useful when data is in opposite orientation
|
||||
# Common with some file formats where genes are rows
|
||||
```
|
||||
|
||||
## Copying
|
||||
|
||||
### Full copy
|
||||
```python
|
||||
# Create independent copy
|
||||
adata_copy = adata.copy()
|
||||
|
||||
# Modifications to copy don't affect original
|
||||
adata_copy.obs['new_column'] = 1
|
||||
print('new_column' in adata.obs.columns) # False
|
||||
```
|
||||
|
||||
### Shallow copy
|
||||
```python
|
||||
# View (doesn't copy data, modifications affect original)
|
||||
adata_view = adata[0:100, :]
|
||||
|
||||
# Check if object is a view
|
||||
print(adata_view.is_view) # True
|
||||
|
||||
# Convert view to independent copy
|
||||
adata_independent = adata_view.copy()
|
||||
print(adata_independent.is_view) # False
|
||||
```
|
||||
|
||||
## Renaming
|
||||
|
||||
### Rename observations and variables
|
||||
```python
|
||||
# Rename all observations
|
||||
adata.obs_names = [f'new_cell_{i}' for i in range(adata.n_obs)]
|
||||
|
||||
# Rename all variables
|
||||
adata.var_names = [f'new_gene_{i}' for i in range(adata.n_vars)]
|
||||
|
||||
# Make names unique (add suffix to duplicates)
|
||||
adata.obs_names_make_unique()
|
||||
adata.var_names_make_unique()
|
||||
```
|
||||
|
||||
### Rename categories
|
||||
```python
|
||||
# Create categorical column
|
||||
adata.obs['cell_type'] = pd.Categorical(['A', 'B', 'C'] * 333 + ['A'])
|
||||
|
||||
# Rename categories
|
||||
adata.rename_categories('cell_type', ['Type_A', 'Type_B', 'Type_C'])
|
||||
|
||||
# Or using dictionary
|
||||
adata.rename_categories('cell_type', {
|
||||
'Type_A': 'T_cell',
|
||||
'Type_B': 'B_cell',
|
||||
'Type_C': 'Monocyte'
|
||||
})
|
||||
```
|
||||
|
||||
## Type Conversions
|
||||
|
||||
### Strings to categoricals
|
||||
```python
|
||||
# Convert string columns to categorical (more memory efficient)
|
||||
adata.obs['cell_type'] = ['TypeA', 'TypeB'] * 500
|
||||
adata.obs['tissue'] = ['brain', 'liver'] * 500
|
||||
|
||||
# Convert all string columns to categorical
|
||||
adata.strings_to_categoricals()
|
||||
|
||||
print(adata.obs['cell_type'].dtype) # category
|
||||
print(adata.obs['tissue'].dtype) # category
|
||||
```
|
||||
|
||||
### Sparse to dense and vice versa
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Dense to sparse
|
||||
if not isinstance(adata.X, csr_matrix):
|
||||
adata.X = csr_matrix(adata.X)
|
||||
|
||||
# Sparse to dense
|
||||
if isinstance(adata.X, csr_matrix):
|
||||
adata.X = adata.X.toarray()
|
||||
|
||||
# Convert layer
|
||||
adata.layers['normalized'] = csr_matrix(adata.layers['normalized'])
|
||||
```
|
||||
|
||||
## Chunked Operations
|
||||
|
||||
Process large datasets in chunks:
|
||||
|
||||
```python
|
||||
# Iterate through data in chunks
|
||||
chunk_size = 100
|
||||
for chunk in adata.chunked_X(chunk_size):
|
||||
# Process chunk
|
||||
result = process_chunk(chunk)
|
||||
```
|
||||
|
||||
## Extracting Vectors
|
||||
|
||||
### Get observation vectors
|
||||
```python
|
||||
# Get observation metadata as array
|
||||
cell_types = adata.obs_vector('cell_type')
|
||||
|
||||
# Get gene expression across observations
|
||||
actb_expression = adata.obs_vector('ACTB') # If ACTB in var_names
|
||||
```
|
||||
|
||||
### Get variable vectors
|
||||
```python
|
||||
# Get variable metadata as array
|
||||
gene_names = adata.var_vector('gene_name')
|
||||
```
|
||||
|
||||
## Adding/Modifying Data
|
||||
|
||||
### Add observations
|
||||
```python
|
||||
# Create new observations
|
||||
new_obs = ad.AnnData(X=np.random.rand(100, adata.n_vars))
|
||||
new_obs.var_names = adata.var_names
|
||||
|
||||
# Concatenate with existing
|
||||
adata_extended = ad.concat([adata, new_obs], axis=0)
|
||||
```
|
||||
|
||||
### Add variables
|
||||
```python
|
||||
# Create new variables
|
||||
new_vars = ad.AnnData(X=np.random.rand(adata.n_obs, 100))
|
||||
new_vars.obs_names = adata.obs_names
|
||||
|
||||
# Concatenate with existing
|
||||
adata_extended = ad.concat([adata, new_vars], axis=1)
|
||||
```
|
||||
|
||||
### Add metadata columns
|
||||
```python
|
||||
# Add observation annotation
|
||||
adata.obs['new_score'] = np.random.rand(adata.n_obs)
|
||||
|
||||
# Add variable annotation
|
||||
adata.var['new_label'] = ['label'] * adata.n_vars
|
||||
|
||||
# Add from external data
|
||||
external_data = pd.read_csv('metadata.csv', index_col=0)
|
||||
adata.obs['external_info'] = external_data.loc[adata.obs_names, 'column']
|
||||
```
|
||||
|
||||
### Add layers
|
||||
```python
|
||||
# Add new layer
|
||||
adata.layers['raw_counts'] = np.random.randint(0, 100, adata.shape)
|
||||
adata.layers['log_transformed'] = np.log1p(adata.X)
|
||||
|
||||
# Replace layer
|
||||
adata.layers['normalized'] = new_normalized_data
|
||||
```
|
||||
|
||||
### Add embeddings
|
||||
```python
|
||||
# Add PCA
|
||||
adata.obsm['X_pca'] = np.random.rand(adata.n_obs, 50)
|
||||
|
||||
# Add UMAP
|
||||
adata.obsm['X_umap'] = np.random.rand(adata.n_obs, 2)
|
||||
|
||||
# Add multiple embeddings
|
||||
adata.obsm['X_tsne'] = np.random.rand(adata.n_obs, 2)
|
||||
adata.obsm['X_diffmap'] = np.random.rand(adata.n_obs, 10)
|
||||
```
|
||||
|
||||
### Add pairwise relationships
|
||||
```python
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
# Add nearest neighbor graph
|
||||
n_obs = adata.n_obs
|
||||
knn_graph = csr_matrix(np.random.rand(n_obs, n_obs) > 0.95)
|
||||
adata.obsp['connectivities'] = knn_graph
|
||||
|
||||
# Add distance matrix
|
||||
adata.obsp['distances'] = csr_matrix(np.random.rand(n_obs, n_obs))
|
||||
```
|
||||
|
||||
### Add unstructured data
|
||||
```python
|
||||
# Add analysis parameters
|
||||
adata.uns['pca'] = {
|
||||
'variance': [0.2, 0.15, 0.1],
|
||||
'variance_ratio': [0.4, 0.3, 0.2],
|
||||
'params': {'n_comps': 50}
|
||||
}
|
||||
|
||||
# Add color schemes
|
||||
adata.uns['cell_type_colors'] = ['#FF0000', '#00FF00', '#0000FF']
|
||||
```
|
||||
|
||||
## Removing Data
|
||||
|
||||
### Remove observations or variables
|
||||
```python
|
||||
# Keep only specific observations
|
||||
keep_obs = adata.obs['quality_score'] > 0.5
|
||||
adata = adata[keep_obs, :]
|
||||
|
||||
# Remove specific variables
|
||||
remove_vars = adata.var['low_count']
|
||||
adata = adata[:, ~remove_vars]
|
||||
```
|
||||
|
||||
### Remove metadata columns
|
||||
```python
|
||||
# Remove observation column
|
||||
adata.obs.drop('unwanted_column', axis=1, inplace=True)
|
||||
|
||||
# Remove variable column
|
||||
adata.var.drop('unwanted_column', axis=1, inplace=True)
|
||||
```
|
||||
|
||||
### Remove layers
|
||||
```python
|
||||
# Remove specific layer
|
||||
del adata.layers['unwanted_layer']
|
||||
|
||||
# Remove all layers
|
||||
adata.layers = {}
|
||||
```
|
||||
|
||||
### Remove embeddings
|
||||
```python
|
||||
# Remove specific embedding
|
||||
del adata.obsm['X_tsne']
|
||||
|
||||
# Remove all embeddings
|
||||
adata.obsm = {}
|
||||
```
|
||||
|
||||
### Remove unstructured data
|
||||
```python
|
||||
# Remove specific key
|
||||
del adata.uns['unwanted_key']
|
||||
|
||||
# Remove all unstructured data
|
||||
adata.uns = {}
|
||||
```
|
||||
|
||||
## Reordering
|
||||
|
||||
### Sort observations
|
||||
```python
|
||||
# Sort by observation metadata
|
||||
adata = adata[adata.obs.sort_values('quality_score').index, :]
|
||||
|
||||
# Sort by observation names
|
||||
adata = adata[sorted(adata.obs_names), :]
|
||||
```
|
||||
|
||||
### Sort variables
|
||||
```python
|
||||
# Sort by variable metadata
|
||||
adata = adata[:, adata.var.sort_values('gene_name').index]
|
||||
|
||||
# Sort by variable names
|
||||
adata = adata[:, sorted(adata.var_names)]
|
||||
```
|
||||
|
||||
### Reorder to match external list
|
||||
```python
|
||||
# Reorder observations to match external list
|
||||
desired_order = ['cell_10', 'cell_5', 'cell_20', ...]
|
||||
adata = adata[desired_order, :]
|
||||
|
||||
# Reorder variables
|
||||
desired_genes = ['TP53', 'ACTB', 'GAPDH', ...]
|
||||
adata = adata[:, desired_genes]
|
||||
```
|
||||
|
||||
## Data Transformations
|
||||
|
||||
### Normalize
|
||||
```python
|
||||
# Total count normalization (CPM/TPM-like)
|
||||
total_counts = adata.X.sum(axis=1)
|
||||
adata.layers['normalized'] = adata.X / total_counts[:, np.newaxis] * 1e6
|
||||
|
||||
# Log transformation
|
||||
adata.layers['log1p'] = np.log1p(adata.X)
|
||||
|
||||
# Z-score normalization
|
||||
mean = adata.X.mean(axis=0)
|
||||
std = adata.X.std(axis=0)
|
||||
adata.layers['scaled'] = (adata.X - mean) / std
|
||||
```
|
||||
|
||||
### Filter
|
||||
```python
|
||||
# Filter cells by total counts
|
||||
total_counts = np.array(adata.X.sum(axis=1)).flatten()
|
||||
adata.obs['total_counts'] = total_counts
|
||||
adata = adata[adata.obs['total_counts'] > 1000, :]
|
||||
|
||||
# Filter genes by detection rate
|
||||
detection_rate = (adata.X > 0).sum(axis=0) / adata.n_obs
|
||||
adata.var['detection_rate'] = np.array(detection_rate).flatten()
|
||||
adata = adata[:, adata.var['detection_rate'] > 0.01]
|
||||
```
|
||||
|
||||
## Working with Views
|
||||
|
||||
Views are lightweight references to subsets of data that don't copy the underlying matrix:
|
||||
|
||||
```python
|
||||
# Create view
|
||||
view = adata[0:100, 0:500]
|
||||
print(view.is_view) # True
|
||||
|
||||
# Views allow read access
|
||||
data = view.X
|
||||
|
||||
# Modifying view data affects original
|
||||
# (Be careful!)
|
||||
|
||||
# Convert view to independent copy
|
||||
independent = view.copy()
|
||||
|
||||
# Force AnnData to be a copy, not a view
|
||||
adata = adata.copy()
|
||||
```
|
||||
|
||||
## Merging Metadata
|
||||
|
||||
```python
|
||||
# Merge external metadata
|
||||
external_metadata = pd.read_csv('additional_metadata.csv', index_col=0)
|
||||
|
||||
# Join metadata (inner join on index)
|
||||
adata.obs = adata.obs.join(external_metadata)
|
||||
|
||||
# Left join (keep all adata observations)
|
||||
adata.obs = adata.obs.merge(
|
||||
external_metadata,
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
how='left'
|
||||
)
|
||||
```
|
||||
|
||||
## Common Manipulation Patterns
|
||||
|
||||
### Quality control filtering
|
||||
```python
|
||||
# Calculate QC metrics
|
||||
adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
|
||||
adata.obs['total_counts'] = adata.X.sum(axis=1)
|
||||
adata.var['n_cells'] = (adata.X > 0).sum(axis=0)
|
||||
|
||||
# Filter low-quality cells
|
||||
adata = adata[adata.obs['n_genes'] > 200, :]
|
||||
adata = adata[adata.obs['total_counts'] < 50000, :]
|
||||
|
||||
# Filter rarely detected genes
|
||||
adata = adata[:, adata.var['n_cells'] >= 3]
|
||||
```
|
||||
|
||||
### Select highly variable genes
|
||||
```python
|
||||
# Mark highly variable genes
|
||||
gene_variance = np.var(adata.X, axis=0)
|
||||
adata.var['variance'] = np.array(gene_variance).flatten()
|
||||
adata.var['highly_variable'] = adata.var['variance'] > np.percentile(gene_variance, 90)
|
||||
|
||||
# Subset to highly variable genes
|
||||
adata_hvg = adata[:, adata.var['highly_variable']].copy()
|
||||
```
|
||||
|
||||
### Downsample
|
||||
```python
|
||||
# Random sampling of observations
|
||||
np.random.seed(42)
|
||||
n_sample = 500
|
||||
sample_indices = np.random.choice(adata.n_obs, n_sample, replace=False)
|
||||
adata_downsampled = adata[sample_indices, :].copy()
|
||||
|
||||
# Stratified sampling by cell type
|
||||
from sklearn.model_selection import train_test_split
|
||||
train_idx, test_idx = train_test_split(
|
||||
range(adata.n_obs),
|
||||
test_size=0.2,
|
||||
stratify=adata.obs['cell_type']
|
||||
)
|
||||
adata_train = adata[train_idx, :].copy()
|
||||
adata_test = adata[test_idx, :].copy()
|
||||
```
|
||||
|
||||
### Split train/test
|
||||
```python
|
||||
# Random train/test split
|
||||
np.random.seed(42)
|
||||
n_obs = adata.n_obs
|
||||
train_size = int(0.8 * n_obs)
|
||||
indices = np.random.permutation(n_obs)
|
||||
train_indices = indices[:train_size]
|
||||
test_indices = indices[train_size:]
|
||||
|
||||
adata_train = adata[train_indices, :].copy()
|
||||
adata_test = adata[test_indices, :].copy()
|
||||
```
|
||||
240
scientific-skills/arboreto/SKILL.md
Normal file
240
scientific-skills/arboreto/SKILL.md
Normal file
@@ -0,0 +1,240 @@
|
||||
---
|
||||
name: arboreto
|
||||
description: Infer gene regulatory networks (GRNs) from gene expression data using scalable algorithms (GRNBoost2, GENIE3). Use when analyzing transcriptomics data (bulk RNA-seq, single-cell RNA-seq) to identify transcription factor-target gene relationships and regulatory interactions. Supports distributed computation for large-scale datasets.
|
||||
license: BSD-3-Clause license
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# Arboreto
|
||||
|
||||
## Overview
|
||||
|
||||
Arboreto is a computational library for inferring gene regulatory networks (GRNs) from gene expression data using parallelized algorithms that scale from single machines to multi-node clusters.
|
||||
|
||||
**Core capability**: Identify which transcription factors (TFs) regulate which target genes based on expression patterns across observations (cells, samples, conditions).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Install arboreto:
|
||||
```bash
|
||||
uv pip install arboreto
|
||||
```
|
||||
|
||||
Basic GRN inference:
|
||||
```python
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load expression data (genes as columns)
|
||||
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
|
||||
|
||||
# Infer regulatory network
|
||||
network = grnboost2(expression_data=expression_matrix)
|
||||
|
||||
# Save results (TF, target, importance)
|
||||
network.to_csv('network.tsv', sep='\t', index=False, header=False)
|
||||
```
|
||||
|
||||
**Critical**: Always use `if __name__ == '__main__':` guard because Dask spawns new processes.
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Basic GRN Inference
|
||||
|
||||
For standard GRN inference workflows including:
|
||||
- Input data preparation (Pandas DataFrame or NumPy array)
|
||||
- Running inference with GRNBoost2 or GENIE3
|
||||
- Filtering by transcription factors
|
||||
- Output format and interpretation
|
||||
|
||||
**See**: `references/basic_inference.md`
|
||||
|
||||
**Use the ready-to-run script**: `scripts/basic_grn_inference.py` for standard inference tasks:
|
||||
```bash
|
||||
python scripts/basic_grn_inference.py expression_data.tsv output_network.tsv --tf-file tfs.txt --seed 777
|
||||
```
|
||||
|
||||
### 2. Algorithm Selection
|
||||
|
||||
Arboreto provides two algorithms:
|
||||
|
||||
**GRNBoost2 (Recommended)**:
|
||||
- Fast gradient boosting-based inference
|
||||
- Optimized for large datasets (10k+ observations)
|
||||
- Default choice for most analyses
|
||||
|
||||
**GENIE3**:
|
||||
- Random Forest-based inference
|
||||
- Original multiple regression approach
|
||||
- Use for comparison or validation
|
||||
|
||||
Quick comparison:
|
||||
```python
|
||||
from arboreto.algo import grnboost2, genie3
|
||||
|
||||
# Fast, recommended
|
||||
network_grnboost = grnboost2(expression_data=matrix)
|
||||
|
||||
# Classic algorithm
|
||||
network_genie3 = genie3(expression_data=matrix)
|
||||
```
|
||||
|
||||
**For detailed algorithm comparison, parameters, and selection guidance**: `references/algorithms.md`
|
||||
|
||||
### 3. Distributed Computing
|
||||
|
||||
Scale inference from local multi-core to cluster environments:
|
||||
|
||||
**Local (default)** - Uses all available cores automatically:
|
||||
```python
|
||||
network = grnboost2(expression_data=matrix)
|
||||
```
|
||||
|
||||
**Custom local client** - Control resources:
|
||||
```python
|
||||
from distributed import LocalCluster, Client
|
||||
|
||||
local_cluster = LocalCluster(n_workers=10, memory_limit='8GB')
|
||||
client = Client(local_cluster)
|
||||
|
||||
network = grnboost2(expression_data=matrix, client_or_address=client)
|
||||
|
||||
client.close()
|
||||
local_cluster.close()
|
||||
```
|
||||
|
||||
**Cluster computing** - Connect to remote Dask scheduler:
|
||||
```python
|
||||
from distributed import Client
|
||||
|
||||
client = Client('tcp://scheduler:8786')
|
||||
network = grnboost2(expression_data=matrix, client_or_address=client)
|
||||
```
|
||||
|
||||
**For cluster setup, performance optimization, and large-scale workflows**: `references/distributed_computing.md`
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv pip install arboreto
|
||||
```
|
||||
|
||||
**Dependencies**: scipy, scikit-learn, numpy, pandas, dask, distributed
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Single-Cell RNA-seq Analysis
|
||||
```python
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load single-cell expression matrix (cells x genes)
|
||||
sc_data = pd.read_csv('scrna_counts.tsv', sep='\t')
|
||||
|
||||
# Infer cell-type-specific regulatory network
|
||||
network = grnboost2(expression_data=sc_data, seed=42)
|
||||
|
||||
# Filter high-confidence links
|
||||
high_confidence = network[network['importance'] > 0.5]
|
||||
high_confidence.to_csv('grn_high_confidence.tsv', sep='\t', index=False)
|
||||
```
|
||||
|
||||
### Bulk RNA-seq with TF Filtering
|
||||
```python
|
||||
from arboreto.utils import load_tf_names
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load data
|
||||
expression_data = pd.read_csv('rnaseq_tpm.tsv', sep='\t')
|
||||
tf_names = load_tf_names('human_tfs.txt')
|
||||
|
||||
# Infer with TF restriction
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=123
|
||||
)
|
||||
|
||||
network.to_csv('tf_target_network.tsv', sep='\t', index=False)
|
||||
```
|
||||
|
||||
### Comparative Analysis (Multiple Conditions)
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Infer networks for different conditions
|
||||
conditions = ['control', 'treatment_24h', 'treatment_48h']
|
||||
|
||||
for condition in conditions:
|
||||
data = pd.read_csv(f'{condition}_expression.tsv', sep='\t')
|
||||
network = grnboost2(expression_data=data, seed=42)
|
||||
network.to_csv(f'{condition}_network.tsv', sep='\t', index=False)
|
||||
```
|
||||
|
||||
## Output Interpretation
|
||||
|
||||
Arboreto returns a DataFrame with regulatory links:
|
||||
|
||||
| Column | Description |
|
||||
|--------|-------------|
|
||||
| `TF` | Transcription factor (regulator) |
|
||||
| `target` | Target gene |
|
||||
| `importance` | Regulatory importance score (higher = stronger) |
|
||||
|
||||
**Filtering strategy**:
|
||||
- Top N links per target gene
|
||||
- Importance threshold (e.g., > 0.5)
|
||||
- Statistical significance testing (permutation tests)
|
||||
|
||||
## Integration with pySCENIC
|
||||
|
||||
Arboreto is a core component of the SCENIC pipeline for single-cell regulatory network analysis:
|
||||
|
||||
```python
|
||||
# Step 1: Use arboreto for GRN inference
|
||||
from arboreto.algo import grnboost2
|
||||
network = grnboost2(expression_data=sc_data, tf_names=tf_list)
|
||||
|
||||
# Step 2: Use pySCENIC for regulon identification and activity scoring
|
||||
# (See pySCENIC documentation for downstream analysis)
|
||||
```
|
||||
|
||||
## Reproducibility
|
||||
|
||||
Always set a seed for reproducible results:
|
||||
```python
|
||||
network = grnboost2(expression_data=matrix, seed=777)
|
||||
```
|
||||
|
||||
Run multiple seeds for robustness analysis:
|
||||
```python
|
||||
from distributed import LocalCluster, Client
|
||||
|
||||
if __name__ == '__main__':
|
||||
client = Client(LocalCluster())
|
||||
|
||||
seeds = [42, 123, 777]
|
||||
networks = []
|
||||
|
||||
for seed in seeds:
|
||||
net = grnboost2(expression_data=matrix, client_or_address=client, seed=seed)
|
||||
networks.append(net)
|
||||
|
||||
# Combine networks and filter consensus links
|
||||
consensus = analyze_consensus(networks)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Memory errors**: Reduce dataset size by filtering low-variance genes or use distributed computing
|
||||
|
||||
**Slow performance**: Use GRNBoost2 instead of GENIE3, enable distributed client, filter TF list
|
||||
|
||||
**Dask errors**: Ensure `if __name__ == '__main__':` guard is present in scripts
|
||||
|
||||
**Empty results**: Check data format (genes as columns), verify TF names match gene names
|
||||
138
scientific-skills/arboreto/references/algorithms.md
Normal file
138
scientific-skills/arboreto/references/algorithms.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# GRN Inference Algorithms
|
||||
|
||||
Arboreto provides two algorithms for gene regulatory network (GRN) inference, both based on the multiple regression approach.
|
||||
|
||||
## Algorithm Overview
|
||||
|
||||
Both algorithms follow the same inference strategy:
|
||||
1. For each target gene in the dataset, train a regression model
|
||||
2. Identify the most important features (potential regulators) from the model
|
||||
3. Emit these features as candidate regulators with importance scores
|
||||
|
||||
The key difference is **computational efficiency** and the underlying regression method.
|
||||
|
||||
## GRNBoost2 (Recommended)
|
||||
|
||||
**Purpose**: Fast GRN inference for large-scale datasets using gradient boosting.
|
||||
|
||||
### When to Use
|
||||
- **Large datasets**: Tens of thousands of observations (e.g., single-cell RNA-seq)
|
||||
- **Time-constrained analysis**: Need faster results than GENIE3
|
||||
- **Default choice**: GRNBoost2 is the flagship algorithm and recommended for most use cases
|
||||
|
||||
### Technical Details
|
||||
- **Method**: Stochastic gradient boosting with early-stopping regularization
|
||||
- **Performance**: Significantly faster than GENIE3 on large datasets
|
||||
- **Output**: Same format as GENIE3 (TF-target-importance triplets)
|
||||
|
||||
### Usage
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
seed=42 # For reproducibility
|
||||
)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
```python
|
||||
grnboost2(
|
||||
expression_data, # Required: pandas DataFrame or numpy array
|
||||
gene_names=None, # Required for numpy arrays
|
||||
tf_names='all', # List of TF names or 'all'
|
||||
verbose=False, # Print progress messages
|
||||
client_or_address='local', # Dask client or scheduler address
|
||||
seed=None # Random seed for reproducibility
|
||||
)
|
||||
```
|
||||
|
||||
## GENIE3
|
||||
|
||||
**Purpose**: Classic Random Forest-based GRN inference, serving as the conceptual blueprint.
|
||||
|
||||
### When to Use
|
||||
- **Smaller datasets**: When dataset size allows for longer computation
|
||||
- **Comparison studies**: When comparing with published GENIE3 results
|
||||
- **Validation**: To validate GRNBoost2 results
|
||||
|
||||
### Technical Details
|
||||
- **Method**: Random Forest or ExtraTrees regression
|
||||
- **Foundation**: Original multiple regression GRN inference strategy
|
||||
- **Trade-off**: More computationally expensive but well-established
|
||||
|
||||
### Usage
|
||||
```python
|
||||
from arboreto.algo import genie3
|
||||
|
||||
network = genie3(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
seed=42
|
||||
)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
```python
|
||||
genie3(
|
||||
expression_data, # Required: pandas DataFrame or numpy array
|
||||
gene_names=None, # Required for numpy arrays
|
||||
tf_names='all', # List of TF names or 'all'
|
||||
verbose=False, # Print progress messages
|
||||
client_or_address='local', # Dask client or scheduler address
|
||||
seed=None # Random seed for reproducibility
|
||||
)
|
||||
```
|
||||
|
||||
## Algorithm Comparison
|
||||
|
||||
| Feature | GRNBoost2 | GENIE3 |
|
||||
|---------|-----------|--------|
|
||||
| **Speed** | Fast (optimized for large data) | Slower |
|
||||
| **Method** | Gradient boosting | Random Forest |
|
||||
| **Best for** | Large-scale data (10k+ observations) | Small-medium datasets |
|
||||
| **Output format** | Same | Same |
|
||||
| **Inference strategy** | Multiple regression | Multiple regression |
|
||||
| **Recommended** | Yes (default choice) | For comparison/validation |
|
||||
|
||||
## Advanced: Custom Regressor Parameters
|
||||
|
||||
For advanced users, pass custom scikit-learn regressor parameters:
|
||||
|
||||
```python
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
|
||||
# Custom GRNBoost2 parameters
|
||||
custom_grnboost2 = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
regressor_type='GBM',
|
||||
regressor_kwargs={
|
||||
'n_estimators': 100,
|
||||
'max_depth': 5,
|
||||
'learning_rate': 0.1
|
||||
}
|
||||
)
|
||||
|
||||
# Custom GENIE3 parameters
|
||||
custom_genie3 = genie3(
|
||||
expression_data=expression_matrix,
|
||||
regressor_type='RF',
|
||||
regressor_kwargs={
|
||||
'n_estimators': 1000,
|
||||
'max_features': 'sqrt'
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Choosing the Right Algorithm
|
||||
|
||||
**Decision guide**:
|
||||
|
||||
1. **Start with GRNBoost2** - It's faster and handles large datasets better
|
||||
2. **Use GENIE3 if**:
|
||||
- Comparing with existing GENIE3 publications
|
||||
- Dataset is small-medium sized
|
||||
- Validating GRNBoost2 results
|
||||
|
||||
Both algorithms produce comparable regulatory networks with the same output format, making them interchangeable for most analyses.
|
||||
151
scientific-skills/arboreto/references/basic_inference.md
Normal file
151
scientific-skills/arboreto/references/basic_inference.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# Basic GRN Inference with Arboreto
|
||||
|
||||
## Input Data Requirements
|
||||
|
||||
Arboreto requires gene expression data in one of two formats:
|
||||
|
||||
### Pandas DataFrame (Recommended)
|
||||
- **Rows**: Observations (cells, samples, conditions)
|
||||
- **Columns**: Genes (with gene names as column headers)
|
||||
- **Format**: Numeric expression values
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
# Load expression matrix with genes as columns
|
||||
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
|
||||
# Columns: ['gene1', 'gene2', 'gene3', ...]
|
||||
# Rows: observation data
|
||||
```
|
||||
|
||||
### NumPy Array
|
||||
- **Shape**: (observations, genes)
|
||||
- **Requirement**: Separately provide gene names list matching column order
|
||||
|
||||
Example:
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
expression_matrix = np.genfromtxt('expression_data.tsv', delimiter='\t', skip_header=1)
|
||||
with open('expression_data.tsv') as f:
|
||||
gene_names = [gene.strip() for gene in f.readline().split('\t')]
|
||||
|
||||
assert expression_matrix.shape[1] == len(gene_names)
|
||||
```
|
||||
|
||||
## Transcription Factors (TFs)
|
||||
|
||||
Optionally provide a list of transcription factor names to restrict regulatory inference:
|
||||
|
||||
```python
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
# Load from file (one TF per line)
|
||||
tf_names = load_tf_names('transcription_factors.txt')
|
||||
|
||||
# Or define directly
|
||||
tf_names = ['TF1', 'TF2', 'TF3']
|
||||
```
|
||||
|
||||
If not provided, all genes are considered potential regulators.
|
||||
|
||||
## Basic Inference Workflow
|
||||
|
||||
### Using Pandas DataFrame
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from arboreto.utils import load_tf_names
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load expression data
|
||||
expression_matrix = pd.read_csv('expression_data.tsv', sep='\t')
|
||||
|
||||
# Load transcription factors (optional)
|
||||
tf_names = load_tf_names('tf_list.txt')
|
||||
|
||||
# Run GRN inference
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names # Optional
|
||||
)
|
||||
|
||||
# Save results
|
||||
network.to_csv('network_output.tsv', sep='\t', index=False, header=False)
|
||||
```
|
||||
|
||||
**Critical**: The `if __name__ == '__main__':` guard is required because Dask spawns new processes internally.
|
||||
|
||||
### Using NumPy Array
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load expression matrix
|
||||
expression_matrix = np.genfromtxt('expression_data.tsv', delimiter='\t', skip_header=1)
|
||||
|
||||
# Extract gene names from header
|
||||
with open('expression_data.tsv') as f:
|
||||
gene_names = [gene.strip() for gene in f.readline().split('\t')]
|
||||
|
||||
# Verify dimensions match
|
||||
assert expression_matrix.shape[1] == len(gene_names)
|
||||
|
||||
# Run inference with explicit gene names
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
gene_names=gene_names,
|
||||
tf_names=tf_names
|
||||
)
|
||||
|
||||
network.to_csv('network_output.tsv', sep='\t', index=False, header=False)
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
Arboreto returns a Pandas DataFrame with three columns:
|
||||
|
||||
| Column | Description |
|
||||
|--------|-------------|
|
||||
| `TF` | Transcription factor (regulator) gene name |
|
||||
| `target` | Target gene name |
|
||||
| `importance` | Regulatory importance score (higher = stronger regulation) |
|
||||
|
||||
Example output:
|
||||
```
|
||||
TF1 gene5 0.856
|
||||
TF2 gene12 0.743
|
||||
TF1 gene8 0.621
|
||||
```
|
||||
|
||||
## Setting Random Seed
|
||||
|
||||
For reproducible results, provide a seed parameter:
|
||||
|
||||
```python
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
seed=777
|
||||
)
|
||||
```
|
||||
|
||||
## Algorithm Selection
|
||||
|
||||
Use `grnboost2()` for most cases (faster, handles large datasets):
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
network = grnboost2(expression_data=expression_matrix)
|
||||
```
|
||||
|
||||
Use `genie3()` for comparison or specific requirements:
|
||||
```python
|
||||
from arboreto.algo import genie3
|
||||
network = genie3(expression_data=expression_matrix)
|
||||
```
|
||||
|
||||
See `references/algorithms.md` for detailed algorithm comparison.
|
||||
242
scientific-skills/arboreto/references/distributed_computing.md
Normal file
242
scientific-skills/arboreto/references/distributed_computing.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# Distributed Computing with Arboreto
|
||||
|
||||
Arboreto leverages Dask for parallelized computation, enabling efficient GRN inference from single-machine multi-core processing to multi-node cluster environments.
|
||||
|
||||
## Computation Architecture
|
||||
|
||||
GRN inference is inherently parallelizable:
|
||||
- Each target gene's regression model can be trained independently
|
||||
- Arboreto represents computation as a Dask task graph
|
||||
- Tasks are distributed across available computational resources
|
||||
|
||||
## Local Multi-Core Processing (Default)
|
||||
|
||||
By default, arboreto uses all available CPU cores on the local machine:
|
||||
|
||||
```python
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
# Automatically uses all local cores
|
||||
network = grnboost2(expression_data=expression_matrix, tf_names=tf_names)
|
||||
```
|
||||
|
||||
This is sufficient for most use cases and requires no additional configuration.
|
||||
|
||||
## Custom Local Dask Client
|
||||
|
||||
For fine-grained control over local resources, create a custom Dask client:
|
||||
|
||||
```python
|
||||
from distributed import LocalCluster, Client
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Configure local cluster
|
||||
local_cluster = LocalCluster(
|
||||
n_workers=10, # Number of worker processes
|
||||
threads_per_worker=1, # Threads per worker
|
||||
memory_limit='8GB' # Memory limit per worker
|
||||
)
|
||||
|
||||
# Create client
|
||||
custom_client = Client(local_cluster)
|
||||
|
||||
# Run inference with custom client
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
client_or_address=custom_client
|
||||
)
|
||||
|
||||
# Clean up
|
||||
custom_client.close()
|
||||
local_cluster.close()
|
||||
```
|
||||
|
||||
### Benefits of Custom Client
|
||||
- **Resource control**: Limit CPU and memory usage
|
||||
- **Multiple runs**: Reuse same client for different parameter sets
|
||||
- **Monitoring**: Access Dask dashboard for performance insights
|
||||
|
||||
## Multiple Inference Runs with Same Client
|
||||
|
||||
Reuse a single Dask client for multiple inference runs with different parameters:
|
||||
|
||||
```python
|
||||
from distributed import LocalCluster, Client
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Initialize client once
|
||||
local_cluster = LocalCluster(n_workers=8, threads_per_worker=1)
|
||||
client = Client(local_cluster)
|
||||
|
||||
# Run multiple inferences
|
||||
network_seed1 = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
client_or_address=client,
|
||||
seed=666
|
||||
)
|
||||
|
||||
network_seed2 = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
client_or_address=client,
|
||||
seed=777
|
||||
)
|
||||
|
||||
# Different algorithms with same client
|
||||
from arboreto.algo import genie3
|
||||
network_genie3 = genie3(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
client_or_address=client
|
||||
)
|
||||
|
||||
# Clean up once
|
||||
client.close()
|
||||
local_cluster.close()
|
||||
```
|
||||
|
||||
## Distributed Cluster Computing
|
||||
|
||||
For very large datasets, connect to a remote Dask distributed scheduler running on a cluster:
|
||||
|
||||
### Step 1: Set Up Dask Scheduler (on cluster head node)
|
||||
```bash
|
||||
dask-scheduler
|
||||
# Output: Scheduler at tcp://10.118.224.134:8786
|
||||
```
|
||||
|
||||
### Step 2: Start Dask Workers (on cluster compute nodes)
|
||||
```bash
|
||||
dask-worker tcp://10.118.224.134:8786
|
||||
```
|
||||
|
||||
### Step 3: Connect from Client
|
||||
```python
|
||||
from distributed import Client
|
||||
from arboreto.algo import grnboost2
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Connect to remote scheduler
|
||||
scheduler_address = 'tcp://10.118.224.134:8786'
|
||||
cluster_client = Client(scheduler_address)
|
||||
|
||||
# Run inference on cluster
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
client_or_address=cluster_client
|
||||
)
|
||||
|
||||
cluster_client.close()
|
||||
```
|
||||
|
||||
### Cluster Configuration Best Practices
|
||||
|
||||
**Worker configuration**:
|
||||
```bash
|
||||
dask-worker tcp://scheduler:8786 \
|
||||
--nprocs 4 \ # Number of processes per node
|
||||
--nthreads 1 \ # Threads per process
|
||||
--memory-limit 16GB # Memory per process
|
||||
```
|
||||
|
||||
**For large-scale inference**:
|
||||
- Use more workers with moderate memory rather than fewer workers with large memory
|
||||
- Set `threads_per_worker=1` to avoid GIL contention in scikit-learn
|
||||
- Monitor memory usage to prevent workers from being killed
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Dask Dashboard
|
||||
|
||||
Access the Dask dashboard for real-time monitoring:
|
||||
|
||||
```python
|
||||
from distributed import Client
|
||||
|
||||
client = Client() # Prints dashboard URL
|
||||
# Dashboard available at: http://localhost:8787/status
|
||||
```
|
||||
|
||||
The dashboard shows:
|
||||
- **Task progress**: Number of tasks completed/pending
|
||||
- **Resource usage**: CPU, memory per worker
|
||||
- **Task stream**: Real-time visualization of computation
|
||||
- **Performance**: Bottleneck identification
|
||||
|
||||
### Verbose Output
|
||||
|
||||
Enable verbose logging to track inference progress:
|
||||
|
||||
```python
|
||||
network = grnboost2(
|
||||
expression_data=expression_matrix,
|
||||
tf_names=tf_names,
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
### 1. Data Format
|
||||
- **Use Pandas DataFrame when possible**: More efficient than NumPy for Dask operations
|
||||
- **Reduce data size**: Filter low-variance genes before inference
|
||||
|
||||
### 2. Worker Configuration
|
||||
- **CPU-bound tasks**: Set `threads_per_worker=1`, increase `n_workers`
|
||||
- **Memory-bound tasks**: Increase `memory_limit` per worker
|
||||
|
||||
### 3. Cluster Setup
|
||||
- **Network**: Ensure high-bandwidth, low-latency network between nodes
|
||||
- **Storage**: Use shared filesystem or object storage for large datasets
|
||||
- **Scheduling**: Allocate dedicated nodes to avoid resource contention
|
||||
|
||||
### 4. Transcription Factor Filtering
|
||||
- **Limit TF list**: Providing specific TF names reduces computation
|
||||
```python
|
||||
# Full search (slow)
|
||||
network = grnboost2(expression_data=matrix)
|
||||
|
||||
# Filtered search (faster)
|
||||
network = grnboost2(expression_data=matrix, tf_names=known_tfs)
|
||||
```
|
||||
|
||||
## Example: Large-Scale Single-Cell Analysis
|
||||
|
||||
Complete workflow for processing single-cell RNA-seq data on a cluster:
|
||||
|
||||
```python
|
||||
from distributed import Client
|
||||
from arboreto.algo import grnboost2
|
||||
import pandas as pd
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Connect to cluster
|
||||
client = Client('tcp://cluster-scheduler:8786')
|
||||
|
||||
# Load large single-cell dataset (50,000 cells x 20,000 genes)
|
||||
expression_data = pd.read_csv('scrnaseq_data.tsv', sep='\t')
|
||||
|
||||
# Load cell-type-specific TFs
|
||||
tf_names = pd.read_csv('tf_list.txt', header=None)[0].tolist()
|
||||
|
||||
# Run distributed inference
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
client_or_address=client,
|
||||
verbose=True,
|
||||
seed=42
|
||||
)
|
||||
|
||||
# Save results
|
||||
network.to_csv('grn_results.tsv', sep='\t', index=False)
|
||||
|
||||
client.close()
|
||||
```
|
||||
|
||||
This approach enables analysis of datasets that would be impractical on a single machine.
|
||||
97
scientific-skills/arboreto/scripts/basic_grn_inference.py
Normal file
97
scientific-skills/arboreto/scripts/basic_grn_inference.py
Normal file
@@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Basic GRN inference example using Arboreto.
|
||||
|
||||
This script demonstrates the standard workflow for inferring gene regulatory
|
||||
networks from expression data using GRNBoost2.
|
||||
|
||||
Usage:
|
||||
python basic_grn_inference.py <expression_file> <output_file> [--tf-file TF_FILE] [--seed SEED]
|
||||
|
||||
Arguments:
|
||||
expression_file: Path to expression matrix (TSV format, genes as columns)
|
||||
output_file: Path for output network (TSV format)
|
||||
--tf-file: Optional path to transcription factors file (one per line)
|
||||
--seed: Random seed for reproducibility (default: 777)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from arboreto.algo import grnboost2
|
||||
from arboreto.utils import load_tf_names
|
||||
|
||||
|
||||
def run_grn_inference(expression_file, output_file, tf_file=None, seed=777):
|
||||
"""
|
||||
Run GRN inference using GRNBoost2.
|
||||
|
||||
Args:
|
||||
expression_file: Path to expression matrix TSV file
|
||||
output_file: Path for output network file
|
||||
tf_file: Optional path to TF names file
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
print(f"Loading expression data from {expression_file}...")
|
||||
expression_data = pd.read_csv(expression_file, sep='\t')
|
||||
|
||||
print(f"Expression matrix shape: {expression_data.shape}")
|
||||
print(f"Number of genes: {expression_data.shape[1]}")
|
||||
print(f"Number of observations: {expression_data.shape[0]}")
|
||||
|
||||
# Load TF names if provided
|
||||
tf_names = 'all'
|
||||
if tf_file:
|
||||
print(f"Loading transcription factors from {tf_file}...")
|
||||
tf_names = load_tf_names(tf_file)
|
||||
print(f"Number of TFs: {len(tf_names)}")
|
||||
|
||||
# Run GRN inference
|
||||
print(f"Running GRNBoost2 with seed={seed}...")
|
||||
network = grnboost2(
|
||||
expression_data=expression_data,
|
||||
tf_names=tf_names,
|
||||
seed=seed,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Save results
|
||||
print(f"Saving network to {output_file}...")
|
||||
network.to_csv(output_file, sep='\t', index=False, header=False)
|
||||
|
||||
print(f"Done! Network contains {len(network)} regulatory links.")
|
||||
print(f"\nTop 10 regulatory links:")
|
||||
print(network.head(10).to_string(index=False))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Infer gene regulatory network using GRNBoost2'
|
||||
)
|
||||
parser.add_argument(
|
||||
'expression_file',
|
||||
help='Path to expression matrix (TSV format, genes as columns)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'output_file',
|
||||
help='Path for output network (TSV format)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tf-file',
|
||||
help='Path to transcription factors file (one per line)',
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
help='Random seed for reproducibility (default: 777)',
|
||||
type=int,
|
||||
default=777
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_grn_inference(
|
||||
expression_file=args.expression_file,
|
||||
output_file=args.output_file,
|
||||
tf_file=args.tf_file,
|
||||
seed=args.seed
|
||||
)
|
||||
328
scientific-skills/astropy/SKILL.md
Normal file
328
scientific-skills/astropy/SKILL.md
Normal file
@@ -0,0 +1,328 @@
|
||||
---
|
||||
name: astropy
|
||||
description: Comprehensive Python library for astronomy and astrophysics. This skill should be used when working with astronomical data including celestial coordinates, physical units, FITS files, cosmological calculations, time systems, tables, world coordinate systems (WCS), and astronomical data analysis. Use when tasks involve coordinate transformations, unit conversions, FITS file manipulation, cosmological distance calculations, time scale conversions, or astronomical data processing.
|
||||
license: BSD-3-Clause license
|
||||
metadata:
|
||||
skill-author: K-Dense Inc.
|
||||
---
|
||||
|
||||
# Astropy
|
||||
|
||||
## Overview
|
||||
|
||||
Astropy is the core Python package for astronomy, providing essential functionality for astronomical research and data analysis. Use astropy for coordinate transformations, unit and quantity calculations, FITS file operations, cosmological calculations, precise time handling, tabular data manipulation, and astronomical image processing.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use astropy when tasks involve:
|
||||
- Converting between celestial coordinate systems (ICRS, Galactic, FK5, AltAz, etc.)
|
||||
- Working with physical units and quantities (converting Jy to mJy, parsecs to km, etc.)
|
||||
- Reading, writing, or manipulating FITS files (images or tables)
|
||||
- Cosmological calculations (luminosity distance, lookback time, Hubble parameter)
|
||||
- Precise time handling with different time scales (UTC, TAI, TT, TDB) and formats (JD, MJD, ISO)
|
||||
- Table operations (reading catalogs, cross-matching, filtering, joining)
|
||||
- WCS transformations between pixel and world coordinates
|
||||
- Astronomical constants and calculations
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import astropy.units as u
|
||||
from astropy.coordinates import SkyCoord
|
||||
from astropy.time import Time
|
||||
from astropy.io import fits
|
||||
from astropy.table import Table
|
||||
from astropy.cosmology import Planck18
|
||||
|
||||
# Units and quantities
|
||||
distance = 100 * u.pc
|
||||
distance_km = distance.to(u.km)
|
||||
|
||||
# Coordinates
|
||||
coord = SkyCoord(ra=10.5*u.degree, dec=41.2*u.degree, frame='icrs')
|
||||
coord_galactic = coord.galactic
|
||||
|
||||
# Time
|
||||
t = Time('2023-01-15 12:30:00')
|
||||
jd = t.jd # Julian Date
|
||||
|
||||
# FITS files
|
||||
data = fits.getdata('image.fits')
|
||||
header = fits.getheader('image.fits')
|
||||
|
||||
# Tables
|
||||
table = Table.read('catalog.fits')
|
||||
|
||||
# Cosmology
|
||||
d_L = Planck18.luminosity_distance(z=1.0)
|
||||
```
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Units and Quantities (`astropy.units`)
|
||||
|
||||
Handle physical quantities with units, perform unit conversions, and ensure dimensional consistency in calculations.
|
||||
|
||||
**Key operations:**
|
||||
- Create quantities by multiplying values with units
|
||||
- Convert between units using `.to()` method
|
||||
- Perform arithmetic with automatic unit handling
|
||||
- Use equivalencies for domain-specific conversions (spectral, doppler, parallax)
|
||||
- Work with logarithmic units (magnitudes, decibels)
|
||||
|
||||
**See:** `references/units.md` for comprehensive documentation, unit systems, equivalencies, performance optimization, and unit arithmetic.
|
||||
|
||||
### 2. Coordinate Systems (`astropy.coordinates`)
|
||||
|
||||
Represent celestial positions and transform between different coordinate frames.
|
||||
|
||||
**Key operations:**
|
||||
- Create coordinates with `SkyCoord` in any frame (ICRS, Galactic, FK5, AltAz, etc.)
|
||||
- Transform between coordinate systems
|
||||
- Calculate angular separations and position angles
|
||||
- Match coordinates to catalogs
|
||||
- Include distance for 3D coordinate operations
|
||||
- Handle proper motions and radial velocities
|
||||
- Query named objects from online databases
|
||||
|
||||
**See:** `references/coordinates.md` for detailed coordinate frame descriptions, transformations, observer-dependent frames (AltAz), catalog matching, and performance tips.
|
||||
|
||||
### 3. Cosmological Calculations (`astropy.cosmology`)
|
||||
|
||||
Perform cosmological calculations using standard cosmological models.
|
||||
|
||||
**Key operations:**
|
||||
- Use built-in cosmologies (Planck18, WMAP9, etc.)
|
||||
- Create custom cosmological models
|
||||
- Calculate distances (luminosity, comoving, angular diameter)
|
||||
- Compute ages and lookback times
|
||||
- Determine Hubble parameter at any redshift
|
||||
- Calculate density parameters and volumes
|
||||
- Perform inverse calculations (find z for given distance)
|
||||
|
||||
**See:** `references/cosmology.md` for available models, distance calculations, time calculations, density parameters, and neutrino effects.
|
||||
|
||||
### 4. FITS File Handling (`astropy.io.fits`)
|
||||
|
||||
Read, write, and manipulate FITS (Flexible Image Transport System) files.
|
||||
|
||||
**Key operations:**
|
||||
- Open FITS files with context managers
|
||||
- Access HDUs (Header Data Units) by index or name
|
||||
- Read and modify headers (keywords, comments, history)
|
||||
- Work with image data (NumPy arrays)
|
||||
- Handle table data (binary and ASCII tables)
|
||||
- Create new FITS files (single or multi-extension)
|
||||
- Use memory mapping for large files
|
||||
- Access remote FITS files (S3, HTTP)
|
||||
|
||||
**See:** `references/fits.md` for comprehensive file operations, header manipulation, image and table handling, multi-extension files, and performance considerations.
|
||||
|
||||
### 5. Table Operations (`astropy.table`)
|
||||
|
||||
Work with tabular data with support for units, metadata, and various file formats.
|
||||
|
||||
**Key operations:**
|
||||
- Create tables from arrays, lists, or dictionaries
|
||||
- Read/write tables in multiple formats (FITS, CSV, HDF5, VOTable)
|
||||
- Access and modify columns and rows
|
||||
- Sort, filter, and index tables
|
||||
- Perform database-style operations (join, group, aggregate)
|
||||
- Stack and concatenate tables
|
||||
- Work with unit-aware columns (QTable)
|
||||
- Handle missing data with masking
|
||||
|
||||
**See:** `references/tables.md` for table creation, I/O operations, data manipulation, sorting, filtering, joins, grouping, and performance tips.
|
||||
|
||||
### 6. Time Handling (`astropy.time`)
|
||||
|
||||
Precise time representation and conversion between time scales and formats.
|
||||
|
||||
**Key operations:**
|
||||
- Create Time objects in various formats (ISO, JD, MJD, Unix, etc.)
|
||||
- Convert between time scales (UTC, TAI, TT, TDB, etc.)
|
||||
- Perform time arithmetic with TimeDelta
|
||||
- Calculate sidereal time for observers
|
||||
- Compute light travel time corrections (barycentric, heliocentric)
|
||||
- Work with time arrays efficiently
|
||||
- Handle masked (missing) times
|
||||
|
||||
**See:** `references/time.md` for time formats, time scales, conversions, arithmetic, observing features, and precision handling.
|
||||
|
||||
### 7. World Coordinate System (`astropy.wcs`)
|
||||
|
||||
Transform between pixel coordinates in images and world coordinates.
|
||||
|
||||
**Key operations:**
|
||||
- Read WCS from FITS headers
|
||||
- Convert pixel coordinates to world coordinates (and vice versa)
|
||||
- Calculate image footprints
|
||||
- Access WCS parameters (reference pixel, projection, scale)
|
||||
- Create custom WCS objects
|
||||
|
||||
**See:** `references/wcs_and_other_modules.md` for WCS operations and transformations.
|
||||
|
||||
## Additional Capabilities
|
||||
|
||||
The `references/wcs_and_other_modules.md` file also covers:
|
||||
|
||||
### NDData and CCDData
|
||||
Containers for n-dimensional datasets with metadata, uncertainty, masking, and WCS information.
|
||||
|
||||
### Modeling
|
||||
Framework for creating and fitting mathematical models to astronomical data.
|
||||
|
||||
### Visualization
|
||||
Tools for astronomical image display with appropriate stretching and scaling.
|
||||
|
||||
### Constants
|
||||
Physical and astronomical constants with proper units (speed of light, solar mass, Planck constant, etc.).
|
||||
|
||||
### Convolution
|
||||
Image processing kernels for smoothing and filtering.
|
||||
|
||||
### Statistics
|
||||
Robust statistical functions including sigma clipping and outlier rejection.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install astropy
|
||||
uv pip install astropy
|
||||
|
||||
# With optional dependencies for full functionality
|
||||
uv pip install astropy[all]
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Converting Coordinates Between Systems
|
||||
|
||||
```python
|
||||
from astropy.coordinates import SkyCoord
|
||||
import astropy.units as u
|
||||
|
||||
# Create coordinate
|
||||
c = SkyCoord(ra='05h23m34.5s', dec='-69d45m22s', frame='icrs')
|
||||
|
||||
# Transform to galactic
|
||||
c_gal = c.galactic
|
||||
print(f"l={c_gal.l.deg}, b={c_gal.b.deg}")
|
||||
|
||||
# Transform to alt-az (requires time and location)
|
||||
from astropy.time import Time
|
||||
from astropy.coordinates import EarthLocation, AltAz
|
||||
|
||||
observing_time = Time('2023-06-15 23:00:00')
|
||||
observing_location = EarthLocation(lat=40*u.deg, lon=-120*u.deg)
|
||||
aa_frame = AltAz(obstime=observing_time, location=observing_location)
|
||||
c_altaz = c.transform_to(aa_frame)
|
||||
print(f"Alt={c_altaz.alt.deg}, Az={c_altaz.az.deg}")
|
||||
```
|
||||
|
||||
### Reading and Analyzing FITS Files
|
||||
|
||||
```python
|
||||
from astropy.io import fits
|
||||
import numpy as np
|
||||
|
||||
# Open FITS file
|
||||
with fits.open('observation.fits') as hdul:
|
||||
# Display structure
|
||||
hdul.info()
|
||||
|
||||
# Get image data and header
|
||||
data = hdul[1].data
|
||||
header = hdul[1].header
|
||||
|
||||
# Access header values
|
||||
exptime = header['EXPTIME']
|
||||
filter_name = header['FILTER']
|
||||
|
||||
# Analyze data
|
||||
mean = np.mean(data)
|
||||
median = np.median(data)
|
||||
print(f"Mean: {mean}, Median: {median}")
|
||||
```
|
||||
|
||||
### Cosmological Distance Calculations
|
||||
|
||||
```python
|
||||
from astropy.cosmology import Planck18
|
||||
import astropy.units as u
|
||||
import numpy as np
|
||||
|
||||
# Calculate distances at z=1.5
|
||||
z = 1.5
|
||||
d_L = Planck18.luminosity_distance(z)
|
||||
d_A = Planck18.angular_diameter_distance(z)
|
||||
|
||||
print(f"Luminosity distance: {d_L}")
|
||||
print(f"Angular diameter distance: {d_A}")
|
||||
|
||||
# Age of universe at that redshift
|
||||
age = Planck18.age(z)
|
||||
print(f"Age at z={z}: {age.to(u.Gyr)}")
|
||||
|
||||
# Lookback time
|
||||
t_lookback = Planck18.lookback_time(z)
|
||||
print(f"Lookback time: {t_lookback.to(u.Gyr)}")
|
||||
```
|
||||
|
||||
### Cross-Matching Catalogs
|
||||
|
||||
```python
|
||||
from astropy.table import Table
|
||||
from astropy.coordinates import SkyCoord, match_coordinates_sky
|
||||
import astropy.units as u
|
||||
|
||||
# Read catalogs
|
||||
cat1 = Table.read('catalog1.fits')
|
||||
cat2 = Table.read('catalog2.fits')
|
||||
|
||||
# Create coordinate objects
|
||||
coords1 = SkyCoord(ra=cat1['RA']*u.degree, dec=cat1['DEC']*u.degree)
|
||||
coords2 = SkyCoord(ra=cat2['RA']*u.degree, dec=cat2['DEC']*u.degree)
|
||||
|
||||
# Find matches
|
||||
idx, sep, _ = coords1.match_to_catalog_sky(coords2)
|
||||
|
||||
# Filter by separation threshold
|
||||
max_sep = 1 * u.arcsec
|
||||
matches = sep < max_sep
|
||||
|
||||
# Create matched catalogs
|
||||
cat1_matched = cat1[matches]
|
||||
cat2_matched = cat2[idx[matches]]
|
||||
print(f"Found {len(cat1_matched)} matches")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use units**: Attach units to quantities to avoid errors and ensure dimensional consistency
|
||||
2. **Use context managers for FITS files**: Ensures proper file closing
|
||||
3. **Prefer arrays over loops**: Process multiple coordinates/times as arrays for better performance
|
||||
4. **Check coordinate frames**: Verify the frame before transformations
|
||||
5. **Use appropriate cosmology**: Choose the right cosmological model for your analysis
|
||||
6. **Handle missing data**: Use masked columns for tables with missing values
|
||||
7. **Specify time scales**: Be explicit about time scales (UTC, TT, TDB) for precise timing
|
||||
8. **Use QTable for unit-aware tables**: When table columns have units
|
||||
9. **Check WCS validity**: Verify WCS before using transformations
|
||||
10. **Cache frequently used values**: Expensive calculations (e.g., cosmological distances) can be cached
|
||||
|
||||
## Documentation and Resources
|
||||
|
||||
- Official Astropy Documentation: https://docs.astropy.org/en/stable/
|
||||
- Tutorials: https://learn.astropy.org/
|
||||
- GitHub: https://github.com/astropy/astropy
|
||||
|
||||
## Reference Files
|
||||
|
||||
For detailed information on specific modules:
|
||||
- `references/units.md` - Units, quantities, conversions, and equivalencies
|
||||
- `references/coordinates.md` - Coordinate systems, transformations, and catalog matching
|
||||
- `references/cosmology.md` - Cosmological models and calculations
|
||||
- `references/fits.md` - FITS file operations and manipulation
|
||||
- `references/tables.md` - Table creation, I/O, and operations
|
||||
- `references/time.md` - Time formats, scales, and calculations
|
||||
- `references/wcs_and_other_modules.md` - WCS, NDData, modeling, visualization, constants, and utilities
|
||||
273
scientific-skills/astropy/references/coordinates.md
Normal file
273
scientific-skills/astropy/references/coordinates.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Astronomical Coordinates (astropy.coordinates)
|
||||
|
||||
The `astropy.coordinates` package provides tools for representing celestial coordinates and transforming between different coordinate systems.
|
||||
|
||||
## Creating Coordinates with SkyCoord
|
||||
|
||||
The high-level `SkyCoord` class is the recommended interface:
|
||||
|
||||
```python
|
||||
from astropy import units as u
|
||||
from astropy.coordinates import SkyCoord
|
||||
|
||||
# Decimal degrees
|
||||
c = SkyCoord(ra=10.625*u.degree, dec=41.2*u.degree, frame='icrs')
|
||||
|
||||
# Sexagesimal strings
|
||||
c = SkyCoord(ra='00h42m30s', dec='+41d12m00s', frame='icrs')
|
||||
|
||||
# Mixed formats
|
||||
c = SkyCoord('00h42.5m +41d12m', unit=(u.hourangle, u.deg))
|
||||
|
||||
# Galactic coordinates
|
||||
c = SkyCoord(l=120.5*u.degree, b=-23.4*u.degree, frame='galactic')
|
||||
```
|
||||
|
||||
## Array Coordinates
|
||||
|
||||
Process multiple coordinates efficiently using arrays:
|
||||
|
||||
```python
|
||||
# Create array of coordinates
|
||||
coords = SkyCoord(ra=[10, 11, 12]*u.degree,
|
||||
dec=[41, -5, 42]*u.degree)
|
||||
|
||||
# Access individual elements
|
||||
coords[0]
|
||||
coords[1:3]
|
||||
|
||||
# Array operations
|
||||
coords.shape
|
||||
len(coords)
|
||||
```
|
||||
|
||||
## Accessing Components
|
||||
|
||||
```python
|
||||
c = SkyCoord(ra=10.68*u.degree, dec=41.27*u.degree, frame='icrs')
|
||||
|
||||
# Access coordinates
|
||||
c.ra # <Longitude 10.68 deg>
|
||||
c.dec # <Latitude 41.27 deg>
|
||||
c.ra.hour # Convert to hours
|
||||
c.ra.hms # Hours, minutes, seconds tuple
|
||||
c.dec.dms # Degrees, arcminutes, arcseconds tuple
|
||||
```
|
||||
|
||||
## String Formatting
|
||||
|
||||
```python
|
||||
c.to_string('decimal') # '10.68 41.27'
|
||||
c.to_string('dms') # '10d40m48s 41d16m12s'
|
||||
c.to_string('hmsdms') # '00h42m43.2s +41d16m12s'
|
||||
|
||||
# Custom formatting
|
||||
c.ra.to_string(unit=u.hour, sep=':', precision=2)
|
||||
```
|
||||
|
||||
## Coordinate Transformations
|
||||
|
||||
Transform between reference frames:
|
||||
|
||||
```python
|
||||
c_icrs = SkyCoord(ra=10.68*u.degree, dec=41.27*u.degree, frame='icrs')
|
||||
|
||||
# Simple transformations (as attributes)
|
||||
c_galactic = c_icrs.galactic
|
||||
c_fk5 = c_icrs.fk5
|
||||
c_fk4 = c_icrs.fk4
|
||||
|
||||
# Explicit transformations
|
||||
c_icrs.transform_to('galactic')
|
||||
c_icrs.transform_to(FK5(equinox='J1975')) # Custom frame parameters
|
||||
```
|
||||
|
||||
## Common Coordinate Frames
|
||||
|
||||
### Celestial Frames
|
||||
- **ICRS**: International Celestial Reference System (default, most common)
|
||||
- **FK5**: Fifth Fundamental Catalogue (equinox J2000.0 by default)
|
||||
- **FK4**: Fourth Fundamental Catalogue (older, requires equinox specification)
|
||||
- **GCRS**: Geocentric Celestial Reference System
|
||||
- **CIRS**: Celestial Intermediate Reference System
|
||||
|
||||
### Galactic Frames
|
||||
- **Galactic**: IAU 1958 galactic coordinates
|
||||
- **Supergalactic**: De Vaucouleurs supergalactic coordinates
|
||||
- **Galactocentric**: Galactic center-based 3D coordinates
|
||||
|
||||
### Horizontal Frames
|
||||
- **AltAz**: Altitude-azimuth (observer-dependent)
|
||||
- **HADec**: Hour angle-declination
|
||||
|
||||
### Ecliptic Frames
|
||||
- **GeocentricMeanEcliptic**: Geocentric mean ecliptic
|
||||
- **BarycentricMeanEcliptic**: Barycentric mean ecliptic
|
||||
- **HeliocentricMeanEcliptic**: Heliocentric mean ecliptic
|
||||
|
||||
## Observer-Dependent Transformations
|
||||
|
||||
For altitude-azimuth coordinates, specify observation time and location:
|
||||
|
||||
```python
|
||||
from astropy.time import Time
|
||||
from astropy.coordinates import EarthLocation, AltAz
|
||||
|
||||
# Define observer location
|
||||
observing_location = EarthLocation(lat=40.8*u.deg, lon=-121.5*u.deg, height=1060*u.m)
|
||||
# Or use named observatory
|
||||
observing_location = EarthLocation.of_site('Apache Point Observatory')
|
||||
|
||||
# Define observation time
|
||||
observing_time = Time('2023-01-15 23:00:00')
|
||||
|
||||
# Transform to alt-az
|
||||
aa_frame = AltAz(obstime=observing_time, location=observing_location)
|
||||
aa = c_icrs.transform_to(aa_frame)
|
||||
|
||||
print(f"Altitude: {aa.alt}")
|
||||
print(f"Azimuth: {aa.az}")
|
||||
```
|
||||
|
||||
## Working with Distances
|
||||
|
||||
Add distance information for 3D coordinates:
|
||||
|
||||
```python
|
||||
# With distance
|
||||
c = SkyCoord(ra=10*u.degree, dec=9*u.degree, distance=770*u.kpc, frame='icrs')
|
||||
|
||||
# Access 3D Cartesian coordinates
|
||||
c.cartesian.x
|
||||
c.cartesian.y
|
||||
c.cartesian.z
|
||||
|
||||
# Distance from origin
|
||||
c.distance
|
||||
|
||||
# 3D separation
|
||||
c1 = SkyCoord(ra=10*u.degree, dec=9*u.degree, distance=10*u.pc)
|
||||
c2 = SkyCoord(ra=11*u.degree, dec=10*u.degree, distance=11.5*u.pc)
|
||||
sep_3d = c1.separation_3d(c2) # 3D distance
|
||||
```
|
||||
|
||||
## Angular Separation
|
||||
|
||||
Calculate on-sky separations:
|
||||
|
||||
```python
|
||||
c1 = SkyCoord(ra=10*u.degree, dec=9*u.degree, frame='icrs')
|
||||
c2 = SkyCoord(ra=11*u.degree, dec=10*u.degree, frame='fk5')
|
||||
|
||||
# Angular separation (handles frame conversion automatically)
|
||||
sep = c1.separation(c2)
|
||||
print(f"Separation: {sep.arcsec} arcsec")
|
||||
|
||||
# Position angle
|
||||
pa = c1.position_angle(c2)
|
||||
```
|
||||
|
||||
## Catalog Matching
|
||||
|
||||
Match coordinates to catalog sources:
|
||||
|
||||
```python
|
||||
# Single target matching
|
||||
catalog = SkyCoord(ra=ra_array*u.degree, dec=dec_array*u.degree)
|
||||
target = SkyCoord(ra=10.5*u.degree, dec=41.2*u.degree)
|
||||
|
||||
# Find closest match
|
||||
idx, sep2d, dist3d = target.match_to_catalog_sky(catalog)
|
||||
matched_coord = catalog[idx]
|
||||
|
||||
# Match with maximum separation constraint
|
||||
matches = target.separation(catalog) < 1*u.arcsec
|
||||
```
|
||||
|
||||
## Named Objects
|
||||
|
||||
Retrieve coordinates from online catalogs:
|
||||
|
||||
```python
|
||||
# Query by name (requires internet)
|
||||
m31 = SkyCoord.from_name("M31")
|
||||
crab = SkyCoord.from_name("Crab Nebula")
|
||||
psr = SkyCoord.from_name("PSR J1012+5307")
|
||||
```
|
||||
|
||||
## Earth Locations
|
||||
|
||||
Define observer locations:
|
||||
|
||||
```python
|
||||
# By coordinates
|
||||
location = EarthLocation(lat=40*u.deg, lon=-120*u.deg, height=1000*u.m)
|
||||
|
||||
# By named observatory
|
||||
keck = EarthLocation.of_site('Keck Observatory')
|
||||
vlt = EarthLocation.of_site('Paranal Observatory')
|
||||
|
||||
# By address (requires internet)
|
||||
location = EarthLocation.of_address('1002 Holy Grail Court, St. Louis, MO')
|
||||
|
||||
# List available observatories
|
||||
EarthLocation.get_site_names()
|
||||
```
|
||||
|
||||
## Velocity Information
|
||||
|
||||
Include proper motion and radial velocity:
|
||||
|
||||
```python
|
||||
# Proper motion
|
||||
c = SkyCoord(ra=10*u.degree, dec=41*u.degree,
|
||||
pm_ra_cosdec=15*u.mas/u.yr,
|
||||
pm_dec=5*u.mas/u.yr,
|
||||
distance=150*u.pc)
|
||||
|
||||
# Radial velocity
|
||||
c = SkyCoord(ra=10*u.degree, dec=41*u.degree,
|
||||
radial_velocity=20*u.km/u.s)
|
||||
|
||||
# Both
|
||||
c = SkyCoord(ra=10*u.degree, dec=41*u.degree, distance=150*u.pc,
|
||||
pm_ra_cosdec=15*u.mas/u.yr, pm_dec=5*u.mas/u.yr,
|
||||
radial_velocity=20*u.km/u.s)
|
||||
```
|
||||
|
||||
## Representation Types
|
||||
|
||||
Switch between coordinate representations:
|
||||
|
||||
```python
|
||||
# Cartesian representation
|
||||
c = SkyCoord(x=1*u.kpc, y=2*u.kpc, z=3*u.kpc,
|
||||
representation_type='cartesian', frame='icrs')
|
||||
|
||||
# Change representation
|
||||
c.representation_type = 'cylindrical'
|
||||
c.rho # Cylindrical radius
|
||||
c.phi # Azimuthal angle
|
||||
c.z # Height
|
||||
|
||||
# Spherical (default for most frames)
|
||||
c.representation_type = 'spherical'
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Use arrays, not loops**: Process multiple coordinates as single array
|
||||
2. **Pre-compute frames**: Reuse frame objects for multiple transformations
|
||||
3. **Use broadcasting**: Efficiently transform many positions across many times
|
||||
4. **Enable interpolation**: For dense time sampling, use ErfaAstromInterpolator
|
||||
|
||||
```python
|
||||
# Fast approach
|
||||
coords = SkyCoord(ra=ra_array*u.degree, dec=dec_array*u.degree)
|
||||
coords_transformed = coords.transform_to('galactic')
|
||||
|
||||
# Slow approach (avoid)
|
||||
for ra, dec in zip(ra_array, dec_array):
|
||||
c = SkyCoord(ra=ra*u.degree, dec=dec*u.degree)
|
||||
c_transformed = c.transform_to('galactic')
|
||||
```
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user