Add neuropixels-analysis skill for extracellular electrophysiology

Adds comprehensive toolkit for analyzing Neuropixels high-density neural
recordings using SpikeInterface, Allen Institute, and IBL best practices.

Features:
- Data loading from SpikeGLX, Open Ephys, and NWB formats
- Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection)
- Motion/drift estimation and correction
- Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5)
- Quality metrics computation (SNR, ISI violations, presence ratio)
- Automated curation using Allen/IBL criteria
- AI-assisted visual curation for uncertain units
- Export to Phy and NWB formats

Supports Neuropixels 1.0 and 2.0 probes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Robert
2025-12-17 11:06:28 -05:00
parent 4fb9c053f7
commit 312f18ae60
21 changed files with 5358 additions and 1 deletions

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python
"""
Compute quality metrics and curate units.
Usage:
python compute_metrics.py sorting/ preprocessed/ --output metrics/
"""
import argparse
from pathlib import Path
import json
import pandas as pd
import spikeinterface.full as si
# Curation criteria presets
CURATION_CRITERIA = {
'allen': {
'snr': 3.0,
'isi_violations_ratio': 0.1,
'presence_ratio': 0.9,
'amplitude_cutoff': 0.1,
},
'ibl': {
'snr': 4.0,
'isi_violations_ratio': 0.5,
'presence_ratio': 0.5,
'amplitude_cutoff': None,
},
'strict': {
'snr': 5.0,
'isi_violations_ratio': 0.01,
'presence_ratio': 0.95,
'amplitude_cutoff': 0.05,
},
}
def compute_metrics(
sorting_path: str,
recording_path: str,
output_dir: str,
curation_method: str = 'allen',
n_jobs: int = -1,
):
"""Compute quality metrics and apply curation."""
print(f"Loading sorting from: {sorting_path}")
sorting = si.load_extractor(Path(sorting_path) / 'sorting')
print(f"Loading recording from: {recording_path}")
recording = si.load_extractor(Path(recording_path) / 'preprocessed')
print(f"Units: {len(sorting.unit_ids)}")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Create analyzer
print("Creating SortingAnalyzer...")
analyzer = si.create_sorting_analyzer(
sorting,
recording,
format='binary_folder',
folder=output_path / 'analyzer',
sparse=True,
)
# Compute extensions
print("Computing waveforms...")
analyzer.compute('random_spikes', max_spikes_per_unit=500)
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
analyzer.compute('templates', operators=['average', 'std'])
print("Computing additional extensions...")
analyzer.compute('noise_levels')
analyzer.compute('spike_amplitudes')
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
analyzer.compute('unit_locations', method='monopolar_triangulation')
# Compute quality metrics
print("Computing quality metrics...")
metrics = si.compute_quality_metrics(
analyzer,
metric_names=[
'snr',
'isi_violations_ratio',
'presence_ratio',
'amplitude_cutoff',
'firing_rate',
'amplitude_cv',
'sliding_rp_violation',
],
n_jobs=n_jobs,
)
# Save metrics
metrics.to_csv(output_path / 'quality_metrics.csv')
print(f"Saved metrics to: {output_path / 'quality_metrics.csv'}")
# Apply curation
criteria = CURATION_CRITERIA.get(curation_method, CURATION_CRITERIA['allen'])
print(f"\nApplying {curation_method} curation criteria: {criteria}")
labels = {}
for unit_id in metrics.index:
row = metrics.loc[unit_id]
# Check each criterion
is_good = True
if criteria.get('snr') and row.get('snr', 0) < criteria['snr']:
is_good = False
if criteria.get('isi_violations_ratio') and row.get('isi_violations_ratio', 1) > criteria['isi_violations_ratio']:
is_good = False
if criteria.get('presence_ratio') and row.get('presence_ratio', 0) < criteria['presence_ratio']:
is_good = False
if criteria.get('amplitude_cutoff') and row.get('amplitude_cutoff', 1) > criteria['amplitude_cutoff']:
is_good = False
# Classify
if is_good:
labels[int(unit_id)] = 'good'
elif row.get('snr', 0) < 2:
labels[int(unit_id)] = 'noise'
else:
labels[int(unit_id)] = 'mua'
# Save labels
with open(output_path / 'curation_labels.json', 'w') as f:
json.dump(labels, f, indent=2)
# Summary
label_counts = {}
for label in labels.values():
label_counts[label] = label_counts.get(label, 0) + 1
print(f"\nCuration summary:")
print(f" Good: {label_counts.get('good', 0)}")
print(f" MUA: {label_counts.get('mua', 0)}")
print(f" Noise: {label_counts.get('noise', 0)}")
print(f" Total: {len(labels)}")
# Metrics summary
print(f"\nMetrics summary:")
for col in ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']:
if col in metrics.columns:
print(f" {col}: {metrics[col].median():.4f} (median)")
return analyzer, metrics, labels
def main():
parser = argparse.ArgumentParser(description='Compute quality metrics')
parser.add_argument('sorting', help='Path to sorting directory')
parser.add_argument('recording', help='Path to preprocessed recording')
parser.add_argument('--output', '-o', default='metrics/', help='Output directory')
parser.add_argument('--curation', '-c', default='allen',
choices=['allen', 'ibl', 'strict'])
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
args = parser.parse_args()
compute_metrics(
args.sorting,
args.recording,
args.output,
curation_method=args.curation,
n_jobs=args.n_jobs,
)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Quick exploration of Neuropixels recording.
Usage:
python explore_recording.py /path/to/spikeglx/data
"""
import argparse
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
def explore_recording(data_path: str, stream_id: str = 'imec0.ap'):
"""Explore a Neuropixels recording."""
print(f"Loading: {data_path}")
recording = si.read_spikeglx(data_path, stream_id=stream_id)
# Basic info
print("\n" + "="*50)
print("RECORDING INFO")
print("="*50)
print(f"Channels: {recording.get_num_channels()}")
print(f"Duration: {recording.get_total_duration():.2f} s ({recording.get_total_duration()/60:.2f} min)")
print(f"Sampling rate: {recording.get_sampling_frequency()} Hz")
print(f"Total samples: {recording.get_num_samples()}")
# Probe info
probe = recording.get_probe()
print(f"\nProbe: {probe.manufacturer} {probe.model_name if hasattr(probe, 'model_name') else ''}")
print(f"Probe shape: {probe.ndim}D")
# Channel groups
if recording.get_channel_groups() is not None:
groups = np.unique(recording.get_channel_groups())
print(f"Channel groups (shanks): {len(groups)}")
# Check for bad channels
print("\n" + "="*50)
print("BAD CHANNEL DETECTION")
print("="*50)
bad_ids, labels = si.detect_bad_channels(recording)
if len(bad_ids) > 0:
print(f"Bad channels found: {len(bad_ids)}")
for ch, label in zip(bad_ids, labels):
print(f" Channel {ch}: {label}")
else:
print("No bad channels detected")
# Sample traces
print("\n" + "="*50)
print("SIGNAL STATISTICS")
print("="*50)
# Get 1 second of data
n_samples = int(recording.get_sampling_frequency())
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
print(f"Sample mean: {np.mean(traces):.2f}")
print(f"Sample std: {np.std(traces):.2f}")
print(f"Sample min: {np.min(traces):.2f}")
print(f"Sample max: {np.max(traces):.2f}")
return recording
def plot_probe(recording, output_path=None):
"""Plot probe layout."""
fig, ax = plt.subplots(figsize=(4, 12))
si.plot_probe_map(recording, ax=ax, with_channel_ids=False)
ax.set_title('Probe Layout')
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"Saved: {output_path}")
else:
plt.show()
def plot_traces(recording, duration=1.0, output_path=None):
"""Plot raw traces."""
n_samples = int(duration * recording.get_sampling_frequency())
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
fig, ax = plt.subplots(figsize=(12, 8))
# Plot subset of channels
n_channels = min(20, recording.get_num_channels())
channel_idx = np.linspace(0, recording.get_num_channels()-1, n_channels, dtype=int)
time = np.arange(n_samples) / recording.get_sampling_frequency()
for i, ch in enumerate(channel_idx):
offset = i * 200 # Offset for visibility
ax.plot(time, traces[:, ch] + offset, 'k', linewidth=0.5)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Channel (offset)')
ax.set_title(f'Raw Traces ({n_channels} channels)')
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"Saved: {output_path}")
else:
plt.show()
def plot_power_spectrum(recording, output_path=None):
"""Plot power spectrum."""
from scipy import signal
# Get data from middle channel
mid_ch = recording.get_num_channels() // 2
n_samples = min(int(10 * recording.get_sampling_frequency()), recording.get_num_samples())
traces = recording.get_traces(
start_frame=0,
end_frame=n_samples,
channel_ids=[recording.channel_ids[mid_ch]]
).flatten()
fs = recording.get_sampling_frequency()
# Compute power spectrum
freqs, psd = signal.welch(traces, fs, nperseg=4096)
fig, ax = plt.subplots(figsize=(10, 5))
ax.semilogy(freqs, psd)
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power Spectral Density')
ax.set_title(f'Power Spectrum (Channel {mid_ch})')
ax.set_xlim(0, 5000)
ax.axvline(300, color='r', linestyle='--', alpha=0.5, label='300 Hz')
ax.axvline(6000, color='r', linestyle='--', alpha=0.5, label='6000 Hz')
ax.legend()
ax.grid(True, alpha=0.3)
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"Saved: {output_path}")
else:
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Explore Neuropixels recording')
parser.add_argument('data_path', help='Path to SpikeGLX recording')
parser.add_argument('--stream', default='imec0.ap', help='Stream ID')
parser.add_argument('--plot', action='store_true', help='Generate plots')
parser.add_argument('--output', default=None, help='Output directory for plots')
args = parser.parse_args()
recording = explore_recording(args.data_path, args.stream)
if args.plot:
import os
if args.output:
os.makedirs(args.output, exist_ok=True)
plot_probe(recording, f"{args.output}/probe_map.png")
plot_traces(recording, output_path=f"{args.output}/raw_traces.png")
plot_power_spectrum(recording, f"{args.output}/power_spectrum.png")
else:
plot_probe(recording)
plot_traces(recording)
plot_power_spectrum(recording)

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python
"""
Export sorting results to Phy for manual curation.
Usage:
python export_to_phy.py metrics/analyzer --output phy_export/
"""
import argparse
from pathlib import Path
import spikeinterface.full as si
from spikeinterface.exporters import export_to_phy
def export_phy(
analyzer_path: str,
output_dir: str,
copy_binary: bool = True,
compute_amplitudes: bool = True,
compute_pc_features: bool = True,
n_jobs: int = -1,
):
"""Export to Phy format."""
print(f"Loading analyzer from: {analyzer_path}")
analyzer = si.load_sorting_analyzer(analyzer_path)
print(f"Units: {len(analyzer.sorting.unit_ids)}")
output_path = Path(output_dir)
# Compute required extensions if missing
if compute_amplitudes and analyzer.get_extension('spike_amplitudes') is None:
print("Computing spike amplitudes...")
analyzer.compute('spike_amplitudes')
if compute_pc_features and analyzer.get_extension('principal_components') is None:
print("Computing principal components...")
analyzer.compute('principal_components', n_components=5, mode='by_channel_local')
print(f"Exporting to Phy: {output_path}")
export_to_phy(
analyzer,
output_folder=output_path,
copy_binary=copy_binary,
compute_amplitudes=compute_amplitudes,
compute_pc_features=compute_pc_features,
n_jobs=n_jobs,
)
print("\nExport complete!")
print(f"To open in Phy, run:")
print(f" phy template-gui {output_path / 'params.py'}")
def main():
parser = argparse.ArgumentParser(description='Export to Phy')
parser.add_argument('analyzer', help='Path to sorting analyzer')
parser.add_argument('--output', '-o', default='phy_export/', help='Output directory')
parser.add_argument('--no-binary', action='store_true', help='Skip copying binary file')
parser.add_argument('--no-amplitudes', action='store_true', help='Skip amplitude computation')
parser.add_argument('--no-pc', action='store_true', help='Skip PC feature computation')
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
args = parser.parse_args()
export_phy(
args.analyzer,
args.output,
copy_binary=not args.no_binary,
compute_amplitudes=not args.no_amplitudes,
compute_pc_features=not args.no_pc,
n_jobs=args.n_jobs,
)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,432 @@
#!/usr/bin/env python3
"""
Neuropixels Data Analysis Pipeline (Best Practices Version)
Based on SpikeInterface, Allen Institute, and IBL recommendations.
Usage:
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
References:
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
- https://github.com/AllenInstitute/ecephys_spike_sorting
"""
import argparse
from pathlib import Path
import json
import spikeinterface.full as si
import numpy as np
def load_recording(data_path: str, stream_name: str = 'imec0.ap') -> si.BaseRecording:
"""Load a SpikeGLX or Open Ephys recording."""
data_path = Path(data_path)
# Auto-detect format
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
# SpikeGLX format
streams, _ = si.get_neo_streams('spikeglx', data_path)
print(f"Available streams: {streams}")
recording = si.read_spikeglx(data_path, stream_name=stream_name)
elif any(data_path.rglob('*.oebin')):
# Open Ephys format
recording = si.read_openephys(data_path)
else:
raise ValueError(f"Unknown format in {data_path}")
print(f"Loaded recording:")
print(f" Channels: {recording.get_num_channels()}")
print(f" Duration: {recording.get_total_duration():.2f} s")
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
return recording
def preprocess(
recording: si.BaseRecording,
apply_phase_shift: bool = True,
freq_min: float = 400.,
) -> tuple:
"""
Apply standard Neuropixels preprocessing.
Following SpikeInterface recommendations:
1. High-pass filter at 400 Hz (not 300)
2. Detect and remove bad channels
3. Phase shift (NP 1.0 only)
4. Common median reference
"""
print("Preprocessing...")
# Step 1: High-pass filter
rec = si.highpass_filter(recording, freq_min=freq_min)
print(f" Applied high-pass filter at {freq_min} Hz")
# Step 2: Detect bad channels
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
if len(bad_channel_ids) > 0:
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
rec = rec.remove_channels(bad_channel_ids)
else:
print(" No bad channels detected")
# Step 3: Phase shift (for Neuropixels 1.0)
if apply_phase_shift:
rec = si.phase_shift(rec)
print(" Applied phase shift correction")
# Step 4: Common median reference
rec = si.common_reference(rec, operator='median', reference='global')
print(" Applied common median reference")
return rec, bad_channel_ids
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
"""
Detect peaks and check for drift before spike sorting.
"""
print("Checking for drift...")
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
# Get noise levels
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
# Detect peaks
peaks = detect_peaks(
recording,
method='locally_exclusive',
noise_levels=noise_levels,
detect_threshold=5,
radius_um=50.,
**job_kwargs
)
print(f" Detected {len(peaks)} peaks")
# Localize peaks
peak_locations = localize_peaks(
recording, peaks,
method='center_of_mass',
**job_kwargs
)
# Save drift plot
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 6))
# Subsample for plotting
n_plot = min(100000, len(peaks))
idx = np.random.choice(len(peaks), n_plot, replace=False)
ax.scatter(
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
peak_locations['y'][idx],
s=1, alpha=0.1, c='k'
)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Depth (μm)')
ax.set_title('Peak Activity (Check for Drift)')
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
plt.close()
print(f" Saved drift plot to {output_folder}/drift_check.png")
# Estimate drift magnitude
y_positions = peak_locations['y']
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
print(f" Estimated drift range: {drift_estimate:.1f} μm")
return {
'peaks': peaks,
'peak_locations': peak_locations,
'drift_estimate': drift_estimate
}
def correct_motion(
recording: si.BaseRecording,
output_folder: str,
preset: str = 'nonrigid_fast_and_accurate'
) -> si.BaseRecording:
"""Apply motion correction if needed."""
print(f"Applying motion correction (preset: {preset})...")
rec_corrected = si.correct_motion(
recording,
preset=preset,
folder=f'{output_folder}/motion',
output_motion_info=True,
n_jobs=8,
chunk_duration='1s',
progress_bar=True
)
print(" Motion correction complete")
return rec_corrected
def run_spike_sorting(
recording: si.BaseRecording,
output_folder: str,
sorter: str = 'kilosort4'
) -> si.BaseSorting:
"""Run spike sorting."""
print(f"Running spike sorting with {sorter}...")
sorter_folder = f'{output_folder}/sorting_{sorter}'
sorting = si.run_sorter(
sorter,
recording,
output_folder=sorter_folder,
verbose=True
)
print(f" Found {len(sorting.unit_ids)} units")
print(f" Total spikes: {sorting.get_total_num_spikes()}")
return sorting
def postprocess(
sorting: si.BaseSorting,
recording: si.BaseRecording,
output_folder: str
) -> tuple:
"""Run post-processing and compute quality metrics."""
print("Post-processing...")
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
# Create analyzer
analyzer = si.create_sorting_analyzer(
sorting, recording,
sparse=True,
format='binary_folder',
folder=f'{output_folder}/analyzer'
)
# Compute extensions (order matters)
print(" Computing waveforms...")
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
analyzer.compute('templates', operators=['average', 'std'])
analyzer.compute('noise_levels')
print(" Computing spike features...")
analyzer.compute('spike_amplitudes', **job_kwargs)
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
analyzer.compute('unit_locations', method='monopolar_triangulation')
analyzer.compute('template_similarity')
print(" Computing quality metrics...")
analyzer.compute('quality_metrics')
qm = analyzer.get_extension('quality_metrics').get_data()
return analyzer, qm
def curate_units(qm, method: str = 'allen') -> dict:
"""
Classify units based on quality metrics.
Methods:
'allen': Allen Institute defaults (more permissive)
'ibl': IBL standards
'strict': Strict single-unit criteria
"""
print(f"Curating units (method: {method})...")
labels = {}
for unit_id in qm.index:
row = qm.loc[unit_id]
# Noise detection (universal)
if row['snr'] < 1.5:
labels[unit_id] = 'noise'
continue
if method == 'allen':
# Allen Institute defaults
if (row['presence_ratio'] > 0.9 and
row['isi_violations_ratio'] < 0.5 and
row['amplitude_cutoff'] < 0.1):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.5:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
elif method == 'ibl':
# IBL standards
if (row['presence_ratio'] > 0.9 and
row['isi_violations_ratio'] < 0.1 and
row['amplitude_cutoff'] < 0.1 and
row['firing_rate'] > 0.1):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.1:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
elif method == 'strict':
# Strict single-unit
if (row['snr'] > 5 and
row['presence_ratio'] > 0.95 and
row['isi_violations_ratio'] < 0.01 and
row['amplitude_cutoff'] < 0.01):
labels[unit_id] = 'good'
elif row['isi_violations_ratio'] > 0.05:
labels[unit_id] = 'mua'
else:
labels[unit_id] = 'unsorted'
# Summary
from collections import Counter
counts = Counter(labels.values())
print(f" Classification: {dict(counts)}")
return labels
def export_results(
analyzer,
sorting,
recording,
labels: dict,
output_folder: str
):
"""Export results to various formats."""
print("Exporting results...")
# Get good units
good_ids = [u for u, l in labels.items() if l == 'good']
sorting_good = sorting.select_units(good_ids)
# Export to Phy
phy_folder = f'{output_folder}/phy_export'
si.export_to_phy(analyzer, phy_folder,
compute_pc_features=True,
compute_amplitudes=True)
print(f" Phy export: {phy_folder}")
# Generate report
report_folder = f'{output_folder}/report'
si.export_report(analyzer, report_folder, format='png')
print(f" Report: {report_folder}")
# Save quality metrics
qm = analyzer.get_extension('quality_metrics').get_data()
qm.to_csv(f'{output_folder}/quality_metrics.csv')
# Save labels
with open(f'{output_folder}/unit_labels.json', 'w') as f:
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
# Save summary
summary = {
'total_units': len(sorting.unit_ids),
'good_units': len(good_ids),
'total_spikes': int(sorting.get_total_num_spikes()),
'duration_s': float(recording.get_total_duration()),
'n_channels': int(recording.get_num_channels()),
}
with open(f'{output_folder}/summary.json', 'w') as f:
json.dump(summary, f, indent=2)
print(f" Summary: {summary}")
def run_pipeline(
data_path: str,
output_path: str,
sorter: str = 'kilosort4',
stream_name: str = 'imec0.ap',
apply_motion_correction: bool = True,
curation_method: str = 'allen'
):
"""Run complete Neuropixels analysis pipeline."""
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
# 1. Load data
recording = load_recording(data_path, stream_name)
# 2. Preprocess
rec_preprocessed, bad_channels = preprocess(recording)
# Save preprocessed
preproc_folder = output_path / 'preprocessed'
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
rec_preprocessed = rec_preprocessed.save(
folder=str(preproc_folder),
format='binary',
**job_kwargs
)
# 3. Check drift
drift_info = check_drift(rec_preprocessed, str(output_path))
# 4. Motion correction (if needed)
if apply_motion_correction and drift_info['drift_estimate'] > 20:
print(f"Drift > 20 μm detected, applying motion correction...")
rec_final = correct_motion(rec_preprocessed, str(output_path))
else:
print("Skipping motion correction (low drift)")
rec_final = rec_preprocessed
# 5. Spike sorting
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
# 6. Post-processing
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
# 7. Curation
labels = curate_units(qm, method=curation_method)
# 8. Export
export_results(analyzer, sorting, rec_final, labels, str(output_path))
print("\n" + "="*50)
print("Pipeline complete!")
print(f"Output directory: {output_path}")
print("="*50)
return analyzer, sorting, qm, labels
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Neuropixels analysis pipeline (best practices)'
)
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
parser.add_argument('output_path', help='Output directory')
parser.add_argument('--sorter', default='kilosort4',
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
help='Spike sorter to use')
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
parser.add_argument('--no-motion-correction', action='store_true',
help='Skip motion correction')
parser.add_argument('--curation', default='allen',
choices=['allen', 'ibl', 'strict'],
help='Curation method')
args = parser.parse_args()
run_pipeline(
args.data_path,
args.output_path,
sorter=args.sorter,
stream_name=args.stream,
apply_motion_correction=not args.no_motion_correction,
curation_method=args.curation
)

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python
"""
Preprocess Neuropixels recording.
Usage:
python preprocess_recording.py /path/to/data --output preprocessed/ --format spikeglx
"""
import argparse
from pathlib import Path
import spikeinterface.full as si
def preprocess_recording(
input_path: str,
output_dir: str,
format: str = 'auto',
stream_id: str = None,
freq_min: float = 300,
freq_max: float = 6000,
phase_shift: bool = True,
common_ref: bool = True,
detect_bad: bool = True,
n_jobs: int = -1,
):
"""Preprocess a Neuropixels recording."""
print(f"Loading recording from: {input_path}")
# Load recording
if format == 'spikeglx' or (format == 'auto' and 'imec' in str(input_path).lower()):
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
elif format == 'openephys':
recording = si.read_openephys(input_path)
elif format == 'nwb':
recording = si.read_nwb(input_path)
else:
# Try auto-detection
try:
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
except:
recording = si.load_extractor(input_path)
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
# Preprocessing chain
rec = recording
# Bandpass filter
print(f"Applying bandpass filter ({freq_min}-{freq_max} Hz)...")
rec = si.bandpass_filter(rec, freq_min=freq_min, freq_max=freq_max)
# Phase shift correction (for Neuropixels ADC)
if phase_shift:
print("Applying phase shift correction...")
rec = si.phase_shift(rec)
# Bad channel detection
if detect_bad:
print("Detecting bad channels...")
bad_channel_ids, bad_labels = si.detect_bad_channels(rec)
if len(bad_channel_ids) > 0:
print(f" Removing {len(bad_channel_ids)} bad channels: {bad_channel_ids[:10]}...")
rec = rec.remove_channels(bad_channel_ids)
# Common median reference
if common_ref:
print("Applying common median reference...")
rec = si.common_reference(rec, operator='median', reference='global')
# Save preprocessed
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"Saving preprocessed recording to: {output_path}")
rec.save(folder=output_path / 'preprocessed', n_jobs=n_jobs)
# Save probe info
probe = rec.get_probe()
if probe is not None:
from probeinterface import write_probeinterface
write_probeinterface(output_path / 'probe.json', probe)
print("Done!")
print(f" Output channels: {rec.get_num_channels()}")
print(f" Output duration: {rec.get_total_duration():.1f}s")
return rec
def main():
parser = argparse.ArgumentParser(description='Preprocess Neuropixels recording')
parser.add_argument('input', help='Path to input recording')
parser.add_argument('--output', '-o', default='preprocessed/', help='Output directory')
parser.add_argument('--format', '-f', default='auto', choices=['auto', 'spikeglx', 'openephys', 'nwb'])
parser.add_argument('--stream-id', default=None, help='Stream ID for multi-probe recordings')
parser.add_argument('--freq-min', type=float, default=300, help='Highpass cutoff (Hz)')
parser.add_argument('--freq-max', type=float, default=6000, help='Lowpass cutoff (Hz)')
parser.add_argument('--no-phase-shift', action='store_true', help='Skip phase shift correction')
parser.add_argument('--no-cmr', action='store_true', help='Skip common median reference')
parser.add_argument('--no-bad-channel', action='store_true', help='Skip bad channel detection')
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
args = parser.parse_args()
preprocess_recording(
args.input,
args.output,
format=args.format,
stream_id=args.stream_id,
freq_min=args.freq_min,
freq_max=args.freq_max,
phase_shift=not args.no_phase_shift,
common_ref=not args.no_cmr,
detect_bad=not args.no_bad_channel,
n_jobs=args.n_jobs,
)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
"""
Run spike sorting on preprocessed recording.
Usage:
python run_sorting.py preprocessed/ --sorter kilosort4 --output sorting/
"""
import argparse
from pathlib import Path
import spikeinterface.full as si
# Default parameters for each sorter
SORTER_DEFAULTS = {
'kilosort4': {
'batch_size': 30000,
'nblocks': 1,
'Th_learned': 8,
'Th_universal': 9,
},
'kilosort3': {
'do_CAR': False, # Already done in preprocessing
},
'spykingcircus2': {
'apply_preprocessing': False,
},
'mountainsort5': {
'filter': False,
'whiten': False,
},
}
def run_sorting(
input_path: str,
output_dir: str,
sorter: str = 'kilosort4',
sorter_params: dict = None,
n_jobs: int = -1,
):
"""Run spike sorting."""
print(f"Loading preprocessed recording from: {input_path}")
recording = si.load_extractor(Path(input_path) / 'preprocessed')
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
# Get sorter parameters
params = SORTER_DEFAULTS.get(sorter, {}).copy()
if sorter_params:
params.update(sorter_params)
print(f"Running {sorter} with params: {params}")
output_path = Path(output_dir)
# Run sorter (note: parameter is 'folder' not 'output_folder' in newer SpikeInterface)
sorting = si.run_sorter(
sorter,
recording,
folder=output_path / f'{sorter}_output',
verbose=True,
**params,
)
print(f"\nSorting complete!")
print(f" Units found: {len(sorting.unit_ids)}")
print(f" Total spikes: {sum(len(sorting.get_unit_spike_train(uid)) for uid in sorting.unit_ids)}")
# Save sorting
sorting.save(folder=output_path / 'sorting')
print(f" Saved to: {output_path / 'sorting'}")
return sorting
def main():
parser = argparse.ArgumentParser(description='Run spike sorting')
parser.add_argument('input', help='Path to preprocessed recording')
parser.add_argument('--output', '-o', default='sorting/', help='Output directory')
parser.add_argument('--sorter', '-s', default='kilosort4',
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'])
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
args = parser.parse_args()
run_sorting(
args.input,
args.output,
sorter=args.sorter,
n_jobs=args.n_jobs,
)
if __name__ == '__main__':
main()