mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-03-27 07:09:27 +08:00
Add neuropixels-analysis skill for extracellular electrophysiology
Adds comprehensive toolkit for analyzing Neuropixels high-density neural recordings using SpikeInterface, Allen Institute, and IBL best practices. Features: - Data loading from SpikeGLX, Open Ephys, and NWB formats - Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection) - Motion/drift estimation and correction - Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5) - Quality metrics computation (SNR, ISI violations, presence ratio) - Automated curation using Allen/IBL criteria - AI-assisted visual curation for uncertain units - Export to Phy and NWB formats Supports Neuropixels 1.0 and 2.0 probes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Compute quality metrics and curate units.
|
||||
|
||||
Usage:
|
||||
python compute_metrics.py sorting/ preprocessed/ --output metrics/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
# Curation criteria presets
|
||||
CURATION_CRITERIA = {
|
||||
'allen': {
|
||||
'snr': 3.0,
|
||||
'isi_violations_ratio': 0.1,
|
||||
'presence_ratio': 0.9,
|
||||
'amplitude_cutoff': 0.1,
|
||||
},
|
||||
'ibl': {
|
||||
'snr': 4.0,
|
||||
'isi_violations_ratio': 0.5,
|
||||
'presence_ratio': 0.5,
|
||||
'amplitude_cutoff': None,
|
||||
},
|
||||
'strict': {
|
||||
'snr': 5.0,
|
||||
'isi_violations_ratio': 0.01,
|
||||
'presence_ratio': 0.95,
|
||||
'amplitude_cutoff': 0.05,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
sorting_path: str,
|
||||
recording_path: str,
|
||||
output_dir: str,
|
||||
curation_method: str = 'allen',
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Compute quality metrics and apply curation."""
|
||||
|
||||
print(f"Loading sorting from: {sorting_path}")
|
||||
sorting = si.load_extractor(Path(sorting_path) / 'sorting')
|
||||
|
||||
print(f"Loading recording from: {recording_path}")
|
||||
recording = si.load_extractor(Path(recording_path) / 'preprocessed')
|
||||
|
||||
print(f"Units: {len(sorting.unit_ids)}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create analyzer
|
||||
print("Creating SortingAnalyzer...")
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting,
|
||||
recording,
|
||||
format='binary_folder',
|
||||
folder=output_path / 'analyzer',
|
||||
sparse=True,
|
||||
)
|
||||
|
||||
# Compute extensions
|
||||
print("Computing waveforms...")
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
|
||||
print("Computing additional extensions...")
|
||||
analyzer.compute('noise_levels')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
|
||||
# Compute quality metrics
|
||||
print("Computing quality metrics...")
|
||||
metrics = si.compute_quality_metrics(
|
||||
analyzer,
|
||||
metric_names=[
|
||||
'snr',
|
||||
'isi_violations_ratio',
|
||||
'presence_ratio',
|
||||
'amplitude_cutoff',
|
||||
'firing_rate',
|
||||
'amplitude_cv',
|
||||
'sliding_rp_violation',
|
||||
],
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
|
||||
# Save metrics
|
||||
metrics.to_csv(output_path / 'quality_metrics.csv')
|
||||
print(f"Saved metrics to: {output_path / 'quality_metrics.csv'}")
|
||||
|
||||
# Apply curation
|
||||
criteria = CURATION_CRITERIA.get(curation_method, CURATION_CRITERIA['allen'])
|
||||
print(f"\nApplying {curation_method} curation criteria: {criteria}")
|
||||
|
||||
labels = {}
|
||||
for unit_id in metrics.index:
|
||||
row = metrics.loc[unit_id]
|
||||
|
||||
# Check each criterion
|
||||
is_good = True
|
||||
|
||||
if criteria.get('snr') and row.get('snr', 0) < criteria['snr']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('isi_violations_ratio') and row.get('isi_violations_ratio', 1) > criteria['isi_violations_ratio']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('presence_ratio') and row.get('presence_ratio', 0) < criteria['presence_ratio']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('amplitude_cutoff') and row.get('amplitude_cutoff', 1) > criteria['amplitude_cutoff']:
|
||||
is_good = False
|
||||
|
||||
# Classify
|
||||
if is_good:
|
||||
labels[int(unit_id)] = 'good'
|
||||
elif row.get('snr', 0) < 2:
|
||||
labels[int(unit_id)] = 'noise'
|
||||
else:
|
||||
labels[int(unit_id)] = 'mua'
|
||||
|
||||
# Save labels
|
||||
with open(output_path / 'curation_labels.json', 'w') as f:
|
||||
json.dump(labels, f, indent=2)
|
||||
|
||||
# Summary
|
||||
label_counts = {}
|
||||
for label in labels.values():
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
|
||||
print(f"\nCuration summary:")
|
||||
print(f" Good: {label_counts.get('good', 0)}")
|
||||
print(f" MUA: {label_counts.get('mua', 0)}")
|
||||
print(f" Noise: {label_counts.get('noise', 0)}")
|
||||
print(f" Total: {len(labels)}")
|
||||
|
||||
# Metrics summary
|
||||
print(f"\nMetrics summary:")
|
||||
for col in ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']:
|
||||
if col in metrics.columns:
|
||||
print(f" {col}: {metrics[col].median():.4f} (median)")
|
||||
|
||||
return analyzer, metrics, labels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Compute quality metrics')
|
||||
parser.add_argument('sorting', help='Path to sorting directory')
|
||||
parser.add_argument('recording', help='Path to preprocessed recording')
|
||||
parser.add_argument('--output', '-o', default='metrics/', help='Output directory')
|
||||
parser.add_argument('--curation', '-c', default='allen',
|
||||
choices=['allen', 'ibl', 'strict'])
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
compute_metrics(
|
||||
args.sorting,
|
||||
args.recording,
|
||||
args.output,
|
||||
curation_method=args.curation,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick exploration of Neuropixels recording.
|
||||
|
||||
Usage:
|
||||
python explore_recording.py /path/to/spikeglx/data
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import spikeinterface.full as si
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def explore_recording(data_path: str, stream_id: str = 'imec0.ap'):
|
||||
"""Explore a Neuropixels recording."""
|
||||
|
||||
print(f"Loading: {data_path}")
|
||||
recording = si.read_spikeglx(data_path, stream_id=stream_id)
|
||||
|
||||
# Basic info
|
||||
print("\n" + "="*50)
|
||||
print("RECORDING INFO")
|
||||
print("="*50)
|
||||
print(f"Channels: {recording.get_num_channels()}")
|
||||
print(f"Duration: {recording.get_total_duration():.2f} s ({recording.get_total_duration()/60:.2f} min)")
|
||||
print(f"Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
print(f"Total samples: {recording.get_num_samples()}")
|
||||
|
||||
# Probe info
|
||||
probe = recording.get_probe()
|
||||
print(f"\nProbe: {probe.manufacturer} {probe.model_name if hasattr(probe, 'model_name') else ''}")
|
||||
print(f"Probe shape: {probe.ndim}D")
|
||||
|
||||
# Channel groups
|
||||
if recording.get_channel_groups() is not None:
|
||||
groups = np.unique(recording.get_channel_groups())
|
||||
print(f"Channel groups (shanks): {len(groups)}")
|
||||
|
||||
# Check for bad channels
|
||||
print("\n" + "="*50)
|
||||
print("BAD CHANNEL DETECTION")
|
||||
print("="*50)
|
||||
bad_ids, labels = si.detect_bad_channels(recording)
|
||||
if len(bad_ids) > 0:
|
||||
print(f"Bad channels found: {len(bad_ids)}")
|
||||
for ch, label in zip(bad_ids, labels):
|
||||
print(f" Channel {ch}: {label}")
|
||||
else:
|
||||
print("No bad channels detected")
|
||||
|
||||
# Sample traces
|
||||
print("\n" + "="*50)
|
||||
print("SIGNAL STATISTICS")
|
||||
print("="*50)
|
||||
|
||||
# Get 1 second of data
|
||||
n_samples = int(recording.get_sampling_frequency())
|
||||
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
|
||||
|
||||
print(f"Sample mean: {np.mean(traces):.2f}")
|
||||
print(f"Sample std: {np.std(traces):.2f}")
|
||||
print(f"Sample min: {np.min(traces):.2f}")
|
||||
print(f"Sample max: {np.max(traces):.2f}")
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def plot_probe(recording, output_path=None):
|
||||
"""Plot probe layout."""
|
||||
fig, ax = plt.subplots(figsize=(4, 12))
|
||||
si.plot_probe_map(recording, ax=ax, with_channel_ids=False)
|
||||
ax.set_title('Probe Layout')
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_traces(recording, duration=1.0, output_path=None):
|
||||
"""Plot raw traces."""
|
||||
n_samples = int(duration * recording.get_sampling_frequency())
|
||||
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Plot subset of channels
|
||||
n_channels = min(20, recording.get_num_channels())
|
||||
channel_idx = np.linspace(0, recording.get_num_channels()-1, n_channels, dtype=int)
|
||||
|
||||
time = np.arange(n_samples) / recording.get_sampling_frequency()
|
||||
|
||||
for i, ch in enumerate(channel_idx):
|
||||
offset = i * 200 # Offset for visibility
|
||||
ax.plot(time, traces[:, ch] + offset, 'k', linewidth=0.5)
|
||||
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Channel (offset)')
|
||||
ax.set_title(f'Raw Traces ({n_channels} channels)')
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_power_spectrum(recording, output_path=None):
|
||||
"""Plot power spectrum."""
|
||||
from scipy import signal
|
||||
|
||||
# Get data from middle channel
|
||||
mid_ch = recording.get_num_channels() // 2
|
||||
n_samples = min(int(10 * recording.get_sampling_frequency()), recording.get_num_samples())
|
||||
|
||||
traces = recording.get_traces(
|
||||
start_frame=0,
|
||||
end_frame=n_samples,
|
||||
channel_ids=[recording.channel_ids[mid_ch]]
|
||||
).flatten()
|
||||
|
||||
fs = recording.get_sampling_frequency()
|
||||
|
||||
# Compute power spectrum
|
||||
freqs, psd = signal.welch(traces, fs, nperseg=4096)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
ax.semilogy(freqs, psd)
|
||||
ax.set_xlabel('Frequency (Hz)')
|
||||
ax.set_ylabel('Power Spectral Density')
|
||||
ax.set_title(f'Power Spectrum (Channel {mid_ch})')
|
||||
ax.set_xlim(0, 5000)
|
||||
ax.axvline(300, color='r', linestyle='--', alpha=0.5, label='300 Hz')
|
||||
ax.axvline(6000, color='r', linestyle='--', alpha=0.5, label='6000 Hz')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Explore Neuropixels recording')
|
||||
parser.add_argument('data_path', help='Path to SpikeGLX recording')
|
||||
parser.add_argument('--stream', default='imec0.ap', help='Stream ID')
|
||||
parser.add_argument('--plot', action='store_true', help='Generate plots')
|
||||
parser.add_argument('--output', default=None, help='Output directory for plots')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
recording = explore_recording(args.data_path, args.stream)
|
||||
|
||||
if args.plot:
|
||||
import os
|
||||
if args.output:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
plot_probe(recording, f"{args.output}/probe_map.png")
|
||||
plot_traces(recording, output_path=f"{args.output}/raw_traces.png")
|
||||
plot_power_spectrum(recording, f"{args.output}/power_spectrum.png")
|
||||
else:
|
||||
plot_probe(recording)
|
||||
plot_traces(recording)
|
||||
plot_power_spectrum(recording)
|
||||
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Export sorting results to Phy for manual curation.
|
||||
|
||||
Usage:
|
||||
python export_to_phy.py metrics/analyzer --output phy_export/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
from spikeinterface.exporters import export_to_phy
|
||||
|
||||
|
||||
def export_phy(
|
||||
analyzer_path: str,
|
||||
output_dir: str,
|
||||
copy_binary: bool = True,
|
||||
compute_amplitudes: bool = True,
|
||||
compute_pc_features: bool = True,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Export to Phy format."""
|
||||
|
||||
print(f"Loading analyzer from: {analyzer_path}")
|
||||
analyzer = si.load_sorting_analyzer(analyzer_path)
|
||||
|
||||
print(f"Units: {len(analyzer.sorting.unit_ids)}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# Compute required extensions if missing
|
||||
if compute_amplitudes and analyzer.get_extension('spike_amplitudes') is None:
|
||||
print("Computing spike amplitudes...")
|
||||
analyzer.compute('spike_amplitudes')
|
||||
|
||||
if compute_pc_features and analyzer.get_extension('principal_components') is None:
|
||||
print("Computing principal components...")
|
||||
analyzer.compute('principal_components', n_components=5, mode='by_channel_local')
|
||||
|
||||
print(f"Exporting to Phy: {output_path}")
|
||||
export_to_phy(
|
||||
analyzer,
|
||||
output_folder=output_path,
|
||||
copy_binary=copy_binary,
|
||||
compute_amplitudes=compute_amplitudes,
|
||||
compute_pc_features=compute_pc_features,
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
|
||||
print("\nExport complete!")
|
||||
print(f"To open in Phy, run:")
|
||||
print(f" phy template-gui {output_path / 'params.py'}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Export to Phy')
|
||||
parser.add_argument('analyzer', help='Path to sorting analyzer')
|
||||
parser.add_argument('--output', '-o', default='phy_export/', help='Output directory')
|
||||
parser.add_argument('--no-binary', action='store_true', help='Skip copying binary file')
|
||||
parser.add_argument('--no-amplitudes', action='store_true', help='Skip amplitude computation')
|
||||
parser.add_argument('--no-pc', action='store_true', help='Skip PC feature computation')
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export_phy(
|
||||
args.analyzer,
|
||||
args.output,
|
||||
copy_binary=not args.no_binary,
|
||||
compute_amplitudes=not args.no_amplitudes,
|
||||
compute_pc_features=not args.no_pc,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neuropixels Data Analysis Pipeline (Best Practices Version)
|
||||
|
||||
Based on SpikeInterface, Allen Institute, and IBL recommendations.
|
||||
|
||||
Usage:
|
||||
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
|
||||
|
||||
References:
|
||||
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
|
||||
- https://github.com/AllenInstitute/ecephys_spike_sorting
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
import spikeinterface.full as si
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_recording(data_path: str, stream_name: str = 'imec0.ap') -> si.BaseRecording:
|
||||
"""Load a SpikeGLX or Open Ephys recording."""
|
||||
|
||||
data_path = Path(data_path)
|
||||
|
||||
# Auto-detect format
|
||||
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
|
||||
# SpikeGLX format
|
||||
streams, _ = si.get_neo_streams('spikeglx', data_path)
|
||||
print(f"Available streams: {streams}")
|
||||
recording = si.read_spikeglx(data_path, stream_name=stream_name)
|
||||
elif any(data_path.rglob('*.oebin')):
|
||||
# Open Ephys format
|
||||
recording = si.read_openephys(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unknown format in {data_path}")
|
||||
|
||||
print(f"Loaded recording:")
|
||||
print(f" Channels: {recording.get_num_channels()}")
|
||||
print(f" Duration: {recording.get_total_duration():.2f} s")
|
||||
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def preprocess(
|
||||
recording: si.BaseRecording,
|
||||
apply_phase_shift: bool = True,
|
||||
freq_min: float = 400.,
|
||||
) -> tuple:
|
||||
"""
|
||||
Apply standard Neuropixels preprocessing.
|
||||
|
||||
Following SpikeInterface recommendations:
|
||||
1. High-pass filter at 400 Hz (not 300)
|
||||
2. Detect and remove bad channels
|
||||
3. Phase shift (NP 1.0 only)
|
||||
4. Common median reference
|
||||
"""
|
||||
print("Preprocessing...")
|
||||
|
||||
# Step 1: High-pass filter
|
||||
rec = si.highpass_filter(recording, freq_min=freq_min)
|
||||
print(f" Applied high-pass filter at {freq_min} Hz")
|
||||
|
||||
# Step 2: Detect bad channels
|
||||
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
|
||||
if len(bad_channel_ids) > 0:
|
||||
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
|
||||
rec = rec.remove_channels(bad_channel_ids)
|
||||
else:
|
||||
print(" No bad channels detected")
|
||||
|
||||
# Step 3: Phase shift (for Neuropixels 1.0)
|
||||
if apply_phase_shift:
|
||||
rec = si.phase_shift(rec)
|
||||
print(" Applied phase shift correction")
|
||||
|
||||
# Step 4: Common median reference
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
print(" Applied common median reference")
|
||||
|
||||
return rec, bad_channel_ids
|
||||
|
||||
|
||||
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
|
||||
"""
|
||||
Detect peaks and check for drift before spike sorting.
|
||||
"""
|
||||
print("Checking for drift...")
|
||||
|
||||
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
||||
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Get noise levels
|
||||
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
|
||||
|
||||
# Detect peaks
|
||||
peaks = detect_peaks(
|
||||
recording,
|
||||
method='locally_exclusive',
|
||||
noise_levels=noise_levels,
|
||||
detect_threshold=5,
|
||||
radius_um=50.,
|
||||
**job_kwargs
|
||||
)
|
||||
print(f" Detected {len(peaks)} peaks")
|
||||
|
||||
# Localize peaks
|
||||
peak_locations = localize_peaks(
|
||||
recording, peaks,
|
||||
method='center_of_mass',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# Save drift plot
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Subsample for plotting
|
||||
n_plot = min(100000, len(peaks))
|
||||
idx = np.random.choice(len(peaks), n_plot, replace=False)
|
||||
|
||||
ax.scatter(
|
||||
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
|
||||
peak_locations['y'][idx],
|
||||
s=1, alpha=0.1, c='k'
|
||||
)
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Depth (μm)')
|
||||
ax.set_title('Peak Activity (Check for Drift)')
|
||||
|
||||
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f" Saved drift plot to {output_folder}/drift_check.png")
|
||||
|
||||
# Estimate drift magnitude
|
||||
y_positions = peak_locations['y']
|
||||
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
|
||||
print(f" Estimated drift range: {drift_estimate:.1f} μm")
|
||||
|
||||
return {
|
||||
'peaks': peaks,
|
||||
'peak_locations': peak_locations,
|
||||
'drift_estimate': drift_estimate
|
||||
}
|
||||
|
||||
|
||||
def correct_motion(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
preset: str = 'nonrigid_fast_and_accurate'
|
||||
) -> si.BaseRecording:
|
||||
"""Apply motion correction if needed."""
|
||||
print(f"Applying motion correction (preset: {preset})...")
|
||||
|
||||
rec_corrected = si.correct_motion(
|
||||
recording,
|
||||
preset=preset,
|
||||
folder=f'{output_folder}/motion',
|
||||
output_motion_info=True,
|
||||
n_jobs=8,
|
||||
chunk_duration='1s',
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
print(" Motion correction complete")
|
||||
return rec_corrected
|
||||
|
||||
|
||||
def run_spike_sorting(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
sorter: str = 'kilosort4'
|
||||
) -> si.BaseSorting:
|
||||
"""Run spike sorting."""
|
||||
print(f"Running spike sorting with {sorter}...")
|
||||
|
||||
sorter_folder = f'{output_folder}/sorting_{sorter}'
|
||||
|
||||
sorting = si.run_sorter(
|
||||
sorter,
|
||||
recording,
|
||||
output_folder=sorter_folder,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print(f" Found {len(sorting.unit_ids)} units")
|
||||
print(f" Total spikes: {sorting.get_total_num_spikes()}")
|
||||
|
||||
return sorting
|
||||
|
||||
|
||||
def postprocess(
|
||||
sorting: si.BaseSorting,
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str
|
||||
) -> tuple:
|
||||
"""Run post-processing and compute quality metrics."""
|
||||
print("Post-processing...")
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting, recording,
|
||||
sparse=True,
|
||||
format='binary_folder',
|
||||
folder=f'{output_folder}/analyzer'
|
||||
)
|
||||
|
||||
# Compute extensions (order matters)
|
||||
print(" Computing waveforms...")
|
||||
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
analyzer.compute('noise_levels')
|
||||
|
||||
print(" Computing spike features...")
|
||||
analyzer.compute('spike_amplitudes', **job_kwargs)
|
||||
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
analyzer.compute('template_similarity')
|
||||
|
||||
print(" Computing quality metrics...")
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
return analyzer, qm
|
||||
|
||||
|
||||
def curate_units(qm, method: str = 'allen') -> dict:
|
||||
"""
|
||||
Classify units based on quality metrics.
|
||||
|
||||
Methods:
|
||||
'allen': Allen Institute defaults (more permissive)
|
||||
'ibl': IBL standards
|
||||
'strict': Strict single-unit criteria
|
||||
"""
|
||||
print(f"Curating units (method: {method})...")
|
||||
|
||||
labels = {}
|
||||
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
|
||||
# Noise detection (universal)
|
||||
if row['snr'] < 1.5:
|
||||
labels[unit_id] = 'noise'
|
||||
continue
|
||||
|
||||
if method == 'allen':
|
||||
# Allen Institute defaults
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.5 and
|
||||
row['amplitude_cutoff'] < 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.5:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'ibl':
|
||||
# IBL standards
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.1 and
|
||||
row['amplitude_cutoff'] < 0.1 and
|
||||
row['firing_rate'] > 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.1:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'strict':
|
||||
# Strict single-unit
|
||||
if (row['snr'] > 5 and
|
||||
row['presence_ratio'] > 0.95 and
|
||||
row['isi_violations_ratio'] < 0.01 and
|
||||
row['amplitude_cutoff'] < 0.01):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.05:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
# Summary
|
||||
from collections import Counter
|
||||
counts = Counter(labels.values())
|
||||
print(f" Classification: {dict(counts)}")
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def export_results(
|
||||
analyzer,
|
||||
sorting,
|
||||
recording,
|
||||
labels: dict,
|
||||
output_folder: str
|
||||
):
|
||||
"""Export results to various formats."""
|
||||
print("Exporting results...")
|
||||
|
||||
# Get good units
|
||||
good_ids = [u for u, l in labels.items() if l == 'good']
|
||||
sorting_good = sorting.select_units(good_ids)
|
||||
|
||||
# Export to Phy
|
||||
phy_folder = f'{output_folder}/phy_export'
|
||||
si.export_to_phy(analyzer, phy_folder,
|
||||
compute_pc_features=True,
|
||||
compute_amplitudes=True)
|
||||
print(f" Phy export: {phy_folder}")
|
||||
|
||||
# Generate report
|
||||
report_folder = f'{output_folder}/report'
|
||||
si.export_report(analyzer, report_folder, format='png')
|
||||
print(f" Report: {report_folder}")
|
||||
|
||||
# Save quality metrics
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
qm.to_csv(f'{output_folder}/quality_metrics.csv')
|
||||
|
||||
# Save labels
|
||||
with open(f'{output_folder}/unit_labels.json', 'w') as f:
|
||||
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
|
||||
|
||||
# Save summary
|
||||
summary = {
|
||||
'total_units': len(sorting.unit_ids),
|
||||
'good_units': len(good_ids),
|
||||
'total_spikes': int(sorting.get_total_num_spikes()),
|
||||
'duration_s': float(recording.get_total_duration()),
|
||||
'n_channels': int(recording.get_num_channels()),
|
||||
}
|
||||
with open(f'{output_folder}/summary.json', 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print(f" Summary: {summary}")
|
||||
|
||||
|
||||
def run_pipeline(
|
||||
data_path: str,
|
||||
output_path: str,
|
||||
sorter: str = 'kilosort4',
|
||||
stream_name: str = 'imec0.ap',
|
||||
apply_motion_correction: bool = True,
|
||||
curation_method: str = 'allen'
|
||||
):
|
||||
"""Run complete Neuropixels analysis pipeline."""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Load data
|
||||
recording = load_recording(data_path, stream_name)
|
||||
|
||||
# 2. Preprocess
|
||||
rec_preprocessed, bad_channels = preprocess(recording)
|
||||
|
||||
# Save preprocessed
|
||||
preproc_folder = output_path / 'preprocessed'
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
rec_preprocessed = rec_preprocessed.save(
|
||||
folder=str(preproc_folder),
|
||||
format='binary',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# 3. Check drift
|
||||
drift_info = check_drift(rec_preprocessed, str(output_path))
|
||||
|
||||
# 4. Motion correction (if needed)
|
||||
if apply_motion_correction and drift_info['drift_estimate'] > 20:
|
||||
print(f"Drift > 20 μm detected, applying motion correction...")
|
||||
rec_final = correct_motion(rec_preprocessed, str(output_path))
|
||||
else:
|
||||
print("Skipping motion correction (low drift)")
|
||||
rec_final = rec_preprocessed
|
||||
|
||||
# 5. Spike sorting
|
||||
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
|
||||
|
||||
# 6. Post-processing
|
||||
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
|
||||
|
||||
# 7. Curation
|
||||
labels = curate_units(qm, method=curation_method)
|
||||
|
||||
# 8. Export
|
||||
export_results(analyzer, sorting, rec_final, labels, str(output_path))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Pipeline complete!")
|
||||
print(f"Output directory: {output_path}")
|
||||
print("="*50)
|
||||
|
||||
return analyzer, sorting, qm, labels
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Neuropixels analysis pipeline (best practices)'
|
||||
)
|
||||
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
|
||||
parser.add_argument('output_path', help='Output directory')
|
||||
parser.add_argument('--sorter', default='kilosort4',
|
||||
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
|
||||
help='Spike sorter to use')
|
||||
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
|
||||
parser.add_argument('--no-motion-correction', action='store_true',
|
||||
help='Skip motion correction')
|
||||
parser.add_argument('--curation', default='allen',
|
||||
choices=['allen', 'ibl', 'strict'],
|
||||
help='Curation method')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_pipeline(
|
||||
args.data_path,
|
||||
args.output_path,
|
||||
sorter=args.sorter,
|
||||
stream_name=args.stream,
|
||||
apply_motion_correction=not args.no_motion_correction,
|
||||
curation_method=args.curation
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Preprocess Neuropixels recording.
|
||||
|
||||
Usage:
|
||||
python preprocess_recording.py /path/to/data --output preprocessed/ --format spikeglx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
def preprocess_recording(
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
format: str = 'auto',
|
||||
stream_id: str = None,
|
||||
freq_min: float = 300,
|
||||
freq_max: float = 6000,
|
||||
phase_shift: bool = True,
|
||||
common_ref: bool = True,
|
||||
detect_bad: bool = True,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Preprocess a Neuropixels recording."""
|
||||
|
||||
print(f"Loading recording from: {input_path}")
|
||||
|
||||
# Load recording
|
||||
if format == 'spikeglx' or (format == 'auto' and 'imec' in str(input_path).lower()):
|
||||
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
|
||||
elif format == 'openephys':
|
||||
recording = si.read_openephys(input_path)
|
||||
elif format == 'nwb':
|
||||
recording = si.read_nwb(input_path)
|
||||
else:
|
||||
# Try auto-detection
|
||||
try:
|
||||
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
|
||||
except:
|
||||
recording = si.load_extractor(input_path)
|
||||
|
||||
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
|
||||
|
||||
# Preprocessing chain
|
||||
rec = recording
|
||||
|
||||
# Bandpass filter
|
||||
print(f"Applying bandpass filter ({freq_min}-{freq_max} Hz)...")
|
||||
rec = si.bandpass_filter(rec, freq_min=freq_min, freq_max=freq_max)
|
||||
|
||||
# Phase shift correction (for Neuropixels ADC)
|
||||
if phase_shift:
|
||||
print("Applying phase shift correction...")
|
||||
rec = si.phase_shift(rec)
|
||||
|
||||
# Bad channel detection
|
||||
if detect_bad:
|
||||
print("Detecting bad channels...")
|
||||
bad_channel_ids, bad_labels = si.detect_bad_channels(rec)
|
||||
if len(bad_channel_ids) > 0:
|
||||
print(f" Removing {len(bad_channel_ids)} bad channels: {bad_channel_ids[:10]}...")
|
||||
rec = rec.remove_channels(bad_channel_ids)
|
||||
|
||||
# Common median reference
|
||||
if common_ref:
|
||||
print("Applying common median reference...")
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
|
||||
# Save preprocessed
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Saving preprocessed recording to: {output_path}")
|
||||
rec.save(folder=output_path / 'preprocessed', n_jobs=n_jobs)
|
||||
|
||||
# Save probe info
|
||||
probe = rec.get_probe()
|
||||
if probe is not None:
|
||||
from probeinterface import write_probeinterface
|
||||
write_probeinterface(output_path / 'probe.json', probe)
|
||||
|
||||
print("Done!")
|
||||
print(f" Output channels: {rec.get_num_channels()}")
|
||||
print(f" Output duration: {rec.get_total_duration():.1f}s")
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Preprocess Neuropixels recording')
|
||||
parser.add_argument('input', help='Path to input recording')
|
||||
parser.add_argument('--output', '-o', default='preprocessed/', help='Output directory')
|
||||
parser.add_argument('--format', '-f', default='auto', choices=['auto', 'spikeglx', 'openephys', 'nwb'])
|
||||
parser.add_argument('--stream-id', default=None, help='Stream ID for multi-probe recordings')
|
||||
parser.add_argument('--freq-min', type=float, default=300, help='Highpass cutoff (Hz)')
|
||||
parser.add_argument('--freq-max', type=float, default=6000, help='Lowpass cutoff (Hz)')
|
||||
parser.add_argument('--no-phase-shift', action='store_true', help='Skip phase shift correction')
|
||||
parser.add_argument('--no-cmr', action='store_true', help='Skip common median reference')
|
||||
parser.add_argument('--no-bad-channel', action='store_true', help='Skip bad channel detection')
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocess_recording(
|
||||
args.input,
|
||||
args.output,
|
||||
format=args.format,
|
||||
stream_id=args.stream_id,
|
||||
freq_min=args.freq_min,
|
||||
freq_max=args.freq_max,
|
||||
phase_shift=not args.no_phase_shift,
|
||||
common_ref=not args.no_cmr,
|
||||
detect_bad=not args.no_bad_channel,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Run spike sorting on preprocessed recording.
|
||||
|
||||
Usage:
|
||||
python run_sorting.py preprocessed/ --sorter kilosort4 --output sorting/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
# Default parameters for each sorter
|
||||
SORTER_DEFAULTS = {
|
||||
'kilosort4': {
|
||||
'batch_size': 30000,
|
||||
'nblocks': 1,
|
||||
'Th_learned': 8,
|
||||
'Th_universal': 9,
|
||||
},
|
||||
'kilosort3': {
|
||||
'do_CAR': False, # Already done in preprocessing
|
||||
},
|
||||
'spykingcircus2': {
|
||||
'apply_preprocessing': False,
|
||||
},
|
||||
'mountainsort5': {
|
||||
'filter': False,
|
||||
'whiten': False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_sorting(
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
sorter: str = 'kilosort4',
|
||||
sorter_params: dict = None,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Run spike sorting."""
|
||||
|
||||
print(f"Loading preprocessed recording from: {input_path}")
|
||||
recording = si.load_extractor(Path(input_path) / 'preprocessed')
|
||||
|
||||
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
|
||||
|
||||
# Get sorter parameters
|
||||
params = SORTER_DEFAULTS.get(sorter, {}).copy()
|
||||
if sorter_params:
|
||||
params.update(sorter_params)
|
||||
|
||||
print(f"Running {sorter} with params: {params}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# Run sorter (note: parameter is 'folder' not 'output_folder' in newer SpikeInterface)
|
||||
sorting = si.run_sorter(
|
||||
sorter,
|
||||
recording,
|
||||
folder=output_path / f'{sorter}_output',
|
||||
verbose=True,
|
||||
**params,
|
||||
)
|
||||
|
||||
print(f"\nSorting complete!")
|
||||
print(f" Units found: {len(sorting.unit_ids)}")
|
||||
print(f" Total spikes: {sum(len(sorting.get_unit_spike_train(uid)) for uid in sorting.unit_ids)}")
|
||||
|
||||
# Save sorting
|
||||
sorting.save(folder=output_path / 'sorting')
|
||||
print(f" Saved to: {output_path / 'sorting'}")
|
||||
|
||||
return sorting
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run spike sorting')
|
||||
parser.add_argument('input', help='Path to preprocessed recording')
|
||||
parser.add_argument('--output', '-o', default='sorting/', help='Output directory')
|
||||
parser.add_argument('--sorter', '-s', default='kilosort4',
|
||||
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'])
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_sorting(
|
||||
args.input,
|
||||
args.output,
|
||||
sorter=args.sorter,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user