Files
claude-scientific-skills/scientific-skills/neuropixels-analysis/scripts/neuropixels_pipeline.py
Robert 312f18ae60 Add neuropixels-analysis skill for extracellular electrophysiology
Adds comprehensive toolkit for analyzing Neuropixels high-density neural
recordings using SpikeInterface, Allen Institute, and IBL best practices.

Features:
- Data loading from SpikeGLX, Open Ephys, and NWB formats
- Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection)
- Motion/drift estimation and correction
- Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5)
- Quality metrics computation (SNR, ISI violations, presence ratio)
- Automated curation using Allen/IBL criteria
- AI-assisted visual curation for uncertain units
- Export to Phy and NWB formats

Supports Neuropixels 1.0 and 2.0 probes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-17 11:06:28 -05:00

433 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Neuropixels Data Analysis Pipeline (Best Practices Version)
Based on SpikeInterface, Allen Institute, and IBL recommendations.
Usage:
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
References:
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
- https://github.com/AllenInstitute/ecephys_spike_sorting
"""
import argparse
from pathlib import Path
import json
import spikeinterface.full as si
import numpy as np
def load_recording(data_path: str, stream_name: str = 'imec0.ap') -> si.BaseRecording:
"""Load a SpikeGLX or Open Ephys recording."""
data_path = Path(data_path)
# Auto-detect format
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
# SpikeGLX format
streams, _ = si.get_neo_streams('spikeglx', data_path)
print(f"Available streams: {streams}")
recording = si.read_spikeglx(data_path, stream_name=stream_name)
elif any(data_path.rglob('*.oebin')):
# Open Ephys format
recording = si.read_openephys(data_path)
else:
raise ValueError(f"Unknown format in {data_path}")
print(f"Loaded recording:")
print(f" Channels: {recording.get_num_channels()}")
print(f" Duration: {recording.get_total_duration():.2f} s")
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
return recording
def preprocess(
recording: si.BaseRecording,
apply_phase_shift: bool = True,
freq_min: float = 400.,
) -> tuple:
"""
Apply standard Neuropixels preprocessing.
Following SpikeInterface recommendations:
1. High-pass filter at 400 Hz (not 300)
2. Detect and remove bad channels
3. Phase shift (NP 1.0 only)
4. Common median reference
"""
print("Preprocessing...")
# Step 1: High-pass filter
rec = si.highpass_filter(recording, freq_min=freq_min)
print(f" Applied high-pass filter at {freq_min} Hz")
# Step 2: Detect bad channels
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
if len(bad_channel_ids) > 0:
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
rec = rec.remove_channels(bad_channel_ids)
else:
print(" No bad channels detected")
# Step 3: Phase shift (for Neuropixels 1.0)
if apply_phase_shift:
rec = si.phase_shift(rec)
print(" Applied phase shift correction")
# Step 4: Common median reference
rec = si.common_reference(rec, operator='median', reference='global')
print(" Applied common median reference")
return rec, bad_channel_ids
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
"""
Detect peaks and check for drift before spike sorting.
"""
print("Checking for drift...")
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
# Get noise levels
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
# Detect peaks
peaks = detect_peaks(
recording,
method='locally_exclusive',
noise_levels=noise_levels,
detect_threshold=5,
radius_um=50.,
**job_kwargs
)
print(f" Detected {len(peaks)} peaks")
# Localize peaks
peak_locations = localize_peaks(
recording, peaks,
method='center_of_mass',
**job_kwargs
)
# Save drift plot
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 6))
# Subsample for plotting
n_plot = min(100000, len(peaks))
idx = np.random.choice(len(peaks), n_plot, replace=False)
ax.scatter(
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
peak_locations['y'][idx],
s=1, alpha=0.1, c='k'
)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Depth (μm)')
ax.set_title('Peak Activity (Check for Drift)')
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
plt.close()
print(f" Saved drift plot to {output_folder}/drift_check.png")
# Estimate drift magnitude
y_positions = peak_locations['y']
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
print(f" Estimated drift range: {drift_estimate:.1f} μm")
return {
'peaks': peaks,
'peak_locations': peak_locations,
'drift_estimate': drift_estimate
}
def correct_motion(
recording: si.BaseRecording,
output_folder: str,
preset: str = 'nonrigid_fast_and_accurate'
) -> si.BaseRecording:
"""Apply motion correction if needed."""
print(f"Applying motion correction (preset: {preset})...")
rec_corrected = si.correct_motion(
recording,
preset=preset,
folder=f'{output_folder}/motion',
output_motion_info=True,
n_jobs=8,
chunk_duration='1s',
progress_bar=True
)
print(" Motion correction complete")
return rec_corrected
def run_spike_sorting(
recording: si.BaseRecording,
output_folder: str,
sorter: str = 'kilosort4'
) -> si.BaseSorting:
"""Run spike sorting."""
print(f"Running spike sorting with {sorter}...")
sorter_folder = f'{output_folder}/sorting_{sorter}'
sorting = si.run_sorter(
sorter,
recording,
output_folder=sorter_folder,
verbose=True
)
print(f" Found {len(sorting.unit_ids)} units")
print(f" Total spikes: {sorting.get_total_num_spikes()}")
return sorting
def postprocess(
sorting: si.BaseSorting,
recording: si.BaseRecording,
output_folder: str
) -> tuple:
"""Run post-processing and compute quality metrics."""
print("Post-processing...")
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
# Create analyzer
analyzer = si.create_sorting_analyzer(
sorting, recording,
sparse=True,
format='binary_folder',
folder=f'{output_folder}/analyzer'
)
# Compute extensions (order matters)
print(" Computing waveforms...")
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
analyzer.compute('templates', operators=['average', 'std'])
analyzer.compute('noise_levels')
print(" Computing spike features...")
analyzer.compute('spike_amplitudes', **job_kwargs)
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
analyzer.compute('unit_locations', method='monopolar_triangulation')
analyzer.compute('template_similarity')
print(" Computing quality metrics...")
analyzer.compute('quality_metrics')
qm = analyzer.get_extension('quality_metrics').get_data()
return analyzer, qm
def curate_units(qm, method: str = 'allen') -> dict:
"""
Classify units based on quality metrics.
Methods:
'allen': Allen Institute defaults (more permissive)
'ibl': IBL standards
'strict': Strict single-unit criteria
"""
print(f"Curating units (method: {method})...")
labels = {}
for unit_id in qm.index:
row = qm.loc[unit_id]
# Noise detection (universal)
if row['snr'] < 1.5:
labels[unit_id] = 'noise'
continue
if method == 'allen':
# Allen Institute defaults
if (row['presence_ratio'] > 0.9 and
row['isi_violations_ratio'] < 0.5 and
row['amplitude_cutoff'] < 0.1):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.5:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
elif method == 'ibl':
# IBL standards
if (row['presence_ratio'] > 0.9 and
row['isi_violations_ratio'] < 0.1 and
row['amplitude_cutoff'] < 0.1 and
row['firing_rate'] > 0.1):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.1:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
elif method == 'strict':
# Strict single-unit
if (row['snr'] > 5 and
row['presence_ratio'] > 0.95 and
row['isi_violations_ratio'] < 0.01 and
row['amplitude_cutoff'] < 0.01):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.05:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
# Summary
from collections import Counter
counts = Counter(labels.values())
print(f" Classification: {dict(counts)}")
return labels
def export_results(
analyzer,
sorting,
recording,
labels: dict,
output_folder: str
):
"""Export results to various formats."""
print("Exporting results...")
# Get good units
good_ids = [u for u, l in labels.items() if l == 'good']
sorting_good = sorting.select_units(good_ids)
# Export to Phy
phy_folder = f'{output_folder}/phy_export'
si.export_to_phy(analyzer, phy_folder,
compute_pc_features=True,
compute_amplitudes=True)
print(f" Phy export: {phy_folder}")
# Generate report
report_folder = f'{output_folder}/report'
si.export_report(analyzer, report_folder, format='png')
print(f" Report: {report_folder}")
# Save quality metrics
qm = analyzer.get_extension('quality_metrics').get_data()
qm.to_csv(f'{output_folder}/quality_metrics.csv')
# Save labels
with open(f'{output_folder}/unit_labels.json', 'w') as f:
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
# Save summary
summary = {
'total_units': len(sorting.unit_ids),
'good_units': len(good_ids),
'total_spikes': int(sorting.get_total_num_spikes()),
'duration_s': float(recording.get_total_duration()),
'n_channels': int(recording.get_num_channels()),
}
with open(f'{output_folder}/summary.json', 'w') as f:
json.dump(summary, f, indent=2)
print(f" Summary: {summary}")
def run_pipeline(
data_path: str,
output_path: str,
sorter: str = 'kilosort4',
stream_name: str = 'imec0.ap',
apply_motion_correction: bool = True,
curation_method: str = 'allen'
):
"""Run complete Neuropixels analysis pipeline."""
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
# 1. Load data
recording = load_recording(data_path, stream_name)
# 2. Preprocess
rec_preprocessed, bad_channels = preprocess(recording)
# Save preprocessed
preproc_folder = output_path / 'preprocessed'
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
rec_preprocessed = rec_preprocessed.save(
folder=str(preproc_folder),
format='binary',
**job_kwargs
)
# 3. Check drift
drift_info = check_drift(rec_preprocessed, str(output_path))
# 4. Motion correction (if needed)
if apply_motion_correction and drift_info['drift_estimate'] > 20:
print(f"Drift > 20 μm detected, applying motion correction...")
rec_final = correct_motion(rec_preprocessed, str(output_path))
else:
print("Skipping motion correction (low drift)")
rec_final = rec_preprocessed
# 5. Spike sorting
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
# 6. Post-processing
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
# 7. Curation
labels = curate_units(qm, method=curation_method)
# 8. Export
export_results(analyzer, sorting, rec_final, labels, str(output_path))
print("\n" + "="*50)
print("Pipeline complete!")
print(f"Output directory: {output_path}")
print("="*50)
return analyzer, sorting, qm, labels
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Neuropixels analysis pipeline (best practices)'
)
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
parser.add_argument('output_path', help='Output directory')
parser.add_argument('--sorter', default='kilosort4',
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
help='Spike sorter to use')
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
parser.add_argument('--no-motion-correction', action='store_true',
help='Skip motion correction')
parser.add_argument('--curation', default='allen',
choices=['allen', 'ibl', 'strict'],
help='Curation method')
args = parser.parse_args()
run_pipeline(
args.data_path,
args.output_path,
sorter=args.sorter,
stream_name=args.stream,
apply_motion_correction=not args.no_motion_correction,
curation_method=args.curation
)