mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-01-26 16:58:56 +08:00
- Change stream_name to stream_id in read_spikeglx() call - Change output_folder to folder in run_sorter() call These parameters were renamed in SpikeInterface ≥0.100. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
433 lines
13 KiB
Python
433 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Neuropixels Data Analysis Pipeline (Best Practices Version)
|
|
|
|
Based on SpikeInterface, Allen Institute, and IBL recommendations.
|
|
|
|
Usage:
|
|
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
|
|
|
|
References:
|
|
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
|
|
- https://github.com/AllenInstitute/ecephys_spike_sorting
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
import json
|
|
import spikeinterface.full as si
|
|
import numpy as np
|
|
|
|
|
|
def load_recording(data_path: str, stream_id: str = 'imec0.ap') -> si.BaseRecording:
|
|
"""Load a SpikeGLX or Open Ephys recording."""
|
|
|
|
data_path = Path(data_path)
|
|
|
|
# Auto-detect format
|
|
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
|
|
# SpikeGLX format
|
|
streams, _ = si.get_neo_streams('spikeglx', data_path)
|
|
print(f"Available streams: {streams}")
|
|
recording = si.read_spikeglx(data_path, stream_id=stream_id)
|
|
elif any(data_path.rglob('*.oebin')):
|
|
# Open Ephys format
|
|
recording = si.read_openephys(data_path)
|
|
else:
|
|
raise ValueError(f"Unknown format in {data_path}")
|
|
|
|
print(f"Loaded recording:")
|
|
print(f" Channels: {recording.get_num_channels()}")
|
|
print(f" Duration: {recording.get_total_duration():.2f} s")
|
|
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
|
|
|
|
return recording
|
|
|
|
|
|
def preprocess(
|
|
recording: si.BaseRecording,
|
|
apply_phase_shift: bool = True,
|
|
freq_min: float = 400.,
|
|
) -> tuple:
|
|
"""
|
|
Apply standard Neuropixels preprocessing.
|
|
|
|
Following SpikeInterface recommendations:
|
|
1. High-pass filter at 400 Hz (not 300)
|
|
2. Detect and remove bad channels
|
|
3. Phase shift (NP 1.0 only)
|
|
4. Common median reference
|
|
"""
|
|
print("Preprocessing...")
|
|
|
|
# Step 1: High-pass filter
|
|
rec = si.highpass_filter(recording, freq_min=freq_min)
|
|
print(f" Applied high-pass filter at {freq_min} Hz")
|
|
|
|
# Step 2: Detect bad channels
|
|
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
|
|
if len(bad_channel_ids) > 0:
|
|
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
|
|
rec = rec.remove_channels(bad_channel_ids)
|
|
else:
|
|
print(" No bad channels detected")
|
|
|
|
# Step 3: Phase shift (for Neuropixels 1.0)
|
|
if apply_phase_shift:
|
|
rec = si.phase_shift(rec)
|
|
print(" Applied phase shift correction")
|
|
|
|
# Step 4: Common median reference
|
|
rec = si.common_reference(rec, operator='median', reference='global')
|
|
print(" Applied common median reference")
|
|
|
|
return rec, bad_channel_ids
|
|
|
|
|
|
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
|
|
"""
|
|
Detect peaks and check for drift before spike sorting.
|
|
"""
|
|
print("Checking for drift...")
|
|
|
|
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
|
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
|
|
|
|
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
|
|
|
# Get noise levels
|
|
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
|
|
|
|
# Detect peaks
|
|
peaks = detect_peaks(
|
|
recording,
|
|
method='locally_exclusive',
|
|
noise_levels=noise_levels,
|
|
detect_threshold=5,
|
|
radius_um=50.,
|
|
**job_kwargs
|
|
)
|
|
print(f" Detected {len(peaks)} peaks")
|
|
|
|
# Localize peaks
|
|
peak_locations = localize_peaks(
|
|
recording, peaks,
|
|
method='center_of_mass',
|
|
**job_kwargs
|
|
)
|
|
|
|
# Save drift plot
|
|
import matplotlib.pyplot as plt
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
|
|
# Subsample for plotting
|
|
n_plot = min(100000, len(peaks))
|
|
idx = np.random.choice(len(peaks), n_plot, replace=False)
|
|
|
|
ax.scatter(
|
|
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
|
|
peak_locations['y'][idx],
|
|
s=1, alpha=0.1, c='k'
|
|
)
|
|
ax.set_xlabel('Time (s)')
|
|
ax.set_ylabel('Depth (μm)')
|
|
ax.set_title('Peak Activity (Check for Drift)')
|
|
|
|
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" Saved drift plot to {output_folder}/drift_check.png")
|
|
|
|
# Estimate drift magnitude
|
|
y_positions = peak_locations['y']
|
|
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
|
|
print(f" Estimated drift range: {drift_estimate:.1f} μm")
|
|
|
|
return {
|
|
'peaks': peaks,
|
|
'peak_locations': peak_locations,
|
|
'drift_estimate': drift_estimate
|
|
}
|
|
|
|
|
|
def correct_motion(
|
|
recording: si.BaseRecording,
|
|
output_folder: str,
|
|
preset: str = 'nonrigid_fast_and_accurate'
|
|
) -> si.BaseRecording:
|
|
"""Apply motion correction if needed."""
|
|
print(f"Applying motion correction (preset: {preset})...")
|
|
|
|
rec_corrected = si.correct_motion(
|
|
recording,
|
|
preset=preset,
|
|
folder=f'{output_folder}/motion',
|
|
output_motion_info=True,
|
|
n_jobs=8,
|
|
chunk_duration='1s',
|
|
progress_bar=True
|
|
)
|
|
|
|
print(" Motion correction complete")
|
|
return rec_corrected
|
|
|
|
|
|
def run_spike_sorting(
|
|
recording: si.BaseRecording,
|
|
output_folder: str,
|
|
sorter: str = 'kilosort4'
|
|
) -> si.BaseSorting:
|
|
"""Run spike sorting."""
|
|
print(f"Running spike sorting with {sorter}...")
|
|
|
|
sorter_folder = f'{output_folder}/sorting_{sorter}'
|
|
|
|
sorting = si.run_sorter(
|
|
sorter,
|
|
recording,
|
|
folder=sorter_folder,
|
|
verbose=True
|
|
)
|
|
|
|
print(f" Found {len(sorting.unit_ids)} units")
|
|
print(f" Total spikes: {sorting.get_total_num_spikes()}")
|
|
|
|
return sorting
|
|
|
|
|
|
def postprocess(
|
|
sorting: si.BaseSorting,
|
|
recording: si.BaseRecording,
|
|
output_folder: str
|
|
) -> tuple:
|
|
"""Run post-processing and compute quality metrics."""
|
|
print("Post-processing...")
|
|
|
|
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
|
|
|
# Create analyzer
|
|
analyzer = si.create_sorting_analyzer(
|
|
sorting, recording,
|
|
sparse=True,
|
|
format='binary_folder',
|
|
folder=f'{output_folder}/analyzer'
|
|
)
|
|
|
|
# Compute extensions (order matters)
|
|
print(" Computing waveforms...")
|
|
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
|
|
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
|
|
analyzer.compute('templates', operators=['average', 'std'])
|
|
analyzer.compute('noise_levels')
|
|
|
|
print(" Computing spike features...")
|
|
analyzer.compute('spike_amplitudes', **job_kwargs)
|
|
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
|
|
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
|
analyzer.compute('template_similarity')
|
|
|
|
print(" Computing quality metrics...")
|
|
analyzer.compute('quality_metrics')
|
|
|
|
qm = analyzer.get_extension('quality_metrics').get_data()
|
|
|
|
return analyzer, qm
|
|
|
|
|
|
def curate_units(qm, method: str = 'allen') -> dict:
|
|
"""
|
|
Classify units based on quality metrics.
|
|
|
|
Methods:
|
|
'allen': Allen Institute defaults (more permissive)
|
|
'ibl': IBL standards
|
|
'strict': Strict single-unit criteria
|
|
"""
|
|
print(f"Curating units (method: {method})...")
|
|
|
|
labels = {}
|
|
|
|
for unit_id in qm.index:
|
|
row = qm.loc[unit_id]
|
|
|
|
# Noise detection (universal)
|
|
if row['snr'] < 1.5:
|
|
labels[unit_id] = 'noise'
|
|
continue
|
|
|
|
if method == 'allen':
|
|
# Allen Institute defaults
|
|
if (row['presence_ratio'] > 0.9 and
|
|
row['isi_violations_ratio'] < 0.5 and
|
|
row['amplitude_cutoff'] < 0.1):
|
|
labels[unit_id] = 'good'
|
|
elif row['isi_violations_ratio'] > 0.5:
|
|
labels[unit_id] = 'mua'
|
|
else:
|
|
labels[unit_id] = 'unsorted'
|
|
|
|
elif method == 'ibl':
|
|
# IBL standards
|
|
if (row['presence_ratio'] > 0.9 and
|
|
row['isi_violations_ratio'] < 0.1 and
|
|
row['amplitude_cutoff'] < 0.1 and
|
|
row['firing_rate'] > 0.1):
|
|
labels[unit_id] = 'good'
|
|
elif row['isi_violations_ratio'] > 0.1:
|
|
labels[unit_id] = 'mua'
|
|
else:
|
|
labels[unit_id] = 'unsorted'
|
|
|
|
elif method == 'strict':
|
|
# Strict single-unit
|
|
if (row['snr'] > 5 and
|
|
row['presence_ratio'] > 0.95 and
|
|
row['isi_violations_ratio'] < 0.01 and
|
|
row['amplitude_cutoff'] < 0.01):
|
|
labels[unit_id] = 'good'
|
|
elif row['isi_violations_ratio'] > 0.05:
|
|
labels[unit_id] = 'mua'
|
|
else:
|
|
labels[unit_id] = 'unsorted'
|
|
|
|
# Summary
|
|
from collections import Counter
|
|
counts = Counter(labels.values())
|
|
print(f" Classification: {dict(counts)}")
|
|
|
|
return labels
|
|
|
|
|
|
def export_results(
|
|
analyzer,
|
|
sorting,
|
|
recording,
|
|
labels: dict,
|
|
output_folder: str
|
|
):
|
|
"""Export results to various formats."""
|
|
print("Exporting results...")
|
|
|
|
# Get good units
|
|
good_ids = [u for u, l in labels.items() if l == 'good']
|
|
sorting_good = sorting.select_units(good_ids)
|
|
|
|
# Export to Phy
|
|
phy_folder = f'{output_folder}/phy_export'
|
|
si.export_to_phy(analyzer, phy_folder,
|
|
compute_pc_features=True,
|
|
compute_amplitudes=True)
|
|
print(f" Phy export: {phy_folder}")
|
|
|
|
# Generate report
|
|
report_folder = f'{output_folder}/report'
|
|
si.export_report(analyzer, report_folder, format='png')
|
|
print(f" Report: {report_folder}")
|
|
|
|
# Save quality metrics
|
|
qm = analyzer.get_extension('quality_metrics').get_data()
|
|
qm.to_csv(f'{output_folder}/quality_metrics.csv')
|
|
|
|
# Save labels
|
|
with open(f'{output_folder}/unit_labels.json', 'w') as f:
|
|
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
|
|
|
|
# Save summary
|
|
summary = {
|
|
'total_units': len(sorting.unit_ids),
|
|
'good_units': len(good_ids),
|
|
'total_spikes': int(sorting.get_total_num_spikes()),
|
|
'duration_s': float(recording.get_total_duration()),
|
|
'n_channels': int(recording.get_num_channels()),
|
|
}
|
|
with open(f'{output_folder}/summary.json', 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
print(f" Summary: {summary}")
|
|
|
|
|
|
def run_pipeline(
|
|
data_path: str,
|
|
output_path: str,
|
|
sorter: str = 'kilosort4',
|
|
stream_name: str = 'imec0.ap',
|
|
apply_motion_correction: bool = True,
|
|
curation_method: str = 'allen'
|
|
):
|
|
"""Run complete Neuropixels analysis pipeline."""
|
|
|
|
output_path = Path(output_path)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 1. Load data
|
|
recording = load_recording(data_path, stream_name)
|
|
|
|
# 2. Preprocess
|
|
rec_preprocessed, bad_channels = preprocess(recording)
|
|
|
|
# Save preprocessed
|
|
preproc_folder = output_path / 'preprocessed'
|
|
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
|
rec_preprocessed = rec_preprocessed.save(
|
|
folder=str(preproc_folder),
|
|
format='binary',
|
|
**job_kwargs
|
|
)
|
|
|
|
# 3. Check drift
|
|
drift_info = check_drift(rec_preprocessed, str(output_path))
|
|
|
|
# 4. Motion correction (if needed)
|
|
if apply_motion_correction and drift_info['drift_estimate'] > 20:
|
|
print(f"Drift > 20 μm detected, applying motion correction...")
|
|
rec_final = correct_motion(rec_preprocessed, str(output_path))
|
|
else:
|
|
print("Skipping motion correction (low drift)")
|
|
rec_final = rec_preprocessed
|
|
|
|
# 5. Spike sorting
|
|
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
|
|
|
|
# 6. Post-processing
|
|
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
|
|
|
|
# 7. Curation
|
|
labels = curate_units(qm, method=curation_method)
|
|
|
|
# 8. Export
|
|
export_results(analyzer, sorting, rec_final, labels, str(output_path))
|
|
|
|
print("\n" + "="*50)
|
|
print("Pipeline complete!")
|
|
print(f"Output directory: {output_path}")
|
|
print("="*50)
|
|
|
|
return analyzer, sorting, qm, labels
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
description='Neuropixels analysis pipeline (best practices)'
|
|
)
|
|
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
|
|
parser.add_argument('output_path', help='Output directory')
|
|
parser.add_argument('--sorter', default='kilosort4',
|
|
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
|
|
help='Spike sorter to use')
|
|
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
|
|
parser.add_argument('--no-motion-correction', action='store_true',
|
|
help='Skip motion correction')
|
|
parser.add_argument('--curation', default='allen',
|
|
choices=['allen', 'ibl', 'strict'],
|
|
help='Curation method')
|
|
|
|
args = parser.parse_args()
|
|
|
|
run_pipeline(
|
|
args.data_path,
|
|
args.output_path,
|
|
sorter=args.sorter,
|
|
stream_name=args.stream,
|
|
apply_motion_correction=not args.no_motion_correction,
|
|
curation_method=args.curation
|
|
)
|