mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-27 07:09:27 +08:00
Add neuropixels-analysis skill for extracellular electrophysiology
Adds comprehensive toolkit for analyzing Neuropixels high-density neural recordings using SpikeInterface, Allen Institute, and IBL best practices. Features: - Data loading from SpikeGLX, Open Ephys, and NWB formats - Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection) - Motion/drift estimation and correction - Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5) - Quality metrics computation (SNR, ISI violations, presence ratio) - Automated curation using Allen/IBL criteria - AI-assisted visual curation for uncertain units - Export to Phy and NWB formats Supports Neuropixels 1.0 and 2.0 probes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neuropixels Data Analysis Pipeline (Best Practices Version)
|
||||
|
||||
Based on SpikeInterface, Allen Institute, and IBL recommendations.
|
||||
|
||||
Usage:
|
||||
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
|
||||
|
||||
References:
|
||||
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
|
||||
- https://github.com/AllenInstitute/ecephys_spike_sorting
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
import spikeinterface.full as si
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_recording(data_path: str, stream_name: str = 'imec0.ap') -> si.BaseRecording:
|
||||
"""Load a SpikeGLX or Open Ephys recording."""
|
||||
|
||||
data_path = Path(data_path)
|
||||
|
||||
# Auto-detect format
|
||||
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
|
||||
# SpikeGLX format
|
||||
streams, _ = si.get_neo_streams('spikeglx', data_path)
|
||||
print(f"Available streams: {streams}")
|
||||
recording = si.read_spikeglx(data_path, stream_name=stream_name)
|
||||
elif any(data_path.rglob('*.oebin')):
|
||||
# Open Ephys format
|
||||
recording = si.read_openephys(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unknown format in {data_path}")
|
||||
|
||||
print(f"Loaded recording:")
|
||||
print(f" Channels: {recording.get_num_channels()}")
|
||||
print(f" Duration: {recording.get_total_duration():.2f} s")
|
||||
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def preprocess(
|
||||
recording: si.BaseRecording,
|
||||
apply_phase_shift: bool = True,
|
||||
freq_min: float = 400.,
|
||||
) -> tuple:
|
||||
"""
|
||||
Apply standard Neuropixels preprocessing.
|
||||
|
||||
Following SpikeInterface recommendations:
|
||||
1. High-pass filter at 400 Hz (not 300)
|
||||
2. Detect and remove bad channels
|
||||
3. Phase shift (NP 1.0 only)
|
||||
4. Common median reference
|
||||
"""
|
||||
print("Preprocessing...")
|
||||
|
||||
# Step 1: High-pass filter
|
||||
rec = si.highpass_filter(recording, freq_min=freq_min)
|
||||
print(f" Applied high-pass filter at {freq_min} Hz")
|
||||
|
||||
# Step 2: Detect bad channels
|
||||
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
|
||||
if len(bad_channel_ids) > 0:
|
||||
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
|
||||
rec = rec.remove_channels(bad_channel_ids)
|
||||
else:
|
||||
print(" No bad channels detected")
|
||||
|
||||
# Step 3: Phase shift (for Neuropixels 1.0)
|
||||
if apply_phase_shift:
|
||||
rec = si.phase_shift(rec)
|
||||
print(" Applied phase shift correction")
|
||||
|
||||
# Step 4: Common median reference
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
print(" Applied common median reference")
|
||||
|
||||
return rec, bad_channel_ids
|
||||
|
||||
|
||||
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
|
||||
"""
|
||||
Detect peaks and check for drift before spike sorting.
|
||||
"""
|
||||
print("Checking for drift...")
|
||||
|
||||
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
||||
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Get noise levels
|
||||
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
|
||||
|
||||
# Detect peaks
|
||||
peaks = detect_peaks(
|
||||
recording,
|
||||
method='locally_exclusive',
|
||||
noise_levels=noise_levels,
|
||||
detect_threshold=5,
|
||||
radius_um=50.,
|
||||
**job_kwargs
|
||||
)
|
||||
print(f" Detected {len(peaks)} peaks")
|
||||
|
||||
# Localize peaks
|
||||
peak_locations = localize_peaks(
|
||||
recording, peaks,
|
||||
method='center_of_mass',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# Save drift plot
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Subsample for plotting
|
||||
n_plot = min(100000, len(peaks))
|
||||
idx = np.random.choice(len(peaks), n_plot, replace=False)
|
||||
|
||||
ax.scatter(
|
||||
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
|
||||
peak_locations['y'][idx],
|
||||
s=1, alpha=0.1, c='k'
|
||||
)
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Depth (μm)')
|
||||
ax.set_title('Peak Activity (Check for Drift)')
|
||||
|
||||
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f" Saved drift plot to {output_folder}/drift_check.png")
|
||||
|
||||
# Estimate drift magnitude
|
||||
y_positions = peak_locations['y']
|
||||
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
|
||||
print(f" Estimated drift range: {drift_estimate:.1f} μm")
|
||||
|
||||
return {
|
||||
'peaks': peaks,
|
||||
'peak_locations': peak_locations,
|
||||
'drift_estimate': drift_estimate
|
||||
}
|
||||
|
||||
|
||||
def correct_motion(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
preset: str = 'nonrigid_fast_and_accurate'
|
||||
) -> si.BaseRecording:
|
||||
"""Apply motion correction if needed."""
|
||||
print(f"Applying motion correction (preset: {preset})...")
|
||||
|
||||
rec_corrected = si.correct_motion(
|
||||
recording,
|
||||
preset=preset,
|
||||
folder=f'{output_folder}/motion',
|
||||
output_motion_info=True,
|
||||
n_jobs=8,
|
||||
chunk_duration='1s',
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
print(" Motion correction complete")
|
||||
return rec_corrected
|
||||
|
||||
|
||||
def run_spike_sorting(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
sorter: str = 'kilosort4'
|
||||
) -> si.BaseSorting:
|
||||
"""Run spike sorting."""
|
||||
print(f"Running spike sorting with {sorter}...")
|
||||
|
||||
sorter_folder = f'{output_folder}/sorting_{sorter}'
|
||||
|
||||
sorting = si.run_sorter(
|
||||
sorter,
|
||||
recording,
|
||||
output_folder=sorter_folder,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print(f" Found {len(sorting.unit_ids)} units")
|
||||
print(f" Total spikes: {sorting.get_total_num_spikes()}")
|
||||
|
||||
return sorting
|
||||
|
||||
|
||||
def postprocess(
|
||||
sorting: si.BaseSorting,
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str
|
||||
) -> tuple:
|
||||
"""Run post-processing and compute quality metrics."""
|
||||
print("Post-processing...")
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting, recording,
|
||||
sparse=True,
|
||||
format='binary_folder',
|
||||
folder=f'{output_folder}/analyzer'
|
||||
)
|
||||
|
||||
# Compute extensions (order matters)
|
||||
print(" Computing waveforms...")
|
||||
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
analyzer.compute('noise_levels')
|
||||
|
||||
print(" Computing spike features...")
|
||||
analyzer.compute('spike_amplitudes', **job_kwargs)
|
||||
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
analyzer.compute('template_similarity')
|
||||
|
||||
print(" Computing quality metrics...")
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
return analyzer, qm
|
||||
|
||||
|
||||
def curate_units(qm, method: str = 'allen') -> dict:
|
||||
"""
|
||||
Classify units based on quality metrics.
|
||||
|
||||
Methods:
|
||||
'allen': Allen Institute defaults (more permissive)
|
||||
'ibl': IBL standards
|
||||
'strict': Strict single-unit criteria
|
||||
"""
|
||||
print(f"Curating units (method: {method})...")
|
||||
|
||||
labels = {}
|
||||
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
|
||||
# Noise detection (universal)
|
||||
if row['snr'] < 1.5:
|
||||
labels[unit_id] = 'noise'
|
||||
continue
|
||||
|
||||
if method == 'allen':
|
||||
# Allen Institute defaults
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.5 and
|
||||
row['amplitude_cutoff'] < 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.5:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'ibl':
|
||||
# IBL standards
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.1 and
|
||||
row['amplitude_cutoff'] < 0.1 and
|
||||
row['firing_rate'] > 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.1:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'strict':
|
||||
# Strict single-unit
|
||||
if (row['snr'] > 5 and
|
||||
row['presence_ratio'] > 0.95 and
|
||||
row['isi_violations_ratio'] < 0.01 and
|
||||
row['amplitude_cutoff'] < 0.01):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.05:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
# Summary
|
||||
from collections import Counter
|
||||
counts = Counter(labels.values())
|
||||
print(f" Classification: {dict(counts)}")
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def export_results(
|
||||
analyzer,
|
||||
sorting,
|
||||
recording,
|
||||
labels: dict,
|
||||
output_folder: str
|
||||
):
|
||||
"""Export results to various formats."""
|
||||
print("Exporting results...")
|
||||
|
||||
# Get good units
|
||||
good_ids = [u for u, l in labels.items() if l == 'good']
|
||||
sorting_good = sorting.select_units(good_ids)
|
||||
|
||||
# Export to Phy
|
||||
phy_folder = f'{output_folder}/phy_export'
|
||||
si.export_to_phy(analyzer, phy_folder,
|
||||
compute_pc_features=True,
|
||||
compute_amplitudes=True)
|
||||
print(f" Phy export: {phy_folder}")
|
||||
|
||||
# Generate report
|
||||
report_folder = f'{output_folder}/report'
|
||||
si.export_report(analyzer, report_folder, format='png')
|
||||
print(f" Report: {report_folder}")
|
||||
|
||||
# Save quality metrics
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
qm.to_csv(f'{output_folder}/quality_metrics.csv')
|
||||
|
||||
# Save labels
|
||||
with open(f'{output_folder}/unit_labels.json', 'w') as f:
|
||||
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
|
||||
|
||||
# Save summary
|
||||
summary = {
|
||||
'total_units': len(sorting.unit_ids),
|
||||
'good_units': len(good_ids),
|
||||
'total_spikes': int(sorting.get_total_num_spikes()),
|
||||
'duration_s': float(recording.get_total_duration()),
|
||||
'n_channels': int(recording.get_num_channels()),
|
||||
}
|
||||
with open(f'{output_folder}/summary.json', 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print(f" Summary: {summary}")
|
||||
|
||||
|
||||
def run_pipeline(
|
||||
data_path: str,
|
||||
output_path: str,
|
||||
sorter: str = 'kilosort4',
|
||||
stream_name: str = 'imec0.ap',
|
||||
apply_motion_correction: bool = True,
|
||||
curation_method: str = 'allen'
|
||||
):
|
||||
"""Run complete Neuropixels analysis pipeline."""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Load data
|
||||
recording = load_recording(data_path, stream_name)
|
||||
|
||||
# 2. Preprocess
|
||||
rec_preprocessed, bad_channels = preprocess(recording)
|
||||
|
||||
# Save preprocessed
|
||||
preproc_folder = output_path / 'preprocessed'
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
rec_preprocessed = rec_preprocessed.save(
|
||||
folder=str(preproc_folder),
|
||||
format='binary',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# 3. Check drift
|
||||
drift_info = check_drift(rec_preprocessed, str(output_path))
|
||||
|
||||
# 4. Motion correction (if needed)
|
||||
if apply_motion_correction and drift_info['drift_estimate'] > 20:
|
||||
print(f"Drift > 20 μm detected, applying motion correction...")
|
||||
rec_final = correct_motion(rec_preprocessed, str(output_path))
|
||||
else:
|
||||
print("Skipping motion correction (low drift)")
|
||||
rec_final = rec_preprocessed
|
||||
|
||||
# 5. Spike sorting
|
||||
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
|
||||
|
||||
# 6. Post-processing
|
||||
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
|
||||
|
||||
# 7. Curation
|
||||
labels = curate_units(qm, method=curation_method)
|
||||
|
||||
# 8. Export
|
||||
export_results(analyzer, sorting, rec_final, labels, str(output_path))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Pipeline complete!")
|
||||
print(f"Output directory: {output_path}")
|
||||
print("="*50)
|
||||
|
||||
return analyzer, sorting, qm, labels
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Neuropixels analysis pipeline (best practices)'
|
||||
)
|
||||
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
|
||||
parser.add_argument('output_path', help='Output directory')
|
||||
parser.add_argument('--sorter', default='kilosort4',
|
||||
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
|
||||
help='Spike sorter to use')
|
||||
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
|
||||
parser.add_argument('--no-motion-correction', action='store_true',
|
||||
help='Skip motion correction')
|
||||
parser.add_argument('--curation', default='allen',
|
||||
choices=['allen', 'ibl', 'strict'],
|
||||
help='Curation method')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_pipeline(
|
||||
args.data_path,
|
||||
args.output_path,
|
||||
sorter=args.sorter,
|
||||
stream_name=args.stream,
|
||||
apply_motion_correction=not args.no_motion_correction,
|
||||
curation_method=args.curation
|
||||
)
|
||||
Reference in New Issue
Block a user