Files
claude-scientific-skills/scientific-skills/neuropixels-analysis/assets/analysis_template.py
Robert 312f18ae60 Add neuropixels-analysis skill for extracellular electrophysiology
Adds comprehensive toolkit for analyzing Neuropixels high-density neural
recordings using SpikeInterface, Allen Institute, and IBL best practices.

Features:
- Data loading from SpikeGLX, Open Ephys, and NWB formats
- Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection)
- Motion/drift estimation and correction
- Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5)
- Quality metrics computation (SNR, ISI violations, presence ratio)
- Automated curation using Allen/IBL criteria
- AI-assisted visual curation for uncertain units
- Export to Phy and NWB formats

Supports Neuropixels 1.0 and 2.0 probes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-17 11:06:28 -05:00

272 lines
8.8 KiB
Python

#!/usr/bin/env python
"""
Neuropixels Analysis Template
Complete analysis workflow from raw data to curated units.
Copy and customize this template for your analysis.
Usage:
1. Copy this file to your analysis directory
2. Update the PARAMETERS section
3. Run: python analysis_template.py
"""
# =============================================================================
# PARAMETERS - Customize these for your analysis
# =============================================================================
# Input/Output paths
DATA_PATH = '/path/to/your/spikeglx/data/'
OUTPUT_DIR = 'analysis_output/'
DATA_FORMAT = 'spikeglx' # 'spikeglx', 'openephys', or 'nwb'
STREAM_ID = 'imec0.ap' # For multi-probe recordings
# Preprocessing parameters
FREQ_MIN = 300 # Highpass filter (Hz)
FREQ_MAX = 6000 # Lowpass filter (Hz)
APPLY_PHASE_SHIFT = True
APPLY_CMR = True
DETECT_BAD_CHANNELS = True
# Motion correction
CORRECT_MOTION = True
MOTION_PRESET = 'nonrigid_accurate' # 'kilosort_like', 'nonrigid_fast_and_accurate'
# Spike sorting
SORTER = 'kilosort4' # 'kilosort4', 'spykingcircus2', 'mountainsort5'
SORTER_PARAMS = {
'batch_size': 30000,
'nblocks': 1, # Increase for long recordings with drift
}
# Quality metrics and curation
CURATION_METHOD = 'allen' # 'allen', 'ibl', 'strict'
# Processing
N_JOBS = -1 # -1 = all cores
# =============================================================================
# ANALYSIS PIPELINE - Usually no need to modify below
# =============================================================================
from pathlib import Path
import json
import spikeinterface.full as si
from spikeinterface.exporters import export_to_phy
def main():
"""Run the full analysis pipeline."""
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)
# =========================================================================
# 1. LOAD DATA
# =========================================================================
print("=" * 60)
print("1. LOADING DATA")
print("=" * 60)
if DATA_FORMAT == 'spikeglx':
recording = si.read_spikeglx(DATA_PATH, stream_id=STREAM_ID)
elif DATA_FORMAT == 'openephys':
recording = si.read_openephys(DATA_PATH)
elif DATA_FORMAT == 'nwb':
recording = si.read_nwb(DATA_PATH)
else:
raise ValueError(f"Unknown format: {DATA_FORMAT}")
print(f"Recording: {recording.get_num_channels()} channels")
print(f"Duration: {recording.get_total_duration():.1f} seconds")
print(f"Sampling rate: {recording.get_sampling_frequency()} Hz")
# =========================================================================
# 2. PREPROCESSING
# =========================================================================
print("\n" + "=" * 60)
print("2. PREPROCESSING")
print("=" * 60)
rec = recording
# Bandpass filter
print(f"Applying bandpass filter ({FREQ_MIN}-{FREQ_MAX} Hz)...")
rec = si.bandpass_filter(rec, freq_min=FREQ_MIN, freq_max=FREQ_MAX)
# Phase shift correction
if APPLY_PHASE_SHIFT:
print("Applying phase shift correction...")
rec = si.phase_shift(rec)
# Bad channel detection
if DETECT_BAD_CHANNELS:
print("Detecting bad channels...")
bad_ids, _ = si.detect_bad_channels(rec)
if len(bad_ids) > 0:
print(f" Removing {len(bad_ids)} bad channels")
rec = rec.remove_channels(bad_ids)
# Common median reference
if APPLY_CMR:
print("Applying common median reference...")
rec = si.common_reference(rec, operator='median', reference='global')
# Save preprocessed
print("Saving preprocessed recording...")
rec.save(folder=output_path / 'preprocessed', n_jobs=N_JOBS)
# =========================================================================
# 3. MOTION CORRECTION
# =========================================================================
if CORRECT_MOTION:
print("\n" + "=" * 60)
print("3. MOTION CORRECTION")
print("=" * 60)
print(f"Estimating and correcting motion (preset: {MOTION_PRESET})...")
rec = si.correct_motion(
rec,
preset=MOTION_PRESET,
folder=output_path / 'motion',
)
# =========================================================================
# 4. SPIKE SORTING
# =========================================================================
print("\n" + "=" * 60)
print("4. SPIKE SORTING")
print("=" * 60)
print(f"Running {SORTER}...")
sorting = si.run_sorter(
SORTER,
rec,
output_folder=output_path / f'{SORTER}_output',
verbose=True,
**SORTER_PARAMS,
)
print(f"Found {len(sorting.unit_ids)} units")
# =========================================================================
# 5. POSTPROCESSING
# =========================================================================
print("\n" + "=" * 60)
print("5. POSTPROCESSING")
print("=" * 60)
print("Creating SortingAnalyzer...")
analyzer = si.create_sorting_analyzer(
sorting,
rec,
format='binary_folder',
folder=output_path / 'analyzer',
sparse=True,
)
print("Computing extensions...")
analyzer.compute('random_spikes', max_spikes_per_unit=500)
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
analyzer.compute('templates', operators=['average', 'std'])
analyzer.compute('noise_levels')
analyzer.compute('spike_amplitudes')
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
analyzer.compute('unit_locations', method='monopolar_triangulation')
# =========================================================================
# 6. QUALITY METRICS
# =========================================================================
print("\n" + "=" * 60)
print("6. QUALITY METRICS")
print("=" * 60)
print("Computing quality metrics...")
metrics = si.compute_quality_metrics(
analyzer,
metric_names=[
'snr', 'isi_violations_ratio', 'presence_ratio',
'amplitude_cutoff', 'firing_rate', 'amplitude_cv',
],
n_jobs=N_JOBS,
)
metrics.to_csv(output_path / 'quality_metrics.csv')
print(f"Saved metrics to: {output_path / 'quality_metrics.csv'}")
# Print summary
print("\nMetrics summary:")
for col in ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']:
if col in metrics.columns:
print(f" {col}: {metrics[col].median():.4f} (median)")
# =========================================================================
# 7. CURATION
# =========================================================================
print("\n" + "=" * 60)
print("7. CURATION")
print("=" * 60)
# Curation criteria
criteria = {
'allen': {'snr': 3.0, 'isi_violations_ratio': 0.1, 'presence_ratio': 0.9},
'ibl': {'snr': 4.0, 'isi_violations_ratio': 0.5, 'presence_ratio': 0.5},
'strict': {'snr': 5.0, 'isi_violations_ratio': 0.01, 'presence_ratio': 0.95},
}[CURATION_METHOD]
print(f"Applying {CURATION_METHOD} criteria: {criteria}")
labels = {}
for unit_id in metrics.index:
row = metrics.loc[unit_id]
is_good = (
row.get('snr', 0) >= criteria['snr'] and
row.get('isi_violations_ratio', 1) <= criteria['isi_violations_ratio'] and
row.get('presence_ratio', 0) >= criteria['presence_ratio']
)
if is_good:
labels[int(unit_id)] = 'good'
elif row.get('snr', 0) < 2:
labels[int(unit_id)] = 'noise'
else:
labels[int(unit_id)] = 'mua'
# Save labels
with open(output_path / 'curation_labels.json', 'w') as f:
json.dump(labels, f, indent=2)
# Count
good_count = sum(1 for v in labels.values() if v == 'good')
mua_count = sum(1 for v in labels.values() if v == 'mua')
noise_count = sum(1 for v in labels.values() if v == 'noise')
print(f"\nCuration results:")
print(f" Good: {good_count}")
print(f" MUA: {mua_count}")
print(f" Noise: {noise_count}")
print(f" Total: {len(labels)}")
# =========================================================================
# 8. EXPORT
# =========================================================================
print("\n" + "=" * 60)
print("8. EXPORT")
print("=" * 60)
print("Exporting to Phy...")
export_to_phy(
analyzer,
output_folder=output_path / 'phy_export',
copy_binary=True,
)
print(f"\nAnalysis complete!")
print(f"Results saved to: {output_path}")
print(f"\nTo open in Phy:")
print(f" phy template-gui {output_path / 'phy_export' / 'params.py'}")
if __name__ == '__main__':
main()