mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-01-26 16:58:56 +08:00
Add neuropixels-analysis skill for extracellular electrophysiology
Adds comprehensive toolkit for analyzing Neuropixels high-density neural recordings using SpikeInterface, Allen Institute, and IBL best practices. Features: - Data loading from SpikeGLX, Open Ephys, and NWB formats - Preprocessing pipelines (filtering, phase shift, CAR, bad channel detection) - Motion/drift estimation and correction - Spike sorting integration (Kilosort4, SpykingCircus2, Mountainsort5) - Quality metrics computation (SNR, ISI violations, presence ratio) - Automated curation using Allen/IBL criteria - AI-assisted visual curation for uncertain units - Export to Phy and NWB formats Supports Neuropixels 1.0 and 2.0 probes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
345
scientific-skills/neuropixels-analysis/AI_CURATION.md
Normal file
345
scientific-skills/neuropixels-analysis/AI_CURATION.md
Normal file
@@ -0,0 +1,345 @@
|
||||
# AI-Assisted Curation Reference
|
||||
|
||||
Guide to using AI visual analysis for unit curation, inspired by SpikeAgent's approach.
|
||||
|
||||
## Overview
|
||||
|
||||
AI-assisted curation uses vision-language models to analyze spike sorting visualizations,
|
||||
providing expert-level quality assessments similar to human curators.
|
||||
|
||||
### Workflow
|
||||
|
||||
```
|
||||
Traditional: Metrics → Threshold → Labels
|
||||
AI-Enhanced: Metrics → AI Visual Analysis → Confidence Score → Labels
|
||||
```
|
||||
|
||||
## Claude Code Integration
|
||||
|
||||
When using this skill within Claude Code, Claude can directly analyze waveform plots without requiring API setup. Simply:
|
||||
|
||||
1. Generate a unit report or plot
|
||||
2. Ask Claude to analyze the visualization
|
||||
3. Claude will provide expert-level curation decisions
|
||||
|
||||
Example workflow in Claude Code:
|
||||
```python
|
||||
# Generate plots for a unit
|
||||
npa.plot_unit_summary(analyzer, unit_id=0, output='unit_0_summary.png')
|
||||
|
||||
# Then ask Claude: "Please analyze this unit's waveforms and autocorrelogram
|
||||
# to determine if it's a well-isolated single unit, multi-unit activity, or noise"
|
||||
```
|
||||
|
||||
Claude can assess:
|
||||
- Waveform consistency and shape
|
||||
- Refractory period violations from autocorrelograms
|
||||
- Amplitude stability over time
|
||||
- Overall unit isolation quality
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Generate Unit Report
|
||||
|
||||
```python
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# Create visual report for a unit
|
||||
report = npa.generate_unit_report(analyzer, unit_id=0, output_dir='reports/')
|
||||
|
||||
# Report includes:
|
||||
# - Waveforms, templates, autocorrelogram
|
||||
# - Amplitudes over time, ISI histogram
|
||||
# - Quality metrics summary
|
||||
# - Base64 encoded image for API
|
||||
```
|
||||
|
||||
### AI Visual Analysis
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Setup API client
|
||||
client = Anthropic()
|
||||
|
||||
# Analyze single unit
|
||||
result = npa.analyze_unit_visually(
|
||||
analyzer,
|
||||
unit_id=0,
|
||||
api_client=client,
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
task='quality_assessment'
|
||||
)
|
||||
|
||||
print(f"Classification: {result['classification']}")
|
||||
print(f"Reasoning: {result['reasoning']}")
|
||||
```
|
||||
|
||||
### Batch Analysis
|
||||
|
||||
```python
|
||||
# Analyze all units
|
||||
results = npa.batch_visual_curation(
|
||||
analyzer,
|
||||
api_client=client,
|
||||
output_dir='ai_curation/',
|
||||
progress_callback=lambda i, n: print(f"Progress: {i}/{n}")
|
||||
)
|
||||
|
||||
# Get labels
|
||||
ai_labels = {uid: r['classification'] for uid, r in results.items()}
|
||||
```
|
||||
|
||||
## Interactive Curation Session
|
||||
|
||||
For human-in-the-loop curation with AI assistance:
|
||||
|
||||
```python
|
||||
# Create session
|
||||
session = npa.CurationSession.create(
|
||||
analyzer,
|
||||
output_dir='curation_session/',
|
||||
sort_by_confidence=True # Show uncertain units first
|
||||
)
|
||||
|
||||
# Process units
|
||||
while True:
|
||||
unit = session.current_unit()
|
||||
if unit is None:
|
||||
break
|
||||
|
||||
print(f"Unit {unit.unit_id}:")
|
||||
print(f" Auto: {unit.auto_classification} (conf: {unit.confidence:.2f})")
|
||||
|
||||
# Generate report
|
||||
report = npa.generate_unit_report(analyzer, unit.unit_id)
|
||||
|
||||
# Get AI opinion
|
||||
ai_result = npa.analyze_unit_visually(analyzer, unit.unit_id, api_client=client)
|
||||
session.set_ai_classification(unit.unit_id, ai_result['classification'])
|
||||
|
||||
# Human decision
|
||||
decision = input("Decision (good/mua/noise/skip): ")
|
||||
if decision != 'skip':
|
||||
session.set_decision(unit.unit_id, decision)
|
||||
|
||||
session.next_unit()
|
||||
|
||||
# Export results
|
||||
labels = session.get_final_labels()
|
||||
session.export_decisions('final_curation.csv')
|
||||
```
|
||||
|
||||
## Analysis Tasks
|
||||
|
||||
### Quality Assessment (Default)
|
||||
|
||||
Analyzes waveform shape, refractory period, amplitude stability.
|
||||
|
||||
```python
|
||||
result = npa.analyze_unit_visually(analyzer, uid, task='quality_assessment')
|
||||
# Returns: 'good', 'mua', or 'noise'
|
||||
```
|
||||
|
||||
### Merge Candidate Detection
|
||||
|
||||
Determines if two units should be merged.
|
||||
|
||||
```python
|
||||
result = npa.analyze_unit_visually(analyzer, uid, task='merge_candidate')
|
||||
# Returns: 'merge' or 'keep_separate'
|
||||
```
|
||||
|
||||
### Drift Assessment
|
||||
|
||||
Evaluates motion/drift in the recording.
|
||||
|
||||
```python
|
||||
result = npa.analyze_unit_visually(analyzer, uid, task='drift_assessment')
|
||||
# Returns drift magnitude and correction recommendation
|
||||
```
|
||||
|
||||
## Custom Prompts
|
||||
|
||||
Create custom analysis prompts:
|
||||
|
||||
```python
|
||||
from neuropixels_analysis.ai_curation import create_curation_prompt
|
||||
|
||||
# Get base prompt
|
||||
prompt = create_curation_prompt(
|
||||
task='quality_assessment',
|
||||
additional_context='Focus on waveform amplitude consistency'
|
||||
)
|
||||
|
||||
# Or fully custom
|
||||
custom_prompt = """
|
||||
Analyze this unit and determine if it represents a fast-spiking interneuron.
|
||||
|
||||
Look for:
|
||||
1. Narrow waveform (peak-to-trough < 0.5ms)
|
||||
2. High firing rate
|
||||
3. Regular ISI distribution
|
||||
|
||||
Classify as: FSI (fast-spiking interneuron) or OTHER
|
||||
"""
|
||||
|
||||
result = npa.analyze_unit_visually(
|
||||
analyzer, uid,
|
||||
api_client=client,
|
||||
custom_prompt=custom_prompt
|
||||
)
|
||||
```
|
||||
|
||||
## Combining AI with Metrics
|
||||
|
||||
Best practice: use both AI and quantitative metrics:
|
||||
|
||||
```python
|
||||
def hybrid_curation(analyzer, metrics, api_client):
|
||||
"""Combine metrics and AI for robust curation."""
|
||||
labels = {}
|
||||
|
||||
for unit_id in metrics.index:
|
||||
row = metrics.loc[unit_id]
|
||||
|
||||
# High confidence from metrics alone
|
||||
if row['snr'] > 10 and row['isi_violations_ratio'] < 0.001:
|
||||
labels[unit_id] = 'good'
|
||||
continue
|
||||
|
||||
if row['snr'] < 1.5:
|
||||
labels[unit_id] = 'noise'
|
||||
continue
|
||||
|
||||
# Uncertain cases: use AI
|
||||
result = npa.analyze_unit_visually(
|
||||
analyzer, unit_id, api_client=api_client
|
||||
)
|
||||
labels[unit_id] = result['classification']
|
||||
|
||||
return labels
|
||||
```
|
||||
|
||||
## Session Management
|
||||
|
||||
### Resume Session
|
||||
|
||||
```python
|
||||
# Resume interrupted session
|
||||
session = npa.CurationSession.load('curation_session/20250101_120000/')
|
||||
|
||||
# Check progress
|
||||
summary = session.get_summary()
|
||||
print(f"Progress: {summary['progress_pct']:.1f}%")
|
||||
print(f"Remaining: {summary['remaining']} units")
|
||||
|
||||
# Continue from where we left off
|
||||
unit = session.current_unit()
|
||||
```
|
||||
|
||||
### Navigate Session
|
||||
|
||||
```python
|
||||
# Go to specific unit
|
||||
session.go_to_unit(42)
|
||||
|
||||
# Previous/next
|
||||
session.prev_unit()
|
||||
session.next_unit()
|
||||
|
||||
# Update decision
|
||||
session.set_decision(42, 'good', notes='Clear refractory period')
|
||||
```
|
||||
|
||||
### Export Results
|
||||
|
||||
```python
|
||||
# Get final labels (priority: human > AI > auto)
|
||||
labels = session.get_final_labels()
|
||||
|
||||
# Export detailed results
|
||||
df = session.export_decisions('curation_results.csv')
|
||||
|
||||
# Summary
|
||||
summary = session.get_summary()
|
||||
print(f"Good: {summary['decisions'].get('good', 0)}")
|
||||
print(f"MUA: {summary['decisions'].get('mua', 0)}")
|
||||
print(f"Noise: {summary['decisions'].get('noise', 0)}")
|
||||
```
|
||||
|
||||
## Visual Report Components
|
||||
|
||||
The generated report includes 6 panels:
|
||||
|
||||
| Panel | Content | What to Look For |
|
||||
|-------|---------|------------------|
|
||||
| Waveforms | Individual spike waveforms | Consistency, shape |
|
||||
| Template | Mean ± std | Clean negative peak, physiological shape |
|
||||
| Autocorrelogram | Spike timing | Gap at 0ms (refractory period) |
|
||||
| Amplitudes | Amplitude over time | Stability, no drift |
|
||||
| ISI Histogram | Inter-spike intervals | Refractory gap < 1.5ms |
|
||||
| Metrics | Quality numbers | SNR, ISI violations, presence |
|
||||
|
||||
## API Support
|
||||
|
||||
Currently supported APIs:
|
||||
|
||||
| Provider | Client | Model Examples |
|
||||
|----------|--------|----------------|
|
||||
| Anthropic | `anthropic.Anthropic()` | claude-3-5-sonnet-20241022 |
|
||||
| OpenAI | `openai.OpenAI()` | gpt-4-vision-preview |
|
||||
| Google | `google.generativeai` | gemini-pro-vision |
|
||||
|
||||
### Anthropic Example
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = Anthropic(api_key="your-api-key")
|
||||
result = npa.analyze_unit_visually(analyzer, uid, api_client=client)
|
||||
```
|
||||
|
||||
### OpenAI Example
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key="your-api-key")
|
||||
result = npa.analyze_unit_visually(
|
||||
analyzer, uid,
|
||||
api_client=client,
|
||||
model='gpt-4-vision-preview'
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use AI for uncertain cases** - Don't waste API calls on obvious good/noise units
|
||||
2. **Combine with metrics** - AI should supplement, not replace, quantitative measures
|
||||
3. **Human oversight** - Review AI decisions, especially for important analyses
|
||||
4. **Save sessions** - Always use CurationSession to track decisions
|
||||
5. **Document reasoning** - Use notes field to record decision rationale
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
```python
|
||||
# Only use AI for uncertain units
|
||||
uncertain_units = metrics.query("""
|
||||
snr > 2 and snr < 8 and
|
||||
isi_violations_ratio > 0.001 and isi_violations_ratio < 0.1
|
||||
""").index.tolist()
|
||||
|
||||
# Batch process only these
|
||||
results = npa.batch_visual_curation(
|
||||
analyzer,
|
||||
unit_ids=uncertain_units,
|
||||
api_client=client
|
||||
)
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [SpikeAgent](https://github.com/SpikeAgent/SpikeAgent) - AI-powered spike sorting assistant
|
||||
- [Anthropic Vision API](https://docs.anthropic.com/en/docs/vision)
|
||||
- [GPT-4 Vision](https://platform.openai.com/docs/guides/vision)
|
||||
392
scientific-skills/neuropixels-analysis/ANALYSIS.md
Normal file
392
scientific-skills/neuropixels-analysis/ANALYSIS.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# Post-Processing & Analysis Reference
|
||||
|
||||
Comprehensive guide to quality metrics, visualization, and analysis of sorted Neuropixels data.
|
||||
|
||||
## Sorting Analyzer
|
||||
|
||||
The `SortingAnalyzer` is the central object for post-processing.
|
||||
|
||||
### Create Analyzer
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting,
|
||||
recording,
|
||||
sparse=True, # Use sparse representation
|
||||
format='binary_folder', # Storage format
|
||||
folder='analyzer_output' # Save location
|
||||
)
|
||||
```
|
||||
|
||||
### Compute Extensions
|
||||
```python
|
||||
# Compute all standard extensions
|
||||
analyzer.compute('random_spikes') # Random spike selection
|
||||
analyzer.compute('waveforms') # Extract waveforms
|
||||
analyzer.compute('templates') # Compute templates
|
||||
analyzer.compute('noise_levels') # Noise estimation
|
||||
analyzer.compute('principal_components') # PCA
|
||||
analyzer.compute('spike_amplitudes') # Amplitude per spike
|
||||
analyzer.compute('correlograms') # Auto/cross correlograms
|
||||
analyzer.compute('unit_locations') # Unit locations
|
||||
analyzer.compute('spike_locations') # Per-spike locations
|
||||
analyzer.compute('template_similarity') # Template similarity matrix
|
||||
analyzer.compute('quality_metrics') # Quality metrics
|
||||
|
||||
# Or compute multiple at once
|
||||
analyzer.compute([
|
||||
'random_spikes', 'waveforms', 'templates', 'noise_levels',
|
||||
'principal_components', 'spike_amplitudes', 'correlograms',
|
||||
'unit_locations', 'quality_metrics'
|
||||
])
|
||||
```
|
||||
|
||||
### Save and Load
|
||||
```python
|
||||
# Save
|
||||
analyzer.save_as(folder='analyzer_saved', format='binary_folder')
|
||||
|
||||
# Load
|
||||
analyzer = si.load_sorting_analyzer('analyzer_saved')
|
||||
```
|
||||
|
||||
## Quality Metrics
|
||||
|
||||
### Compute Metrics
|
||||
```python
|
||||
analyzer.compute('quality_metrics')
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
print(qm)
|
||||
```
|
||||
|
||||
### Available Metrics
|
||||
|
||||
| Metric | Description | Good Values |
|
||||
|--------|-------------|-------------|
|
||||
| `snr` | Signal-to-noise ratio | > 5 |
|
||||
| `isi_violations_ratio` | ISI violation ratio | < 0.01 (1%) |
|
||||
| `isi_violations_count` | ISI violation count | Low |
|
||||
| `presence_ratio` | Fraction of recording with spikes | > 0.9 |
|
||||
| `firing_rate` | Spikes per second | 0.1-50 Hz |
|
||||
| `amplitude_cutoff` | Estimated missed spikes | < 0.1 |
|
||||
| `amplitude_median` | Median spike amplitude | - |
|
||||
| `amplitude_cv` | Coefficient of variation | < 0.5 |
|
||||
| `drift_ptp` | Peak-to-peak drift (um) | < 40 |
|
||||
| `drift_std` | Standard deviation of drift | < 10 |
|
||||
| `drift_mad` | Median absolute deviation | < 10 |
|
||||
| `sliding_rp_violation` | Sliding refractory period | < 0.05 |
|
||||
| `sync_spike_2` | Synchrony with other units | < 0.5 |
|
||||
| `isolation_distance` | Mahalanobis distance | > 20 |
|
||||
| `l_ratio` | L-ratio (isolation) | < 0.1 |
|
||||
| `d_prime` | Discriminability | > 5 |
|
||||
| `nn_hit_rate` | Nearest neighbor hit rate | > 0.9 |
|
||||
| `nn_miss_rate` | Nearest neighbor miss rate | < 0.1 |
|
||||
| `silhouette_score` | Cluster silhouette | > 0.5 |
|
||||
|
||||
### Compute Specific Metrics
|
||||
```python
|
||||
analyzer.compute(
|
||||
'quality_metrics',
|
||||
metric_names=['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Quality Thresholds
|
||||
```python
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Define quality criteria
|
||||
quality_criteria = {
|
||||
'snr': ('>', 5),
|
||||
'isi_violations_ratio': ('<', 0.01),
|
||||
'presence_ratio': ('>', 0.9),
|
||||
'firing_rate': ('>', 0.1),
|
||||
'amplitude_cutoff': ('<', 0.1),
|
||||
}
|
||||
|
||||
# Filter good units
|
||||
good_units = qm.query(
|
||||
"(snr > 5) & (isi_violations_ratio < 0.01) & (presence_ratio > 0.9)"
|
||||
).index.tolist()
|
||||
|
||||
print(f"Good units: {len(good_units)}/{len(qm)}")
|
||||
```
|
||||
|
||||
## Waveforms & Templates
|
||||
|
||||
### Extract Waveforms
|
||||
```python
|
||||
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.5, max_spikes_per_unit=500)
|
||||
|
||||
# Get waveforms for a unit
|
||||
waveforms = analyzer.get_extension('waveforms').get_waveforms(unit_id=0)
|
||||
print(f"Shape: {waveforms.shape}") # (n_spikes, n_samples, n_channels)
|
||||
```
|
||||
|
||||
### Compute Templates
|
||||
```python
|
||||
analyzer.compute('templates', operators=['average', 'std', 'median'])
|
||||
|
||||
# Get template
|
||||
templates_ext = analyzer.get_extension('templates')
|
||||
template = templates_ext.get_unit_template(unit_id=0, operator='average')
|
||||
```
|
||||
|
||||
### Template Similarity
|
||||
```python
|
||||
analyzer.compute('template_similarity')
|
||||
sim = analyzer.get_extension('template_similarity').get_data()
|
||||
# Matrix of cosine similarities between templates
|
||||
```
|
||||
|
||||
## Unit Locations
|
||||
|
||||
### Compute Locations
|
||||
```python
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
locations = analyzer.get_extension('unit_locations').get_data()
|
||||
print(locations) # x, y coordinates per unit
|
||||
```
|
||||
|
||||
### Spike Locations
|
||||
```python
|
||||
analyzer.compute('spike_locations', method='center_of_mass')
|
||||
spike_locs = analyzer.get_extension('spike_locations').get_data()
|
||||
```
|
||||
|
||||
### Location Methods
|
||||
- `'center_of_mass'` - Fast, less accurate
|
||||
- `'monopolar_triangulation'` - More accurate, slower
|
||||
- `'grid_convolution'` - Good balance
|
||||
|
||||
## Correlograms
|
||||
|
||||
### Auto-correlograms
|
||||
```python
|
||||
analyzer.compute('correlograms', window_ms=50, bin_ms=1)
|
||||
correlograms, bins = analyzer.get_extension('correlograms').get_data()
|
||||
|
||||
# correlograms shape: (n_units, n_units, n_bins)
|
||||
# Auto-correlogram for unit i: correlograms[i, i, :]
|
||||
# Cross-correlogram units i,j: correlograms[i, j, :]
|
||||
```
|
||||
|
||||
## Visualization
|
||||
|
||||
### Probe Map
|
||||
```python
|
||||
si.plot_probe_map(recording, with_channel_ids=True)
|
||||
```
|
||||
|
||||
### Unit Templates
|
||||
```python
|
||||
# All units
|
||||
si.plot_unit_templates(analyzer)
|
||||
|
||||
# Specific units
|
||||
si.plot_unit_templates(analyzer, unit_ids=[0, 1, 2])
|
||||
```
|
||||
|
||||
### Waveforms
|
||||
```python
|
||||
# Plot waveforms with template
|
||||
si.plot_unit_waveforms(analyzer, unit_ids=[0])
|
||||
|
||||
# Waveform density
|
||||
si.plot_unit_waveforms_density_map(analyzer, unit_id=0)
|
||||
```
|
||||
|
||||
### Raster Plot
|
||||
```python
|
||||
si.plot_rasters(sorting, time_range=(0, 10)) # First 10 seconds
|
||||
```
|
||||
|
||||
### Amplitudes
|
||||
```python
|
||||
analyzer.compute('spike_amplitudes')
|
||||
si.plot_amplitudes(analyzer)
|
||||
|
||||
# Distribution
|
||||
si.plot_all_amplitudes_distributions(analyzer)
|
||||
```
|
||||
|
||||
### Correlograms
|
||||
```python
|
||||
# Auto-correlograms
|
||||
si.plot_autocorrelograms(analyzer, unit_ids=[0, 1, 2])
|
||||
|
||||
# Cross-correlograms
|
||||
si.plot_crosscorrelograms(analyzer, unit_ids=[0, 1])
|
||||
```
|
||||
|
||||
### Quality Metrics
|
||||
```python
|
||||
# Summary plot
|
||||
si.plot_quality_metrics(analyzer)
|
||||
|
||||
# Specific metric distribution
|
||||
import matplotlib.pyplot as plt
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
plt.hist(qm['snr'], bins=50)
|
||||
plt.xlabel('SNR')
|
||||
plt.ylabel('Count')
|
||||
```
|
||||
|
||||
### Unit Locations on Probe
|
||||
```python
|
||||
si.plot_unit_locations(analyzer)
|
||||
```
|
||||
|
||||
### Drift Map
|
||||
```python
|
||||
si.plot_drift_raster(sorting, recording)
|
||||
```
|
||||
|
||||
### Summary Plot
|
||||
```python
|
||||
# Comprehensive unit summary
|
||||
si.plot_unit_summary(analyzer, unit_id=0)
|
||||
```
|
||||
|
||||
## LFP Analysis
|
||||
|
||||
### Load LFP Data
|
||||
```python
|
||||
lfp = si.read_spikeglx('/path/to/data', stream_id='imec0.lf')
|
||||
print(f"LFP: {lfp.get_sampling_frequency()} Hz")
|
||||
```
|
||||
|
||||
### Basic LFP Processing
|
||||
```python
|
||||
# Downsample if needed
|
||||
lfp_ds = si.resample(lfp, resample_rate=1000)
|
||||
|
||||
# Common average reference
|
||||
lfp_car = si.common_reference(lfp_ds, reference='global', operator='median')
|
||||
```
|
||||
|
||||
### Extract LFP Traces
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
# Get traces (channels x samples)
|
||||
traces = lfp.get_traces(start_frame=0, end_frame=30000)
|
||||
|
||||
# Specific channels
|
||||
traces = lfp.get_traces(channel_ids=[0, 1, 2])
|
||||
```
|
||||
|
||||
### Spectral Analysis
|
||||
```python
|
||||
from scipy import signal
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Get single channel
|
||||
trace = lfp.get_traces(channel_ids=[0]).flatten()
|
||||
fs = lfp.get_sampling_frequency()
|
||||
|
||||
# Power spectrum
|
||||
freqs, psd = signal.welch(trace, fs, nperseg=4096)
|
||||
plt.semilogy(freqs, psd)
|
||||
plt.xlabel('Frequency (Hz)')
|
||||
plt.ylabel('Power')
|
||||
plt.xlim(0, 100)
|
||||
```
|
||||
|
||||
### Spectrogram
|
||||
```python
|
||||
f, t, Sxx = signal.spectrogram(trace, fs, nperseg=2048, noverlap=1024)
|
||||
plt.pcolormesh(t, f, 10*np.log10(Sxx), shading='gouraud')
|
||||
plt.ylabel('Frequency (Hz)')
|
||||
plt.xlabel('Time (s)')
|
||||
plt.ylim(0, 100)
|
||||
plt.colorbar(label='Power (dB)')
|
||||
```
|
||||
|
||||
## Export Formats
|
||||
|
||||
### Export to Phy
|
||||
```python
|
||||
si.export_to_phy(
|
||||
analyzer,
|
||||
output_folder='phy_export',
|
||||
compute_pc_features=True,
|
||||
compute_amplitudes=True,
|
||||
copy_binary=True
|
||||
)
|
||||
# Then: phy template-gui phy_export/params.py
|
||||
```
|
||||
|
||||
### Export to NWB
|
||||
```python
|
||||
from spikeinterface.exporters import export_to_nwb
|
||||
|
||||
export_to_nwb(
|
||||
recording,
|
||||
sorting,
|
||||
'output.nwb',
|
||||
metadata=dict(
|
||||
session_description='Neuropixels recording',
|
||||
experimenter='Name',
|
||||
lab='Lab name',
|
||||
institution='Institution'
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Export Report
|
||||
```python
|
||||
si.export_report(
|
||||
analyzer,
|
||||
output_folder='report',
|
||||
remove_if_exists=True,
|
||||
format='html'
|
||||
)
|
||||
```
|
||||
|
||||
## Complete Analysis Pipeline
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
def analyze_sorting(recording, sorting, output_dir):
|
||||
"""Complete post-processing pipeline."""
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting, recording,
|
||||
sparse=True,
|
||||
folder=f'{output_dir}/analyzer'
|
||||
)
|
||||
|
||||
# Compute all extensions
|
||||
print("Computing extensions...")
|
||||
analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels'])
|
||||
analyzer.compute(['principal_components', 'spike_amplitudes'])
|
||||
analyzer.compute(['correlograms', 'unit_locations', 'template_similarity'])
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
# Get quality metrics
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Filter good units
|
||||
good_units = qm.query(
|
||||
"(snr > 5) & (isi_violations_ratio < 0.01) & (presence_ratio > 0.9)"
|
||||
).index.tolist()
|
||||
|
||||
print(f"Quality filtering: {len(good_units)}/{len(qm)} units passed")
|
||||
|
||||
# Export
|
||||
si.export_to_phy(analyzer, f'{output_dir}/phy')
|
||||
si.export_report(analyzer, f'{output_dir}/report')
|
||||
|
||||
# Save metrics
|
||||
qm.to_csv(f'{output_dir}/quality_metrics.csv')
|
||||
|
||||
return analyzer, qm, good_units
|
||||
|
||||
# Usage
|
||||
analyzer, qm, good_units = analyze_sorting(recording, sorting, 'output/')
|
||||
```
|
||||
358
scientific-skills/neuropixels-analysis/AUTOMATED_CURATION.md
Normal file
358
scientific-skills/neuropixels-analysis/AUTOMATED_CURATION.md
Normal file
@@ -0,0 +1,358 @@
|
||||
# Automated Curation Reference
|
||||
|
||||
Guide to automated spike sorting curation using Bombcell, UnitRefine, and other tools.
|
||||
|
||||
## Why Automated Curation?
|
||||
|
||||
Manual curation is:
|
||||
- **Slow**: Hours per recording session
|
||||
- **Subjective**: Inter-rater variability
|
||||
- **Non-reproducible**: Hard to standardize
|
||||
|
||||
Automated tools provide consistent, reproducible quality classification.
|
||||
|
||||
## Available Tools
|
||||
|
||||
| Tool | Classification | Language | Integration |
|
||||
|------|---------------|----------|-------------|
|
||||
| **Bombcell** | 4-class (single/multi/noise/non-somatic) | Python/MATLAB | SpikeInterface, Phy |
|
||||
| **UnitRefine** | Machine learning-based | Python | SpikeInterface |
|
||||
| **SpikeInterface QM** | Threshold-based | Python | Native |
|
||||
| **UnitMatch** | Cross-session tracking | Python/MATLAB | Kilosort, Bombcell |
|
||||
|
||||
## Bombcell
|
||||
|
||||
### Overview
|
||||
|
||||
Bombcell classifies units into 4 categories:
|
||||
1. **Single somatic units** - Well-isolated single neurons
|
||||
2. **Multi-unit activity (MUA)** - Mixed neuronal signals
|
||||
3. **Noise** - Non-neural artifacts
|
||||
4. **Non-somatic** - Axonal or dendritic signals
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install bombcell
|
||||
|
||||
# Or development version
|
||||
git clone https://github.com/Julie-Fabre/bombcell.git
|
||||
cd bombcell/py_bombcell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic Usage (Python)
|
||||
|
||||
```python
|
||||
import bombcell as bc
|
||||
|
||||
# Load sorted data (Kilosort output)
|
||||
kilosort_folder = '/path/to/kilosort/output'
|
||||
raw_data_path = '/path/to/recording.ap.bin'
|
||||
|
||||
# Run Bombcell
|
||||
results = bc.run_bombcell(
|
||||
kilosort_folder,
|
||||
raw_data_path,
|
||||
sample_rate=30000,
|
||||
n_channels=384
|
||||
)
|
||||
|
||||
# Get classifications
|
||||
unit_labels = results['unit_labels']
|
||||
# 'good' = single unit, 'mua' = multi-unit, 'noise' = noise
|
||||
```
|
||||
|
||||
### Integration with SpikeInterface
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
# After spike sorting
|
||||
sorting = si.run_sorter('kilosort4', recording, output_folder='ks4/')
|
||||
|
||||
# Create analyzer and compute required extensions
|
||||
analyzer = si.create_sorting_analyzer(sorting, recording, sparse=True)
|
||||
analyzer.compute('waveforms')
|
||||
analyzer.compute('templates')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
|
||||
# Export to Phy format (Bombcell can read this)
|
||||
si.export_to_phy(analyzer, output_folder='phy_export/')
|
||||
|
||||
# Run Bombcell on Phy export
|
||||
import bombcell as bc
|
||||
results = bc.run_bombcell_phy('phy_export/')
|
||||
```
|
||||
|
||||
### Bombcell Metrics
|
||||
|
||||
Bombcell computes specific metrics for classification:
|
||||
|
||||
| Metric | Description | Used For |
|
||||
|--------|-------------|----------|
|
||||
| `peak_trough_ratio` | Waveform shape | Somatic vs non-somatic |
|
||||
| `spatial_decay` | Amplitude across channels | Noise detection |
|
||||
| `refractory_period_violations` | ISI violations | Single vs multi |
|
||||
| `presence_ratio` | Temporal stability | Unit quality |
|
||||
| `waveform_duration` | Peak-to-trough time | Cell type |
|
||||
|
||||
### Custom Thresholds
|
||||
|
||||
```python
|
||||
# Customize classification thresholds
|
||||
custom_params = {
|
||||
'isi_threshold': 0.01, # ISI violation threshold
|
||||
'presence_threshold': 0.9, # Minimum presence ratio
|
||||
'amplitude_threshold': 20, # Minimum amplitude (μV)
|
||||
'spatial_decay_threshold': 40, # Spatial decay (μm)
|
||||
}
|
||||
|
||||
results = bc.run_bombcell(
|
||||
kilosort_folder,
|
||||
raw_data_path,
|
||||
**custom_params
|
||||
)
|
||||
```
|
||||
|
||||
## SpikeInterface Auto-Curation
|
||||
|
||||
### Threshold-Based Curation
|
||||
|
||||
```python
|
||||
# Compute quality metrics
|
||||
analyzer.compute('quality_metrics')
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Define curation function
|
||||
def auto_curate(qm):
|
||||
labels = {}
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
|
||||
# Classification logic
|
||||
if row['snr'] < 2 or row['presence_ratio'] < 0.5:
|
||||
labels[unit_id] = 'noise'
|
||||
elif row['isi_violations_ratio'] > 0.1:
|
||||
labels[unit_id] = 'mua'
|
||||
elif (row['snr'] > 5 and
|
||||
row['isi_violations_ratio'] < 0.01 and
|
||||
row['presence_ratio'] > 0.9):
|
||||
labels[unit_id] = 'good'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
return labels
|
||||
|
||||
unit_labels = auto_curate(qm)
|
||||
|
||||
# Filter by label
|
||||
good_unit_ids = [u for u, l in unit_labels.items() if l == 'good']
|
||||
sorting_curated = sorting.select_units(good_unit_ids)
|
||||
```
|
||||
|
||||
### Using SpikeInterface Curation Module
|
||||
|
||||
```python
|
||||
from spikeinterface.curation import (
|
||||
CurationSorting,
|
||||
MergeUnitsSorting,
|
||||
SplitUnitSorting
|
||||
)
|
||||
|
||||
# Wrap sorting for curation
|
||||
curation = CurationSorting(sorting)
|
||||
|
||||
# Remove noise units
|
||||
noise_units = qm[qm['snr'] < 2].index.tolist()
|
||||
curation.remove_units(noise_units)
|
||||
|
||||
# Merge similar units (based on template similarity)
|
||||
analyzer.compute('template_similarity')
|
||||
similarity = analyzer.get_extension('template_similarity').get_data()
|
||||
|
||||
# Find highly similar pairs
|
||||
import numpy as np
|
||||
threshold = 0.9
|
||||
similar_pairs = np.argwhere(similarity > threshold)
|
||||
# Merge pairs (careful - requires manual review)
|
||||
|
||||
# Get curated sorting
|
||||
sorting_curated = curation.to_sorting()
|
||||
```
|
||||
|
||||
## UnitMatch: Cross-Session Tracking
|
||||
|
||||
Track the same neurons across recording days.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install unitmatch
|
||||
# Or from source
|
||||
git clone https://github.com/EnnyvanBeest/UnitMatch.git
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
# After running Bombcell on multiple sessions
|
||||
session_folders = [
|
||||
'/path/to/session1/kilosort/',
|
||||
'/path/to/session2/kilosort/',
|
||||
'/path/to/session3/kilosort/',
|
||||
]
|
||||
|
||||
from unitmatch import UnitMatch
|
||||
|
||||
# Run UnitMatch
|
||||
um = UnitMatch(session_folders)
|
||||
um.run()
|
||||
|
||||
# Get matching results
|
||||
matches = um.get_matches()
|
||||
# Returns DataFrame with unit IDs matched across sessions
|
||||
|
||||
# Assign unique IDs
|
||||
unique_ids = um.get_unique_ids()
|
||||
```
|
||||
|
||||
### Integration with Workflow
|
||||
|
||||
```python
|
||||
# Typical workflow:
|
||||
# 1. Spike sort each session
|
||||
# 2. Run Bombcell for quality control
|
||||
# 3. Run UnitMatch for cross-session tracking
|
||||
|
||||
# Session 1
|
||||
sorting1 = si.run_sorter('kilosort4', rec1, output_folder='session1/ks4/')
|
||||
# Run Bombcell
|
||||
labels1 = bc.run_bombcell('session1/ks4/', raw1_path)
|
||||
|
||||
# Session 2
|
||||
sorting2 = si.run_sorter('kilosort4', rec2, output_folder='session2/ks4/')
|
||||
labels2 = bc.run_bombcell('session2/ks4/', raw2_path)
|
||||
|
||||
# Track units across sessions
|
||||
um = UnitMatch(['session1/ks4/', 'session2/ks4/'])
|
||||
matches = um.get_matches()
|
||||
```
|
||||
|
||||
## Semi-Automated Workflow
|
||||
|
||||
Combine automated and manual curation:
|
||||
|
||||
```python
|
||||
# Step 1: Automated classification
|
||||
analyzer.compute('quality_metrics')
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Auto-label obvious cases
|
||||
auto_labels = {}
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
if row['snr'] < 1.5:
|
||||
auto_labels[unit_id] = 'noise'
|
||||
elif row['snr'] > 8 and row['isi_violations_ratio'] < 0.005:
|
||||
auto_labels[unit_id] = 'good'
|
||||
else:
|
||||
auto_labels[unit_id] = 'needs_review'
|
||||
|
||||
# Step 2: Export uncertain units for manual review
|
||||
needs_review = [u for u, l in auto_labels.items() if l == 'needs_review']
|
||||
|
||||
# Export only uncertain units to Phy
|
||||
sorting_review = sorting.select_units(needs_review)
|
||||
analyzer_review = si.create_sorting_analyzer(sorting_review, recording)
|
||||
analyzer_review.compute('waveforms')
|
||||
analyzer_review.compute('templates')
|
||||
si.export_to_phy(analyzer_review, output_folder='phy_review/')
|
||||
|
||||
# Manual review in Phy: phy template-gui phy_review/params.py
|
||||
|
||||
# Step 3: Load manual labels and merge
|
||||
manual_labels = si.read_phy('phy_review/').get_property('quality')
|
||||
# Combine auto + manual labels for final result
|
||||
```
|
||||
|
||||
## Comparison of Methods
|
||||
|
||||
| Method | Pros | Cons |
|
||||
|--------|------|------|
|
||||
| **Manual (Phy)** | Gold standard, flexible | Slow, subjective |
|
||||
| **SpikeInterface QM** | Fast, reproducible | Simple thresholds only |
|
||||
| **Bombcell** | Multi-class, validated | Requires waveform extraction |
|
||||
| **UnitRefine** | ML-based, learns from data | Needs training data |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always visualize** - Don't blindly trust automated results
|
||||
2. **Document thresholds** - Record exact parameters used
|
||||
3. **Validate** - Compare automated vs manual on subset
|
||||
4. **Be conservative** - When in doubt, exclude the unit
|
||||
5. **Report methods** - Include curation criteria in publications
|
||||
|
||||
## Pipeline Example
|
||||
|
||||
```python
|
||||
def curate_sorting(sorting, recording, output_dir):
|
||||
"""Complete curation pipeline."""
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(sorting, recording, sparse=True,
|
||||
folder=f'{output_dir}/analyzer')
|
||||
|
||||
# Compute required extensions
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms')
|
||||
analyzer.compute('templates')
|
||||
analyzer.compute('noise_levels')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Auto-classify
|
||||
labels = {}
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
|
||||
if row['snr'] < 2:
|
||||
labels[unit_id] = 'noise'
|
||||
elif row['isi_violations_ratio'] > 0.1 or row['presence_ratio'] < 0.8:
|
||||
labels[unit_id] = 'mua'
|
||||
elif (row['snr'] > 5 and
|
||||
row['isi_violations_ratio'] < 0.01 and
|
||||
row['presence_ratio'] > 0.9 and
|
||||
row['amplitude_cutoff'] < 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
# Summary
|
||||
from collections import Counter
|
||||
print("Classification summary:")
|
||||
print(Counter(labels.values()))
|
||||
|
||||
# Save labels
|
||||
import json
|
||||
with open(f'{output_dir}/unit_labels.json', 'w') as f:
|
||||
json.dump(labels, f)
|
||||
|
||||
# Return good units
|
||||
good_ids = [u for u, l in labels.items() if l == 'good']
|
||||
return sorting.select_units(good_ids), labels
|
||||
|
||||
# Usage
|
||||
sorting_curated, labels = curate_sorting(sorting, recording, 'output/')
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [Bombcell GitHub](https://github.com/Julie-Fabre/bombcell)
|
||||
- [UnitMatch GitHub](https://github.com/EnnyvanBeest/UnitMatch)
|
||||
- [SpikeInterface Curation](https://spikeinterface.readthedocs.io/en/stable/modules/curation.html)
|
||||
- Fabre et al. (2023) "Bombcell: automated curation and cell classification"
|
||||
- van Beest et al. (2024) "UnitMatch: tracking neurons across days with high-density probes"
|
||||
21
scientific-skills/neuropixels-analysis/LICENSE.txt
Normal file
21
scientific-skills/neuropixels-analysis/LICENSE.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Shen Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
323
scientific-skills/neuropixels-analysis/MOTION_CORRECTION.md
Normal file
323
scientific-skills/neuropixels-analysis/MOTION_CORRECTION.md
Normal file
@@ -0,0 +1,323 @@
|
||||
# Motion/Drift Correction Reference
|
||||
|
||||
Mechanical drift during acute probe insertion is a major challenge for Neuropixels recordings. This guide covers detection, estimation, and correction of motion artifacts.
|
||||
|
||||
## Why Motion Correction Matters
|
||||
|
||||
- Neuropixels probes can drift 10-100+ μm during recording
|
||||
- Uncorrected drift leads to:
|
||||
- Units appearing/disappearing mid-recording
|
||||
- Waveform amplitude changes
|
||||
- Incorrect spike-unit assignments
|
||||
- Reduced unit yield
|
||||
|
||||
## Detection: Check Before Sorting
|
||||
|
||||
**Always visualize drift before running spike sorting!**
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
||||
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
|
||||
|
||||
# Preprocess first (don't whiten - affects peak localization)
|
||||
rec = si.highpass_filter(recording, freq_min=400.)
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
|
||||
# Detect peaks
|
||||
noise_levels = si.get_noise_levels(rec, return_in_uV=False)
|
||||
peaks = detect_peaks(
|
||||
rec,
|
||||
method='locally_exclusive',
|
||||
noise_levels=noise_levels,
|
||||
detect_threshold=5,
|
||||
radius_um=50.,
|
||||
n_jobs=8,
|
||||
chunk_duration='1s',
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
# Localize peaks
|
||||
peak_locations = localize_peaks(
|
||||
rec, peaks,
|
||||
method='center_of_mass',
|
||||
n_jobs=8,
|
||||
chunk_duration='1s'
|
||||
)
|
||||
|
||||
# Visualize drift
|
||||
si.plot_drift_raster_map(
|
||||
peaks=peaks,
|
||||
peak_locations=peak_locations,
|
||||
recording=rec,
|
||||
clim=(-200, 0) # Adjust color limits
|
||||
)
|
||||
```
|
||||
|
||||
### Interpreting Drift Plots
|
||||
|
||||
| Pattern | Interpretation | Action |
|
||||
|---------|---------------|--------|
|
||||
| Horizontal bands, stable | No significant drift | Skip correction |
|
||||
| Diagonal bands (slow) | Gradual settling drift | Use motion correction |
|
||||
| Rapid jumps | Brain pulsation or movement | Use non-rigid correction |
|
||||
| Chaotic patterns | Severe instability | Consider discarding segment |
|
||||
|
||||
## Motion Correction Methods
|
||||
|
||||
### Quick Correction (Recommended Start)
|
||||
|
||||
```python
|
||||
# Simple one-liner with preset
|
||||
rec_corrected = si.correct_motion(
|
||||
recording=rec,
|
||||
preset='nonrigid_fast_and_accurate'
|
||||
)
|
||||
```
|
||||
|
||||
### Available Presets
|
||||
|
||||
| Preset | Speed | Accuracy | Best For |
|
||||
|--------|-------|----------|----------|
|
||||
| `rigid_fast` | Fast | Low | Quick check, small drift |
|
||||
| `kilosort_like` | Medium | Good | Kilosort-compatible results |
|
||||
| `nonrigid_accurate` | Slow | High | Publication-quality |
|
||||
| `nonrigid_fast_and_accurate` | Medium | High | **Recommended default** |
|
||||
| `dredge` | Slow | Highest | Best results, complex drift |
|
||||
| `dredge_fast` | Medium | High | DREDge with less compute |
|
||||
|
||||
### Full Control Pipeline
|
||||
|
||||
```python
|
||||
from spikeinterface.sortingcomponents.motion import (
|
||||
estimate_motion,
|
||||
interpolate_motion
|
||||
)
|
||||
|
||||
# Step 1: Estimate motion
|
||||
motion, temporal_bins, spatial_bins = estimate_motion(
|
||||
rec,
|
||||
peaks,
|
||||
peak_locations,
|
||||
method='decentralized',
|
||||
direction='y',
|
||||
rigid=False, # Non-rigid for Neuropixels
|
||||
win_step_um=50, # Spatial window step
|
||||
win_sigma_um=150, # Spatial smoothing
|
||||
bin_s=2.0, # Temporal bin size
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
# Step 2: Visualize motion estimate
|
||||
si.plot_motion(
|
||||
motion,
|
||||
temporal_bins,
|
||||
spatial_bins,
|
||||
recording=rec
|
||||
)
|
||||
|
||||
# Step 3: Apply correction via interpolation
|
||||
rec_corrected = interpolate_motion(
|
||||
recording=rec,
|
||||
motion=motion,
|
||||
temporal_bins=temporal_bins,
|
||||
spatial_bins=spatial_bins,
|
||||
border_mode='force_extrapolate'
|
||||
)
|
||||
```
|
||||
|
||||
### Save Motion Estimate
|
||||
|
||||
```python
|
||||
# Save for later use
|
||||
import numpy as np
|
||||
np.savez('motion_estimate.npz',
|
||||
motion=motion,
|
||||
temporal_bins=temporal_bins,
|
||||
spatial_bins=spatial_bins)
|
||||
|
||||
# Load later
|
||||
data = np.load('motion_estimate.npz')
|
||||
motion = data['motion']
|
||||
temporal_bins = data['temporal_bins']
|
||||
spatial_bins = data['spatial_bins']
|
||||
```
|
||||
|
||||
## DREDge: State-of-the-Art Method
|
||||
|
||||
DREDge (Decentralized Registration of Electrophysiology Data) is currently the best-performing motion correction method.
|
||||
|
||||
### Using DREDge Preset
|
||||
|
||||
```python
|
||||
# AP-band motion estimation
|
||||
rec_corrected = si.correct_motion(rec, preset='dredge')
|
||||
|
||||
# Or compute explicitly
|
||||
motion, motion_info = si.compute_motion(
|
||||
rec,
|
||||
preset='dredge',
|
||||
output_motion_info=True,
|
||||
folder='motion_output/',
|
||||
**job_kwargs
|
||||
)
|
||||
```
|
||||
|
||||
### LFP-Based Motion Estimation
|
||||
|
||||
For very fast drift or when AP-band estimation fails:
|
||||
|
||||
```python
|
||||
# Load LFP stream
|
||||
lfp = si.read_spikeglx('/path/to/data', stream_name='imec0.lf')
|
||||
|
||||
# Estimate motion from LFP (faster, handles rapid drift)
|
||||
motion_lfp, motion_info = si.compute_motion(
|
||||
lfp,
|
||||
preset='dredge_lfp',
|
||||
output_motion_info=True
|
||||
)
|
||||
|
||||
# Apply to AP recording
|
||||
rec_corrected = interpolate_motion(
|
||||
recording=rec, # AP recording
|
||||
motion=motion_lfp,
|
||||
temporal_bins=motion_info['temporal_bins'],
|
||||
spatial_bins=motion_info['spatial_bins']
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with Spike Sorting
|
||||
|
||||
### Option 1: Pre-correction (Recommended)
|
||||
|
||||
```python
|
||||
# Correct before sorting
|
||||
rec_corrected = si.correct_motion(rec, preset='nonrigid_fast_and_accurate')
|
||||
|
||||
# Save corrected recording
|
||||
rec_corrected = rec_corrected.save(folder='preprocessed_motion_corrected/',
|
||||
format='binary', n_jobs=8)
|
||||
|
||||
# Run spike sorting on corrected data
|
||||
sorting = si.run_sorter('kilosort4', rec_corrected, output_folder='ks4/')
|
||||
```
|
||||
|
||||
### Option 2: Let Kilosort Handle It
|
||||
|
||||
Kilosort 2.5+ has built-in drift correction:
|
||||
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
rec, # Not motion corrected
|
||||
output_folder='ks4/',
|
||||
nblocks=5, # Non-rigid blocks for drift correction
|
||||
do_correction=True # Enable Kilosort's drift correction
|
||||
)
|
||||
```
|
||||
|
||||
### Option 3: Post-hoc Correction
|
||||
|
||||
```python
|
||||
# Sort first
|
||||
sorting = si.run_sorter('kilosort4', rec, output_folder='ks4/')
|
||||
|
||||
# Then estimate motion from sorted spikes
|
||||
# (More accurate as it uses actual spike times)
|
||||
from spikeinterface.sortingcomponents.motion import estimate_motion_from_sorting
|
||||
|
||||
motion = estimate_motion_from_sorting(sorting, rec)
|
||||
```
|
||||
|
||||
## Parameters Deep Dive
|
||||
|
||||
### Peak Detection
|
||||
|
||||
```python
|
||||
peaks = detect_peaks(
|
||||
rec,
|
||||
method='locally_exclusive', # Best for dense probes
|
||||
noise_levels=noise_levels,
|
||||
detect_threshold=5, # Lower = more peaks (noisier estimate)
|
||||
radius_um=50., # Exclusion radius
|
||||
exclude_sweep_ms=0.1, # Temporal exclusion
|
||||
)
|
||||
```
|
||||
|
||||
### Motion Estimation
|
||||
|
||||
```python
|
||||
motion = estimate_motion(
|
||||
rec, peaks, peak_locations,
|
||||
method='decentralized', # 'decentralized' or 'iterative_template'
|
||||
direction='y', # Along probe axis
|
||||
rigid=False, # False for non-rigid
|
||||
bin_s=2.0, # Temporal resolution (seconds)
|
||||
win_step_um=50, # Spatial window step
|
||||
win_sigma_um=150, # Spatial smoothing sigma
|
||||
margin_um=0, # Margin at probe edges
|
||||
win_scale_um=150, # Window scale for weights
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Over-correction (Wavy Patterns)
|
||||
|
||||
```python
|
||||
# Increase temporal smoothing
|
||||
motion = estimate_motion(..., bin_s=5.0) # Larger bins
|
||||
|
||||
# Or use rigid correction for small drift
|
||||
motion = estimate_motion(..., rigid=True)
|
||||
```
|
||||
|
||||
### Under-correction (Drift Remains)
|
||||
|
||||
```python
|
||||
# Decrease spatial window for finer non-rigid estimate
|
||||
motion = estimate_motion(..., win_step_um=25, win_sigma_um=75)
|
||||
|
||||
# Use more peaks
|
||||
peaks = detect_peaks(..., detect_threshold=4) # Lower threshold
|
||||
```
|
||||
|
||||
### Edge Artifacts
|
||||
|
||||
```python
|
||||
rec_corrected = interpolate_motion(
|
||||
rec, motion, temporal_bins, spatial_bins,
|
||||
border_mode='force_extrapolate', # or 'remove_channels'
|
||||
spatial_interpolation_method='kriging'
|
||||
)
|
||||
```
|
||||
|
||||
## Validation
|
||||
|
||||
After correction, re-visualize to confirm:
|
||||
|
||||
```python
|
||||
# Re-detect peaks on corrected recording
|
||||
peaks_corrected = detect_peaks(rec_corrected, ...)
|
||||
peak_locations_corrected = localize_peaks(rec_corrected, peaks_corrected, ...)
|
||||
|
||||
# Plot before/after comparison
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
||||
|
||||
# Before
|
||||
si.plot_drift_raster_map(peaks, peak_locations, rec, ax=axes[0])
|
||||
axes[0].set_title('Before Correction')
|
||||
|
||||
# After
|
||||
si.plot_drift_raster_map(peaks_corrected, peak_locations_corrected,
|
||||
rec_corrected, ax=axes[1])
|
||||
axes[1].set_title('After Correction')
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [SpikeInterface Motion Correction Docs](https://spikeinterface.readthedocs.io/en/stable/modules/motion_correction.html)
|
||||
- [Handle Drift Tutorial](https://spikeinterface.readthedocs.io/en/stable/how_to/handle_drift.html)
|
||||
- [DREDge GitHub](https://github.com/evarol/DREDge)
|
||||
- Windolf et al. (2023) "DREDge: robust motion correction for high-density extracellular recordings"
|
||||
273
scientific-skills/neuropixels-analysis/PREPROCESSING.md
Normal file
273
scientific-skills/neuropixels-analysis/PREPROCESSING.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Neuropixels Preprocessing Reference
|
||||
|
||||
Comprehensive preprocessing techniques for Neuropixels neural recordings.
|
||||
|
||||
## Standard Preprocessing Pipeline
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
# Load raw data
|
||||
recording = si.read_spikeglx('/path/to/data', stream_id='imec0.ap')
|
||||
|
||||
# 1. Phase shift correction (for Neuropixels 1.0)
|
||||
rec = si.phase_shift(recording)
|
||||
|
||||
# 2. Bandpass filter for spike detection
|
||||
rec = si.bandpass_filter(rec, freq_min=300, freq_max=6000)
|
||||
|
||||
# 3. Common median reference (removes correlated noise)
|
||||
rec = si.common_reference(rec, reference='global', operator='median')
|
||||
|
||||
# 4. Remove bad channels (optional)
|
||||
rec = si.remove_bad_channels(rec, bad_channel_ids=bad_channels)
|
||||
```
|
||||
|
||||
## Filtering Options
|
||||
|
||||
### Bandpass Filter
|
||||
```python
|
||||
# Standard AP band
|
||||
rec = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
|
||||
|
||||
# Wider band (preserve more waveform shape)
|
||||
rec = si.bandpass_filter(recording, freq_min=150, freq_max=7500)
|
||||
|
||||
# Filter parameters
|
||||
rec = si.bandpass_filter(
|
||||
recording,
|
||||
freq_min=300,
|
||||
freq_max=6000,
|
||||
filter_order=5,
|
||||
ftype='butter', # 'butter', 'bessel', or 'cheby1'
|
||||
margin_ms=5.0 # Prevent edge artifacts
|
||||
)
|
||||
```
|
||||
|
||||
### Highpass Filter Only
|
||||
```python
|
||||
rec = si.highpass_filter(recording, freq_min=300)
|
||||
```
|
||||
|
||||
### Notch Filter (Remove Line Noise)
|
||||
```python
|
||||
# Remove 60Hz and harmonics
|
||||
rec = si.notch_filter(recording, freq=60, q=30)
|
||||
rec = si.notch_filter(rec, freq=120, q=30)
|
||||
rec = si.notch_filter(rec, freq=180, q=30)
|
||||
```
|
||||
|
||||
## Reference Schemes
|
||||
|
||||
### Common Median Reference (Recommended)
|
||||
```python
|
||||
# Global median reference
|
||||
rec = si.common_reference(recording, reference='global', operator='median')
|
||||
|
||||
# Per-shank reference (multi-shank probes)
|
||||
rec = si.common_reference(recording, reference='global', operator='median',
|
||||
groups=recording.get_channel_groups())
|
||||
```
|
||||
|
||||
### Common Average Reference
|
||||
```python
|
||||
rec = si.common_reference(recording, reference='global', operator='average')
|
||||
```
|
||||
|
||||
### Local Reference
|
||||
```python
|
||||
# Reference by local groups of channels
|
||||
rec = si.common_reference(recording, reference='local', local_radius=(30, 100))
|
||||
```
|
||||
|
||||
## Bad Channel Detection & Removal
|
||||
|
||||
### Automatic Detection
|
||||
```python
|
||||
# Detect bad channels
|
||||
bad_channel_ids, channel_labels = si.detect_bad_channels(
|
||||
recording,
|
||||
method='coherence+psd',
|
||||
dead_channel_threshold=-0.5,
|
||||
noisy_channel_threshold=1.0,
|
||||
outside_channel_threshold=-0.3,
|
||||
n_neighbors=11
|
||||
)
|
||||
|
||||
print(f"Bad channels: {bad_channel_ids}")
|
||||
print(f"Labels: {dict(zip(bad_channel_ids, channel_labels))}")
|
||||
```
|
||||
|
||||
### Remove Bad Channels
|
||||
```python
|
||||
rec_clean = si.remove_bad_channels(recording, bad_channel_ids=bad_channel_ids)
|
||||
```
|
||||
|
||||
### Interpolate Bad Channels
|
||||
```python
|
||||
rec_interp = si.interpolate_bad_channels(recording, bad_channel_ids=bad_channel_ids)
|
||||
```
|
||||
|
||||
## Motion Correction
|
||||
|
||||
### Estimate Motion
|
||||
```python
|
||||
# Estimate motion (drift)
|
||||
motion, temporal_bins, spatial_bins = si.estimate_motion(
|
||||
recording,
|
||||
method='decentralized',
|
||||
rigid=False, # Non-rigid motion estimation
|
||||
win_step_um=50, # Spatial window step
|
||||
win_sigma_um=150, # Spatial window sigma
|
||||
progress_bar=True
|
||||
)
|
||||
```
|
||||
|
||||
### Apply Motion Correction
|
||||
```python
|
||||
rec_corrected = si.correct_motion(
|
||||
recording,
|
||||
motion,
|
||||
temporal_bins,
|
||||
spatial_bins,
|
||||
interpolate_motion_border=True
|
||||
)
|
||||
```
|
||||
|
||||
### Motion Visualization
|
||||
```python
|
||||
si.plot_motion(motion, temporal_bins, spatial_bins)
|
||||
```
|
||||
|
||||
## Probe-Specific Processing
|
||||
|
||||
### Neuropixels 1.0
|
||||
```python
|
||||
# Phase shift correction (different ADC per channel)
|
||||
rec = si.phase_shift(recording)
|
||||
|
||||
# Then standard pipeline
|
||||
rec = si.bandpass_filter(rec, freq_min=300, freq_max=6000)
|
||||
rec = si.common_reference(rec, reference='global', operator='median')
|
||||
```
|
||||
|
||||
### Neuropixels 2.0
|
||||
```python
|
||||
# No phase shift needed (single ADC)
|
||||
rec = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
|
||||
rec = si.common_reference(rec, reference='global', operator='median')
|
||||
```
|
||||
|
||||
### Multi-Shank (Neuropixels 2.0 4-shank)
|
||||
```python
|
||||
# Reference per shank
|
||||
groups = recording.get_channel_groups() # Returns shank assignments
|
||||
rec = si.common_reference(recording, reference='global', operator='median', groups=groups)
|
||||
```
|
||||
|
||||
## Whitening
|
||||
|
||||
```python
|
||||
# Whiten data (decorrelate channels)
|
||||
rec_whitened = si.whiten(recording, mode='local', local_radius_um=100)
|
||||
|
||||
# Global whitening
|
||||
rec_whitened = si.whiten(recording, mode='global')
|
||||
```
|
||||
|
||||
## Artifact Removal
|
||||
|
||||
### Remove Stimulation Artifacts
|
||||
```python
|
||||
# Define artifact times (in samples)
|
||||
triggers = [10000, 20000, 30000] # Sample indices
|
||||
|
||||
rec = si.remove_artifacts(
|
||||
recording,
|
||||
triggers,
|
||||
ms_before=0.5,
|
||||
ms_after=3.0,
|
||||
mode='cubic' # 'zeros', 'linear', 'cubic'
|
||||
)
|
||||
```
|
||||
|
||||
### Blank Saturation Periods
|
||||
```python
|
||||
rec = si.blank_staturation(recording, threshold=0.95, fill_value=0)
|
||||
```
|
||||
|
||||
## Saving Preprocessed Data
|
||||
|
||||
### Binary Format (Recommended)
|
||||
```python
|
||||
rec_preprocessed.save(folder='preprocessed/', format='binary', n_jobs=4)
|
||||
```
|
||||
|
||||
### Zarr Format (Compressed)
|
||||
```python
|
||||
rec_preprocessed.save(folder='preprocessed.zarr', format='zarr')
|
||||
```
|
||||
|
||||
### Save as Recording Extractor
|
||||
```python
|
||||
# Save for later use
|
||||
rec_preprocessed.save(folder='preprocessed/', format='binary')
|
||||
|
||||
# Load later
|
||||
rec_loaded = si.load_extractor('preprocessed/')
|
||||
```
|
||||
|
||||
## Complete Pipeline Example
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
def preprocess_neuropixels(data_path, output_path):
|
||||
"""Standard Neuropixels preprocessing pipeline."""
|
||||
|
||||
# Load data
|
||||
recording = si.read_spikeglx(data_path, stream_id='imec0.ap')
|
||||
print(f"Loaded: {recording.get_num_channels()} channels, "
|
||||
f"{recording.get_total_duration():.1f}s")
|
||||
|
||||
# Phase shift (NP 1.0 only)
|
||||
rec = si.phase_shift(recording)
|
||||
|
||||
# Filter
|
||||
rec = si.bandpass_filter(rec, freq_min=300, freq_max=6000)
|
||||
|
||||
# Detect and remove bad channels
|
||||
bad_ids, _ = si.detect_bad_channels(rec)
|
||||
if len(bad_ids) > 0:
|
||||
print(f"Removing {len(bad_ids)} bad channels: {bad_ids}")
|
||||
rec = si.interpolate_bad_channels(rec, bad_ids)
|
||||
|
||||
# Common reference
|
||||
rec = si.common_reference(rec, reference='global', operator='median')
|
||||
|
||||
# Save
|
||||
rec.save(folder=output_path, format='binary', n_jobs=4)
|
||||
print(f"Saved to: {output_path}")
|
||||
|
||||
return rec
|
||||
|
||||
# Usage
|
||||
rec_preprocessed = preprocess_neuropixels(
|
||||
'/path/to/spikeglx/data',
|
||||
'/path/to/preprocessed'
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
```python
|
||||
# Use parallel processing
|
||||
rec.save(folder='output/', n_jobs=-1) # Use all cores
|
||||
|
||||
# Use job kwargs for memory management
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
rec.save(folder='output/', **job_kwargs)
|
||||
|
||||
# Set global job kwargs
|
||||
si.set_global_job_kwargs(n_jobs=8, chunk_duration='1s')
|
||||
```
|
||||
359
scientific-skills/neuropixels-analysis/QUALITY_METRICS.md
Normal file
359
scientific-skills/neuropixels-analysis/QUALITY_METRICS.md
Normal file
@@ -0,0 +1,359 @@
|
||||
# Quality Metrics Reference
|
||||
|
||||
Comprehensive guide to unit quality assessment using SpikeInterface metrics and Allen/IBL standards.
|
||||
|
||||
## Overview
|
||||
|
||||
Quality metrics assess three aspects of sorted units:
|
||||
|
||||
| Category | Question | Key Metrics |
|
||||
|----------|----------|-------------|
|
||||
| **Contamination** (Type I) | Are spikes from multiple neurons? | ISI violations, SNR |
|
||||
| **Completeness** (Type II) | Are we missing spikes? | Amplitude cutoff, presence ratio |
|
||||
| **Stability** | Is the unit stable over time? | Drift metrics, amplitude CV |
|
||||
|
||||
## Computing Quality Metrics
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
# Create analyzer with computed waveforms
|
||||
analyzer = si.create_sorting_analyzer(sorting, recording, sparse=True)
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0)
|
||||
analyzer.compute('templates')
|
||||
analyzer.compute('noise_levels')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('principal_components', n_components=5)
|
||||
|
||||
# Compute all quality metrics
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
# Or compute specific metrics
|
||||
analyzer.compute('quality_metrics', metric_names=[
|
||||
'firing_rate', 'snr', 'isi_violations_ratio',
|
||||
'presence_ratio', 'amplitude_cutoff'
|
||||
])
|
||||
|
||||
# Get results
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
print(qm.columns.tolist()) # Available metrics
|
||||
```
|
||||
|
||||
## Metric Definitions & Thresholds
|
||||
|
||||
### Contamination Metrics
|
||||
|
||||
#### ISI Violations Ratio
|
||||
Fraction of spikes violating refractory period. All neurons have a ~1.5ms refractory period.
|
||||
|
||||
```python
|
||||
# Compute with custom refractory period
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['isi_violations_ratio'],
|
||||
isi_threshold_ms=1.5,
|
||||
min_isi_ms=0.0)
|
||||
```
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| < 0.01 | Excellent (well-isolated single unit) |
|
||||
| 0.01 - 0.1 | Good (minor contamination) |
|
||||
| 0.1 - 0.5 | Moderate (multi-unit activity likely) |
|
||||
| > 0.5 | Poor (likely multi-unit) |
|
||||
|
||||
**Reference:** Hill et al. (2011) J Neurosci 31:8699-8705
|
||||
|
||||
#### Signal-to-Noise Ratio (SNR)
|
||||
Ratio of peak waveform amplitude to background noise.
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics', metric_names=['snr'])
|
||||
```
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| > 10 | Excellent |
|
||||
| 5 - 10 | Good |
|
||||
| 2 - 5 | Acceptable |
|
||||
| < 2 | Poor (may be noise) |
|
||||
|
||||
#### Isolation Distance
|
||||
Mahalanobis distance to nearest cluster in PCA space.
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['isolation_distance'],
|
||||
n_neighbors=4)
|
||||
```
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| > 50 | Well-isolated |
|
||||
| 20 - 50 | Moderately isolated |
|
||||
| < 20 | Poorly isolated |
|
||||
|
||||
#### L-ratio
|
||||
Contamination measure based on Mahalanobis distances.
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| < 0.05 | Well-isolated |
|
||||
| 0.05 - 0.1 | Acceptable |
|
||||
| > 0.1 | Contaminated |
|
||||
|
||||
#### D-prime
|
||||
Discriminability between unit and nearest neighbor.
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| > 8 | Excellent separation |
|
||||
| 5 - 8 | Good separation |
|
||||
| < 5 | Poor separation |
|
||||
|
||||
### Completeness Metrics
|
||||
|
||||
#### Amplitude Cutoff
|
||||
Estimates fraction of spikes below detection threshold.
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['amplitude_cutoff'],
|
||||
peak_sign='neg') # 'neg', 'pos', or 'both'
|
||||
```
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| < 0.01 | Excellent (nearly complete) |
|
||||
| 0.01 - 0.1 | Good |
|
||||
| 0.1 - 0.2 | Moderate (some missed spikes) |
|
||||
| > 0.2 | Poor (many missed spikes) |
|
||||
|
||||
**For precise timing analyses:** Use < 0.01
|
||||
|
||||
#### Presence Ratio
|
||||
Fraction of recording time with detected spikes.
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['presence_ratio'],
|
||||
bin_duration_s=60) # 1-minute bins
|
||||
```
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| > 0.99 | Excellent |
|
||||
| 0.9 - 0.99 | Good |
|
||||
| 0.8 - 0.9 | Acceptable |
|
||||
| < 0.8 | Unit may have drifted out |
|
||||
|
||||
### Stability Metrics
|
||||
|
||||
#### Drift Metrics
|
||||
Measure unit movement over time.
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['drift_ptp', 'drift_std', 'drift_mad'])
|
||||
```
|
||||
|
||||
| Metric | Description | Good Value |
|
||||
|--------|-------------|------------|
|
||||
| `drift_ptp` | Peak-to-peak drift (μm) | < 40 |
|
||||
| `drift_std` | Standard deviation of drift | < 10 |
|
||||
| `drift_mad` | Median absolute deviation | < 10 |
|
||||
|
||||
#### Amplitude CV
|
||||
Coefficient of variation of spike amplitudes.
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| < 0.25 | Very stable |
|
||||
| 0.25 - 0.5 | Acceptable |
|
||||
| > 0.5 | Unstable (drift or contamination) |
|
||||
|
||||
### Cluster Quality Metrics
|
||||
|
||||
#### Silhouette Score
|
||||
Cluster cohesion vs separation (-1 to 1).
|
||||
|
||||
| Value | Interpretation |
|
||||
|-------|---------------|
|
||||
| > 0.5 | Well-defined cluster |
|
||||
| 0.25 - 0.5 | Moderate |
|
||||
| < 0.25 | Overlapping clusters |
|
||||
|
||||
#### Nearest-Neighbor Metrics
|
||||
|
||||
```python
|
||||
analyzer.compute('quality_metrics',
|
||||
metric_names=['nn_hit_rate', 'nn_miss_rate'],
|
||||
n_neighbors=4)
|
||||
```
|
||||
|
||||
| Metric | Description | Good Value |
|
||||
|--------|-------------|------------|
|
||||
| `nn_hit_rate` | Fraction of spikes with same-unit neighbors | > 0.9 |
|
||||
| `nn_miss_rate` | Fraction of spikes with other-unit neighbors | < 0.1 |
|
||||
|
||||
## Standard Filtering Criteria
|
||||
|
||||
### Allen Institute Defaults
|
||||
|
||||
```python
|
||||
# Allen Visual Coding / Behavior defaults
|
||||
allen_query = """
|
||||
presence_ratio > 0.95 and
|
||||
isi_violations_ratio < 0.5 and
|
||||
amplitude_cutoff < 0.1
|
||||
"""
|
||||
good_units = qm.query(allen_query).index.tolist()
|
||||
```
|
||||
|
||||
### IBL Standards
|
||||
|
||||
```python
|
||||
# IBL reproducible ephys criteria
|
||||
ibl_query = """
|
||||
presence_ratio > 0.9 and
|
||||
isi_violations_ratio < 0.1 and
|
||||
amplitude_cutoff < 0.1 and
|
||||
firing_rate > 0.1
|
||||
"""
|
||||
good_units = qm.query(ibl_query).index.tolist()
|
||||
```
|
||||
|
||||
### Strict Single-Unit Criteria
|
||||
|
||||
```python
|
||||
# For precise timing / spike-timing analyses
|
||||
strict_query = """
|
||||
snr > 5 and
|
||||
presence_ratio > 0.99 and
|
||||
isi_violations_ratio < 0.01 and
|
||||
amplitude_cutoff < 0.01 and
|
||||
isolation_distance > 20 and
|
||||
drift_ptp < 40
|
||||
"""
|
||||
single_units = qm.query(strict_query).index.tolist()
|
||||
```
|
||||
|
||||
### Multi-Unit Activity (MUA)
|
||||
|
||||
```python
|
||||
# Include multi-unit activity
|
||||
mua_query = """
|
||||
snr > 2 and
|
||||
presence_ratio > 0.5 and
|
||||
isi_violations_ratio < 1.0
|
||||
"""
|
||||
all_units = qm.query(mua_query).index.tolist()
|
||||
```
|
||||
|
||||
## Visualization
|
||||
|
||||
### Quality Metric Summary
|
||||
|
||||
```python
|
||||
# Plot all metrics
|
||||
si.plot_quality_metrics(analyzer)
|
||||
```
|
||||
|
||||
### Individual Metric Distributions
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
|
||||
metrics = ['snr', 'isi_violations_ratio', 'presence_ratio',
|
||||
'amplitude_cutoff', 'firing_rate', 'drift_ptp']
|
||||
|
||||
for ax, metric in zip(axes.flat, metrics):
|
||||
ax.hist(qm[metric].dropna(), bins=50, edgecolor='black')
|
||||
ax.set_xlabel(metric)
|
||||
ax.set_ylabel('Count')
|
||||
# Add threshold line
|
||||
if metric == 'snr':
|
||||
ax.axvline(5, color='r', linestyle='--', label='threshold')
|
||||
elif metric == 'isi_violations_ratio':
|
||||
ax.axvline(0.01, color='r', linestyle='--')
|
||||
elif metric == 'presence_ratio':
|
||||
ax.axvline(0.9, color='r', linestyle='--')
|
||||
|
||||
plt.tight_layout()
|
||||
```
|
||||
|
||||
### Unit Quality Summary
|
||||
|
||||
```python
|
||||
# Comprehensive unit summary plot
|
||||
si.plot_unit_summary(analyzer, unit_id=0)
|
||||
```
|
||||
|
||||
### Quality vs Firing Rate
|
||||
|
||||
```python
|
||||
fig, ax = plt.subplots()
|
||||
scatter = ax.scatter(qm['firing_rate'], qm['snr'],
|
||||
c=qm['isi_violations_ratio'],
|
||||
cmap='RdYlGn_r', alpha=0.6)
|
||||
ax.set_xlabel('Firing Rate (Hz)')
|
||||
ax.set_ylabel('SNR')
|
||||
plt.colorbar(scatter, label='ISI Violations')
|
||||
ax.set_xscale('log')
|
||||
```
|
||||
|
||||
## Compute All Metrics at Once
|
||||
|
||||
```python
|
||||
# Full quality metrics computation
|
||||
all_metric_names = [
|
||||
# Firing properties
|
||||
'firing_rate', 'presence_ratio',
|
||||
# Waveform
|
||||
'snr', 'amplitude_cutoff', 'amplitude_cv_median', 'amplitude_cv_range',
|
||||
# ISI
|
||||
'isi_violations_ratio', 'isi_violations_count',
|
||||
# Drift
|
||||
'drift_ptp', 'drift_std', 'drift_mad',
|
||||
# Isolation (require PCA)
|
||||
'isolation_distance', 'l_ratio', 'd_prime',
|
||||
# Nearest neighbor (require PCA)
|
||||
'nn_hit_rate', 'nn_miss_rate',
|
||||
# Cluster quality
|
||||
'silhouette_score',
|
||||
# Synchrony
|
||||
'sync_spike_2', 'sync_spike_4', 'sync_spike_8',
|
||||
]
|
||||
|
||||
# Compute PCA first (required for some metrics)
|
||||
analyzer.compute('principal_components', n_components=5)
|
||||
|
||||
# Compute metrics
|
||||
analyzer.compute('quality_metrics', metric_names=all_metric_names)
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Save to CSV
|
||||
qm.to_csv('quality_metrics.csv')
|
||||
```
|
||||
|
||||
## Custom Metrics
|
||||
|
||||
```python
|
||||
from spikeinterface.qualitymetrics import compute_firing_rates, compute_snrs
|
||||
|
||||
# Compute individual metrics
|
||||
firing_rates = compute_firing_rates(sorting)
|
||||
snrs = compute_snrs(analyzer)
|
||||
|
||||
# Add custom metric to DataFrame
|
||||
qm['custom_score'] = qm['snr'] * qm['presence_ratio'] / (qm['isi_violations_ratio'] + 0.001)
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [SpikeInterface Quality Metrics](https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html)
|
||||
- [Allen Institute ecephys_quality_metrics](https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html)
|
||||
- Hill et al. (2011) "Quality metrics to accompany spike sorting of extracellular signals"
|
||||
- Siegle et al. (2021) "Survey of spiking in the mouse visual system reveals functional hierarchy"
|
||||
344
scientific-skills/neuropixels-analysis/SKILL.md
Normal file
344
scientific-skills/neuropixels-analysis/SKILL.md
Normal file
@@ -0,0 +1,344 @@
|
||||
---
|
||||
name: neuropixels-analysis
|
||||
description: "Neuropixels neural recording analysis. Load SpikeGLX/OpenEphys data, preprocess, motion correction, Kilosort4 spike sorting, quality metrics, Allen/IBL curation, AI-assisted visual analysis, for Neuropixels 1.0/2.0 extracellular electrophysiology. Use when working with neural recordings, spike sorting, extracellular electrophysiology, or when the user mentions Neuropixels, SpikeGLX, Open Ephys, Kilosort, quality metrics, or unit curation."
|
||||
---
|
||||
|
||||
# Neuropixels Data Analysis
|
||||
|
||||
## Overview
|
||||
|
||||
Comprehensive toolkit for analyzing Neuropixels high-density neural recordings using current best practices from SpikeInterface, Allen Institute, and International Brain Laboratory (IBL). Supports the full workflow from raw data to publication-ready curated units.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be used when:
|
||||
- Working with Neuropixels recordings (.ap.bin, .lf.bin, .meta files)
|
||||
- Loading data from SpikeGLX, Open Ephys, or NWB formats
|
||||
- Preprocessing neural recordings (filtering, CAR, bad channel detection)
|
||||
- Detecting and correcting motion/drift in recordings
|
||||
- Running spike sorting (Kilosort4, SpykingCircus2, Mountainsort5)
|
||||
- Computing quality metrics (SNR, ISI violations, presence ratio)
|
||||
- Curating units using Allen/IBL criteria
|
||||
- Creating visualizations of neural data
|
||||
- Exporting results to Phy or NWB
|
||||
|
||||
## Supported Hardware & Formats
|
||||
|
||||
| Probe | Electrodes | Channels | Notes |
|
||||
|-------|-----------|----------|-------|
|
||||
| Neuropixels 1.0 | 960 | 384 | Requires phase_shift correction |
|
||||
| Neuropixels 2.0 (single) | 1280 | 384 | Denser geometry |
|
||||
| Neuropixels 2.0 (4-shank) | 5120 | 384 | Multi-region recording |
|
||||
|
||||
| Format | Extension | Reader |
|
||||
|--------|-----------|--------|
|
||||
| SpikeGLX | `.ap.bin`, `.lf.bin`, `.meta` | `si.read_spikeglx()` |
|
||||
| Open Ephys | `.continuous`, `.oebin` | `si.read_openephys()` |
|
||||
| NWB | `.nwb` | `si.read_nwb()` |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Import and Setup
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# Configure parallel processing
|
||||
job_kwargs = dict(n_jobs=-1, chunk_duration='1s', progress_bar=True)
|
||||
```
|
||||
|
||||
### Loading Data
|
||||
|
||||
```python
|
||||
# SpikeGLX (most common)
|
||||
recording = si.read_spikeglx('/path/to/data', stream_id='imec0.ap')
|
||||
|
||||
# Open Ephys (common for many labs)
|
||||
recording = si.read_openephys('/path/to/Record_Node_101/')
|
||||
|
||||
# Check available streams
|
||||
streams, ids = si.get_neo_streams('spikeglx', '/path/to/data')
|
||||
print(streams) # ['imec0.ap', 'imec0.lf', 'nidq']
|
||||
|
||||
# For testing with subset of data
|
||||
recording = recording.frame_slice(0, int(60 * recording.get_sampling_frequency()))
|
||||
```
|
||||
|
||||
### Complete Pipeline (One Command)
|
||||
|
||||
```python
|
||||
# Run full analysis pipeline
|
||||
results = npa.run_pipeline(
|
||||
recording,
|
||||
output_dir='output/',
|
||||
sorter='kilosort4',
|
||||
curation_method='allen',
|
||||
)
|
||||
|
||||
# Access results
|
||||
sorting = results['sorting']
|
||||
metrics = results['metrics']
|
||||
labels = results['labels']
|
||||
```
|
||||
|
||||
## Standard Analysis Workflow
|
||||
|
||||
### 1. Preprocessing
|
||||
|
||||
```python
|
||||
# Recommended preprocessing chain
|
||||
rec = si.highpass_filter(recording, freq_min=400)
|
||||
rec = si.phase_shift(rec) # Required for Neuropixels 1.0
|
||||
bad_ids, _ = si.detect_bad_channels(rec)
|
||||
rec = rec.remove_channels(bad_ids)
|
||||
rec = si.common_reference(rec, operator='median')
|
||||
|
||||
# Or use our wrapper
|
||||
rec = npa.preprocess(recording)
|
||||
```
|
||||
|
||||
### 2. Check and Correct Drift
|
||||
|
||||
```python
|
||||
# Check for drift (always do this!)
|
||||
motion_info = npa.estimate_motion(rec, preset='kilosort_like')
|
||||
npa.plot_drift(rec, motion_info, output='drift_map.png')
|
||||
|
||||
# Apply correction if needed
|
||||
if motion_info['motion'].max() > 10: # microns
|
||||
rec = npa.correct_motion(rec, preset='nonrigid_accurate')
|
||||
```
|
||||
|
||||
### 3. Spike Sorting
|
||||
|
||||
```python
|
||||
# Kilosort4 (recommended, requires GPU)
|
||||
sorting = si.run_sorter('kilosort4', rec, folder='ks4_output')
|
||||
|
||||
# CPU alternatives
|
||||
sorting = si.run_sorter('tridesclous2', rec, folder='tdc2_output')
|
||||
sorting = si.run_sorter('spykingcircus2', rec, folder='sc2_output')
|
||||
sorting = si.run_sorter('mountainsort5', rec, folder='ms5_output')
|
||||
|
||||
# Check available sorters
|
||||
print(si.installed_sorters())
|
||||
```
|
||||
|
||||
### 4. Postprocessing
|
||||
|
||||
```python
|
||||
# Create analyzer and compute all extensions
|
||||
analyzer = si.create_sorting_analyzer(sorting, rec, sparse=True)
|
||||
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
metrics = analyzer.get_extension('quality_metrics').get_data()
|
||||
```
|
||||
|
||||
### 5. Curation
|
||||
|
||||
```python
|
||||
# Allen Institute criteria (conservative)
|
||||
good_units = metrics.query("""
|
||||
presence_ratio > 0.9 and
|
||||
isi_violations_ratio < 0.5 and
|
||||
amplitude_cutoff < 0.1
|
||||
""").index.tolist()
|
||||
|
||||
# Or use automated curation
|
||||
labels = npa.curate(metrics, method='allen') # 'allen', 'ibl', 'strict'
|
||||
```
|
||||
|
||||
### 6. AI-Assisted Curation (For Uncertain Units)
|
||||
|
||||
When using this skill with Claude Code, Claude can directly analyze waveform plots and provide expert curation decisions. For programmatic API access:
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Setup API client
|
||||
client = Anthropic()
|
||||
|
||||
# Analyze uncertain units visually
|
||||
uncertain = metrics.query('snr > 3 and snr < 8').index.tolist()
|
||||
|
||||
for unit_id in uncertain:
|
||||
result = npa.analyze_unit_visually(analyzer, unit_id, api_client=client)
|
||||
print(f"Unit {unit_id}: {result['classification']}")
|
||||
print(f" Reasoning: {result['reasoning'][:100]}...")
|
||||
```
|
||||
|
||||
**Claude Code Integration**: When running within Claude Code, ask Claude to examine waveform/correlogram plots directly - no API setup required.
|
||||
|
||||
### 7. Generate Analysis Report
|
||||
|
||||
```python
|
||||
# Generate comprehensive HTML report with visualizations
|
||||
report_dir = npa.generate_analysis_report(results, 'output/')
|
||||
# Opens report.html with summary stats, figures, and unit table
|
||||
|
||||
# Print formatted summary to console
|
||||
npa.print_analysis_summary(results)
|
||||
```
|
||||
|
||||
### 8. Export Results
|
||||
|
||||
```python
|
||||
# Export to Phy for manual review
|
||||
si.export_to_phy(analyzer, output_folder='phy_export/',
|
||||
compute_pc_features=True, compute_amplitudes=True)
|
||||
|
||||
# Export to NWB
|
||||
from spikeinterface.exporters import export_to_nwb
|
||||
export_to_nwb(rec, sorting, 'output.nwb')
|
||||
|
||||
# Save quality metrics
|
||||
metrics.to_csv('quality_metrics.csv')
|
||||
```
|
||||
|
||||
## Common Pitfalls and Best Practices
|
||||
|
||||
1. **Always check drift** before spike sorting - drift > 10μm significantly impacts quality
|
||||
2. **Use phase_shift** for Neuropixels 1.0 probes (not needed for 2.0)
|
||||
3. **Save preprocessed data** to avoid recomputing - use `rec.save(folder='preprocessed/')`
|
||||
4. **Use GPU** for Kilosort4 - it's 10-50x faster than CPU alternatives
|
||||
5. **Review uncertain units manually** - automated curation is a starting point
|
||||
6. **Combine metrics with AI** - use metrics for clear cases, AI for borderline units
|
||||
7. **Document your thresholds** - different analyses may need different criteria
|
||||
8. **Export to Phy** for critical experiments - human oversight is valuable
|
||||
|
||||
## Key Parameters to Adjust
|
||||
|
||||
### Preprocessing
|
||||
- `freq_min`: Highpass cutoff (300-400 Hz typical)
|
||||
- `detect_threshold`: Bad channel detection sensitivity
|
||||
|
||||
### Motion Correction
|
||||
- `preset`: 'kilosort_like' (fast) or 'nonrigid_accurate' (better for severe drift)
|
||||
|
||||
### Spike Sorting (Kilosort4)
|
||||
- `batch_size`: Samples per batch (30000 default)
|
||||
- `nblocks`: Number of drift blocks (increase for long recordings)
|
||||
- `Th_learned`: Detection threshold (lower = more spikes)
|
||||
|
||||
### Quality Metrics
|
||||
- `snr_threshold`: Signal-to-noise cutoff (3-5 typical)
|
||||
- `isi_violations_ratio`: Refractory violations (0.01-0.5)
|
||||
- `presence_ratio`: Recording coverage (0.5-0.95)
|
||||
|
||||
## Bundled Resources
|
||||
|
||||
### scripts/preprocess_recording.py
|
||||
Automated preprocessing script:
|
||||
```bash
|
||||
python scripts/preprocess_recording.py /path/to/data --output preprocessed/
|
||||
```
|
||||
|
||||
### scripts/run_sorting.py
|
||||
Run spike sorting:
|
||||
```bash
|
||||
python scripts/run_sorting.py preprocessed/ --sorter kilosort4 --output sorting/
|
||||
```
|
||||
|
||||
### scripts/compute_metrics.py
|
||||
Compute quality metrics and apply curation:
|
||||
```bash
|
||||
python scripts/compute_metrics.py sorting/ preprocessed/ --output metrics/ --curation allen
|
||||
```
|
||||
|
||||
### scripts/export_to_phy.py
|
||||
Export to Phy for manual curation:
|
||||
```bash
|
||||
python scripts/export_to_phy.py metrics/analyzer --output phy_export/
|
||||
```
|
||||
|
||||
### assets/analysis_template.py
|
||||
Complete analysis template. Copy and customize:
|
||||
```bash
|
||||
cp assets/analysis_template.py my_analysis.py
|
||||
# Edit parameters and run
|
||||
python my_analysis.py
|
||||
```
|
||||
|
||||
### reference/standard_workflow.md
|
||||
Detailed step-by-step workflow with explanations for each stage.
|
||||
|
||||
### reference/api_reference.md
|
||||
Quick function reference organized by module.
|
||||
|
||||
### reference/plotting_guide.md
|
||||
Comprehensive visualization guide for publication-quality figures.
|
||||
|
||||
## Detailed Reference Guides
|
||||
|
||||
| Topic | Reference |
|
||||
|-------|-----------|
|
||||
| Full workflow | [reference/standard_workflow.md](reference/standard_workflow.md) |
|
||||
| API reference | [reference/api_reference.md](reference/api_reference.md) |
|
||||
| Plotting guide | [reference/plotting_guide.md](reference/plotting_guide.md) |
|
||||
| Preprocessing | [PREPROCESSING.md](PREPROCESSING.md) |
|
||||
| Spike sorting | [SPIKE_SORTING.md](SPIKE_SORTING.md) |
|
||||
| Motion correction | [MOTION_CORRECTION.md](MOTION_CORRECTION.md) |
|
||||
| Quality metrics | [QUALITY_METRICS.md](QUALITY_METRICS.md) |
|
||||
| Automated curation | [AUTOMATED_CURATION.md](AUTOMATED_CURATION.md) |
|
||||
| AI-assisted curation | [AI_CURATION.md](AI_CURATION.md) |
|
||||
| Waveform analysis | [ANALYSIS.md](ANALYSIS.md) |
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Core packages
|
||||
pip install spikeinterface[full] probeinterface neo
|
||||
|
||||
# Spike sorters
|
||||
pip install kilosort # Kilosort4 (GPU required)
|
||||
pip install spykingcircus # SpykingCircus2 (CPU)
|
||||
pip install mountainsort5 # Mountainsort5 (CPU)
|
||||
|
||||
# Our toolkit
|
||||
pip install neuropixels-analysis
|
||||
|
||||
# Optional: AI curation
|
||||
pip install anthropic
|
||||
|
||||
# Optional: IBL tools
|
||||
pip install ibl-neuropixel ibllib
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
project/
|
||||
├── raw_data/
|
||||
│ └── recording_g0/
|
||||
│ └── recording_g0_imec0/
|
||||
│ ├── recording_g0_t0.imec0.ap.bin
|
||||
│ └── recording_g0_t0.imec0.ap.meta
|
||||
├── preprocessed/ # Saved preprocessed recording
|
||||
├── motion/ # Motion estimation results
|
||||
├── sorting_output/ # Spike sorter output
|
||||
├── analyzer/ # SortingAnalyzer (waveforms, metrics)
|
||||
├── phy_export/ # For manual curation
|
||||
├── ai_curation/ # AI analysis reports
|
||||
└── results/
|
||||
├── quality_metrics.csv
|
||||
├── curation_labels.json
|
||||
└── output.nwb
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **SpikeInterface Docs**: https://spikeinterface.readthedocs.io/
|
||||
- **Neuropixels Tutorial**: https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
|
||||
- **Kilosort4 GitHub**: https://github.com/MouseLand/Kilosort
|
||||
- **IBL Neuropixel Tools**: https://github.com/int-brain-lab/ibl-neuropixel
|
||||
- **Allen Institute ecephys**: https://github.com/AllenInstitute/ecephys_spike_sorting
|
||||
- **Bombcell (Automated QC)**: https://github.com/Julie-Fabre/bombcell
|
||||
- **SpikeAgent (AI Curation)**: https://github.com/SpikeAgent/SpikeAgent
|
||||
339
scientific-skills/neuropixels-analysis/SPIKE_SORTING.md
Normal file
339
scientific-skills/neuropixels-analysis/SPIKE_SORTING.md
Normal file
@@ -0,0 +1,339 @@
|
||||
# Spike Sorting Reference
|
||||
|
||||
Comprehensive guide to spike sorting Neuropixels data.
|
||||
|
||||
## Available Sorters
|
||||
|
||||
| Sorter | GPU Required | Speed | Quality | Best For |
|
||||
|--------|--------------|-------|---------|----------|
|
||||
| **Kilosort4** | Yes (CUDA) | Fast | Excellent | Production use |
|
||||
| **Kilosort3** | Yes (CUDA) | Fast | Very Good | Legacy compatibility |
|
||||
| **Kilosort2.5** | Yes (CUDA) | Fast | Good | Older pipelines |
|
||||
| **SpykingCircus2** | No | Medium | Good | CPU-only systems |
|
||||
| **Mountainsort5** | No | Medium | Good | Small recordings |
|
||||
| **Tridesclous2** | No | Medium | Good | Interactive sorting |
|
||||
|
||||
## Kilosort4 (Recommended)
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
pip install kilosort
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
|
||||
# Run Kilosort4
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
recording,
|
||||
output_folder='ks4_output',
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print(f"Found {len(sorting.unit_ids)} units")
|
||||
```
|
||||
|
||||
### Custom Parameters
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
recording,
|
||||
output_folder='ks4_output',
|
||||
# Detection
|
||||
Th_universal=9, # Spike detection threshold
|
||||
Th_learned=8, # Learned threshold
|
||||
# Templates
|
||||
dmin=15, # Min vertical distance between templates (um)
|
||||
dminx=12, # Min horizontal distance (um)
|
||||
nblocks=5, # Number of non-rigid blocks
|
||||
# Clustering
|
||||
max_channel_distance=None, # Max distance for template channel
|
||||
# Output
|
||||
do_CAR=False, # Skip CAR (done in preprocessing)
|
||||
skip_kilosort_preprocessing=True,
|
||||
save_extra_kwargs=True
|
||||
)
|
||||
```
|
||||
|
||||
### Kilosort4 Full Parameters
|
||||
```python
|
||||
# Get all available parameters
|
||||
params = si.get_default_sorter_params('kilosort4')
|
||||
print(params)
|
||||
|
||||
# Key parameters:
|
||||
ks4_params = {
|
||||
# Detection
|
||||
'Th_universal': 9, # Universal threshold for spike detection
|
||||
'Th_learned': 8, # Threshold for learned templates
|
||||
'spkTh': -6, # Spike threshold during extraction
|
||||
|
||||
# Clustering
|
||||
'dmin': 15, # Min distance between clusters (um)
|
||||
'dminx': 12, # Min horizontal distance (um)
|
||||
'nblocks': 5, # Blocks for non-rigid drift correction
|
||||
|
||||
# Templates
|
||||
'n_templates': 6, # Number of universal templates per group
|
||||
'nt': 61, # Number of time samples in template
|
||||
|
||||
# Performance
|
||||
'batch_size': 60000, # Batch size in samples
|
||||
'nfilt_factor': 8, # Factor for number of filters
|
||||
}
|
||||
```
|
||||
|
||||
## Kilosort3
|
||||
|
||||
### Usage
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'kilosort3',
|
||||
recording,
|
||||
output_folder='ks3_output',
|
||||
# Key parameters
|
||||
detect_threshold=6,
|
||||
projection_threshold=[9, 9],
|
||||
preclust_threshold=8,
|
||||
car=False, # CAR done in preprocessing
|
||||
freq_min=300,
|
||||
)
|
||||
```
|
||||
|
||||
## SpykingCircus2 (CPU-Only)
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
pip install spykingcircus
|
||||
```
|
||||
|
||||
### Usage
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'spykingcircus2',
|
||||
recording,
|
||||
output_folder='sc2_output',
|
||||
# Parameters
|
||||
detect_threshold=5,
|
||||
selection_method='all',
|
||||
)
|
||||
```
|
||||
|
||||
## Mountainsort5 (CPU-Only)
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
pip install mountainsort5
|
||||
```
|
||||
|
||||
### Usage
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'mountainsort5',
|
||||
recording,
|
||||
output_folder='ms5_output',
|
||||
# Parameters
|
||||
detect_threshold=5.0,
|
||||
scheme='2', # '1', '2', or '3'
|
||||
)
|
||||
```
|
||||
|
||||
## Running Multiple Sorters
|
||||
|
||||
### Compare Sorters
|
||||
```python
|
||||
# Run multiple sorters
|
||||
sorting_ks4 = si.run_sorter('kilosort4', recording, output_folder='ks4/')
|
||||
sorting_sc2 = si.run_sorter('spykingcircus2', recording, output_folder='sc2/')
|
||||
sorting_ms5 = si.run_sorter('mountainsort5', recording, output_folder='ms5/')
|
||||
|
||||
# Compare results
|
||||
comparison = si.compare_multiple_sorters(
|
||||
[sorting_ks4, sorting_sc2, sorting_ms5],
|
||||
name_list=['KS4', 'SC2', 'MS5']
|
||||
)
|
||||
|
||||
# Get agreement scores
|
||||
agreement = comparison.get_agreement_sorting()
|
||||
```
|
||||
|
||||
### Ensemble Sorting
|
||||
```python
|
||||
# Create consensus sorting
|
||||
sorting_ensemble = si.create_ensemble_sorting(
|
||||
[sorting_ks4, sorting_sc2, sorting_ms5],
|
||||
voting_method='agreement',
|
||||
min_agreement=2 # Unit must be found by at least 2 sorters
|
||||
)
|
||||
```
|
||||
|
||||
## Sorting in Docker/Singularity
|
||||
|
||||
### Using Docker
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'kilosort3',
|
||||
recording,
|
||||
output_folder='ks3_docker/',
|
||||
docker_image='spikeinterface/kilosort3-compiled-base:latest',
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
### Using Singularity
|
||||
```python
|
||||
sorting = si.run_sorter(
|
||||
'kilosort3',
|
||||
recording,
|
||||
output_folder='ks3_singularity/',
|
||||
singularity_image='/path/to/kilosort3.sif',
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
## Long Recording Strategy
|
||||
|
||||
### Concatenate Recordings
|
||||
```python
|
||||
# Multiple recording files
|
||||
recordings = [
|
||||
si.read_spikeglx(f'/path/to/recording_{i}', stream_id='imec0.ap')
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Concatenate
|
||||
recording_concat = si.concatenate_recordings(recordings)
|
||||
|
||||
# Sort
|
||||
sorting = si.run_sorter('kilosort4', recording_concat, output_folder='ks4/')
|
||||
|
||||
# Split back by original recording
|
||||
sortings_split = si.split_sorting(sorting, recording_concat)
|
||||
```
|
||||
|
||||
### Sort by Segment
|
||||
```python
|
||||
# For very long recordings, sort segments separately
|
||||
from pathlib import Path
|
||||
|
||||
segments_output = Path('sorting_segments')
|
||||
sortings = []
|
||||
|
||||
for i, segment in enumerate(recording.split_by_times([0, 3600, 7200, 10800])):
|
||||
sorting_seg = si.run_sorter(
|
||||
'kilosort4',
|
||||
segment,
|
||||
output_folder=segments_output / f'segment_{i}'
|
||||
)
|
||||
sortings.append(sorting_seg)
|
||||
```
|
||||
|
||||
## Post-Sorting Curation
|
||||
|
||||
### Manual Curation with Phy
|
||||
```python
|
||||
# Export to Phy format
|
||||
analyzer = si.create_sorting_analyzer(sorting, recording)
|
||||
analyzer.compute(['random_spikes', 'waveforms', 'templates'])
|
||||
si.export_to_phy(analyzer, output_folder='phy_export/')
|
||||
|
||||
# Open Phy
|
||||
# Run in terminal: phy template-gui phy_export/params.py
|
||||
```
|
||||
|
||||
### Load Phy Curation
|
||||
```python
|
||||
# After manual curation in Phy
|
||||
sorting_curated = si.read_phy('phy_export/')
|
||||
|
||||
# Or apply Phy labels
|
||||
sorting_curated = si.apply_phy_curation(sorting, 'phy_export/')
|
||||
```
|
||||
|
||||
### Automatic Curation
|
||||
```python
|
||||
# Remove units below quality threshold
|
||||
analyzer = si.create_sorting_analyzer(sorting, recording)
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
# Define quality criteria
|
||||
query = "(snr > 5) & (isi_violations_ratio < 0.01) & (presence_ratio > 0.9)"
|
||||
good_unit_ids = qm.query(query).index.tolist()
|
||||
|
||||
sorting_clean = sorting.select_units(good_unit_ids)
|
||||
print(f"Kept {len(good_unit_ids)}/{len(sorting.unit_ids)} units")
|
||||
```
|
||||
|
||||
## Sorting Metrics
|
||||
|
||||
### Check Sorter Output
|
||||
```python
|
||||
# Basic stats
|
||||
print(f"Units found: {len(sorting.unit_ids)}")
|
||||
print(f"Total spikes: {sorting.get_total_num_spikes()}")
|
||||
|
||||
# Per-unit spike counts
|
||||
for unit_id in sorting.unit_ids[:10]:
|
||||
n_spikes = len(sorting.get_unit_spike_train(unit_id))
|
||||
print(f"Unit {unit_id}: {n_spikes} spikes")
|
||||
```
|
||||
|
||||
### Firing Rates
|
||||
```python
|
||||
# Compute firing rates
|
||||
duration = recording.get_total_duration()
|
||||
for unit_id in sorting.unit_ids:
|
||||
n_spikes = len(sorting.get_unit_spike_train(unit_id))
|
||||
fr = n_spikes / duration
|
||||
print(f"Unit {unit_id}: {fr:.2f} Hz")
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Out of GPU Memory**
|
||||
```python
|
||||
# Reduce batch size
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
recording,
|
||||
output_folder='ks4/',
|
||||
batch_size=30000 # Smaller batch
|
||||
)
|
||||
```
|
||||
|
||||
**Too Few Units Found**
|
||||
```python
|
||||
# Lower detection threshold
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
recording,
|
||||
output_folder='ks4/',
|
||||
Th_universal=7, # Lower from default 9
|
||||
Th_learned=6
|
||||
)
|
||||
```
|
||||
|
||||
**Too Many Units (Over-splitting)**
|
||||
```python
|
||||
# Increase minimum distance between templates
|
||||
sorting = si.run_sorter(
|
||||
'kilosort4',
|
||||
recording,
|
||||
output_folder='ks4/',
|
||||
dmin=20, # Increase from 15
|
||||
dminx=16 # Increase from 12
|
||||
)
|
||||
```
|
||||
|
||||
**Check GPU Availability**
|
||||
```python
|
||||
import torch
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
```
|
||||
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Neuropixels Analysis Template
|
||||
|
||||
Complete analysis workflow from raw data to curated units.
|
||||
Copy and customize this template for your analysis.
|
||||
|
||||
Usage:
|
||||
1. Copy this file to your analysis directory
|
||||
2. Update the PARAMETERS section
|
||||
3. Run: python analysis_template.py
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# PARAMETERS - Customize these for your analysis
|
||||
# =============================================================================
|
||||
|
||||
# Input/Output paths
|
||||
DATA_PATH = '/path/to/your/spikeglx/data/'
|
||||
OUTPUT_DIR = 'analysis_output/'
|
||||
DATA_FORMAT = 'spikeglx' # 'spikeglx', 'openephys', or 'nwb'
|
||||
STREAM_ID = 'imec0.ap' # For multi-probe recordings
|
||||
|
||||
# Preprocessing parameters
|
||||
FREQ_MIN = 300 # Highpass filter (Hz)
|
||||
FREQ_MAX = 6000 # Lowpass filter (Hz)
|
||||
APPLY_PHASE_SHIFT = True
|
||||
APPLY_CMR = True
|
||||
DETECT_BAD_CHANNELS = True
|
||||
|
||||
# Motion correction
|
||||
CORRECT_MOTION = True
|
||||
MOTION_PRESET = 'nonrigid_accurate' # 'kilosort_like', 'nonrigid_fast_and_accurate'
|
||||
|
||||
# Spike sorting
|
||||
SORTER = 'kilosort4' # 'kilosort4', 'spykingcircus2', 'mountainsort5'
|
||||
SORTER_PARAMS = {
|
||||
'batch_size': 30000,
|
||||
'nblocks': 1, # Increase for long recordings with drift
|
||||
}
|
||||
|
||||
# Quality metrics and curation
|
||||
CURATION_METHOD = 'allen' # 'allen', 'ibl', 'strict'
|
||||
|
||||
# Processing
|
||||
N_JOBS = -1 # -1 = all cores
|
||||
|
||||
# =============================================================================
|
||||
# ANALYSIS PIPELINE - Usually no need to modify below
|
||||
# =============================================================================
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import spikeinterface.full as si
|
||||
from spikeinterface.exporters import export_to_phy
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the full analysis pipeline."""
|
||||
|
||||
output_path = Path(OUTPUT_DIR)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# =========================================================================
|
||||
# 1. LOAD DATA
|
||||
# =========================================================================
|
||||
print("=" * 60)
|
||||
print("1. LOADING DATA")
|
||||
print("=" * 60)
|
||||
|
||||
if DATA_FORMAT == 'spikeglx':
|
||||
recording = si.read_spikeglx(DATA_PATH, stream_id=STREAM_ID)
|
||||
elif DATA_FORMAT == 'openephys':
|
||||
recording = si.read_openephys(DATA_PATH)
|
||||
elif DATA_FORMAT == 'nwb':
|
||||
recording = si.read_nwb(DATA_PATH)
|
||||
else:
|
||||
raise ValueError(f"Unknown format: {DATA_FORMAT}")
|
||||
|
||||
print(f"Recording: {recording.get_num_channels()} channels")
|
||||
print(f"Duration: {recording.get_total_duration():.1f} seconds")
|
||||
print(f"Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
|
||||
# =========================================================================
|
||||
# 2. PREPROCESSING
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("2. PREPROCESSING")
|
||||
print("=" * 60)
|
||||
|
||||
rec = recording
|
||||
|
||||
# Bandpass filter
|
||||
print(f"Applying bandpass filter ({FREQ_MIN}-{FREQ_MAX} Hz)...")
|
||||
rec = si.bandpass_filter(rec, freq_min=FREQ_MIN, freq_max=FREQ_MAX)
|
||||
|
||||
# Phase shift correction
|
||||
if APPLY_PHASE_SHIFT:
|
||||
print("Applying phase shift correction...")
|
||||
rec = si.phase_shift(rec)
|
||||
|
||||
# Bad channel detection
|
||||
if DETECT_BAD_CHANNELS:
|
||||
print("Detecting bad channels...")
|
||||
bad_ids, _ = si.detect_bad_channels(rec)
|
||||
if len(bad_ids) > 0:
|
||||
print(f" Removing {len(bad_ids)} bad channels")
|
||||
rec = rec.remove_channels(bad_ids)
|
||||
|
||||
# Common median reference
|
||||
if APPLY_CMR:
|
||||
print("Applying common median reference...")
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
|
||||
# Save preprocessed
|
||||
print("Saving preprocessed recording...")
|
||||
rec.save(folder=output_path / 'preprocessed', n_jobs=N_JOBS)
|
||||
|
||||
# =========================================================================
|
||||
# 3. MOTION CORRECTION
|
||||
# =========================================================================
|
||||
if CORRECT_MOTION:
|
||||
print("\n" + "=" * 60)
|
||||
print("3. MOTION CORRECTION")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"Estimating and correcting motion (preset: {MOTION_PRESET})...")
|
||||
rec = si.correct_motion(
|
||||
rec,
|
||||
preset=MOTION_PRESET,
|
||||
folder=output_path / 'motion',
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 4. SPIKE SORTING
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("4. SPIKE SORTING")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"Running {SORTER}...")
|
||||
sorting = si.run_sorter(
|
||||
SORTER,
|
||||
rec,
|
||||
output_folder=output_path / f'{SORTER}_output',
|
||||
verbose=True,
|
||||
**SORTER_PARAMS,
|
||||
)
|
||||
|
||||
print(f"Found {len(sorting.unit_ids)} units")
|
||||
|
||||
# =========================================================================
|
||||
# 5. POSTPROCESSING
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("5. POSTPROCESSING")
|
||||
print("=" * 60)
|
||||
|
||||
print("Creating SortingAnalyzer...")
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting,
|
||||
rec,
|
||||
format='binary_folder',
|
||||
folder=output_path / 'analyzer',
|
||||
sparse=True,
|
||||
)
|
||||
|
||||
print("Computing extensions...")
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
analyzer.compute('noise_levels')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
|
||||
# =========================================================================
|
||||
# 6. QUALITY METRICS
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("6. QUALITY METRICS")
|
||||
print("=" * 60)
|
||||
|
||||
print("Computing quality metrics...")
|
||||
metrics = si.compute_quality_metrics(
|
||||
analyzer,
|
||||
metric_names=[
|
||||
'snr', 'isi_violations_ratio', 'presence_ratio',
|
||||
'amplitude_cutoff', 'firing_rate', 'amplitude_cv',
|
||||
],
|
||||
n_jobs=N_JOBS,
|
||||
)
|
||||
|
||||
metrics.to_csv(output_path / 'quality_metrics.csv')
|
||||
print(f"Saved metrics to: {output_path / 'quality_metrics.csv'}")
|
||||
|
||||
# Print summary
|
||||
print("\nMetrics summary:")
|
||||
for col in ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']:
|
||||
if col in metrics.columns:
|
||||
print(f" {col}: {metrics[col].median():.4f} (median)")
|
||||
|
||||
# =========================================================================
|
||||
# 7. CURATION
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("7. CURATION")
|
||||
print("=" * 60)
|
||||
|
||||
# Curation criteria
|
||||
criteria = {
|
||||
'allen': {'snr': 3.0, 'isi_violations_ratio': 0.1, 'presence_ratio': 0.9},
|
||||
'ibl': {'snr': 4.0, 'isi_violations_ratio': 0.5, 'presence_ratio': 0.5},
|
||||
'strict': {'snr': 5.0, 'isi_violations_ratio': 0.01, 'presence_ratio': 0.95},
|
||||
}[CURATION_METHOD]
|
||||
|
||||
print(f"Applying {CURATION_METHOD} criteria: {criteria}")
|
||||
|
||||
labels = {}
|
||||
for unit_id in metrics.index:
|
||||
row = metrics.loc[unit_id]
|
||||
is_good = (
|
||||
row.get('snr', 0) >= criteria['snr'] and
|
||||
row.get('isi_violations_ratio', 1) <= criteria['isi_violations_ratio'] and
|
||||
row.get('presence_ratio', 0) >= criteria['presence_ratio']
|
||||
)
|
||||
if is_good:
|
||||
labels[int(unit_id)] = 'good'
|
||||
elif row.get('snr', 0) < 2:
|
||||
labels[int(unit_id)] = 'noise'
|
||||
else:
|
||||
labels[int(unit_id)] = 'mua'
|
||||
|
||||
# Save labels
|
||||
with open(output_path / 'curation_labels.json', 'w') as f:
|
||||
json.dump(labels, f, indent=2)
|
||||
|
||||
# Count
|
||||
good_count = sum(1 for v in labels.values() if v == 'good')
|
||||
mua_count = sum(1 for v in labels.values() if v == 'mua')
|
||||
noise_count = sum(1 for v in labels.values() if v == 'noise')
|
||||
|
||||
print(f"\nCuration results:")
|
||||
print(f" Good: {good_count}")
|
||||
print(f" MUA: {mua_count}")
|
||||
print(f" Noise: {noise_count}")
|
||||
print(f" Total: {len(labels)}")
|
||||
|
||||
# =========================================================================
|
||||
# 8. EXPORT
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("8. EXPORT")
|
||||
print("=" * 60)
|
||||
|
||||
print("Exporting to Phy...")
|
||||
export_to_phy(
|
||||
analyzer,
|
||||
output_folder=output_path / 'phy_export',
|
||||
copy_binary=True,
|
||||
)
|
||||
|
||||
print(f"\nAnalysis complete!")
|
||||
print(f"Results saved to: {output_path}")
|
||||
print(f"\nTo open in Phy:")
|
||||
print(f" phy template-gui {output_path / 'phy_export' / 'params.py'}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,415 @@
|
||||
# API Reference
|
||||
|
||||
Quick reference for neuropixels_analysis functions organized by module.
|
||||
|
||||
## Core Module
|
||||
|
||||
### load_recording
|
||||
|
||||
```python
|
||||
npa.load_recording(
|
||||
path: str,
|
||||
format: str = 'auto', # 'spikeglx', 'openephys', 'nwb'
|
||||
stream_id: str = None, # e.g., 'imec0.ap'
|
||||
) -> Recording
|
||||
```
|
||||
|
||||
Load Neuropixels recording from various formats.
|
||||
|
||||
### run_pipeline
|
||||
|
||||
```python
|
||||
npa.run_pipeline(
|
||||
recording: Recording,
|
||||
output_dir: str,
|
||||
sorter: str = 'kilosort4',
|
||||
preprocess: bool = True,
|
||||
correct_motion: bool = True,
|
||||
postprocess: bool = True,
|
||||
curate: bool = True,
|
||||
curation_method: str = 'allen',
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Run complete analysis pipeline. Returns dictionary with all results.
|
||||
|
||||
## Preprocessing Module
|
||||
|
||||
### preprocess
|
||||
|
||||
```python
|
||||
npa.preprocess(
|
||||
recording: Recording,
|
||||
freq_min: float = 300,
|
||||
freq_max: float = 6000,
|
||||
phase_shift: bool = True,
|
||||
common_ref: bool = True,
|
||||
bad_channel_detection: bool = True,
|
||||
) -> Recording
|
||||
```
|
||||
|
||||
Apply standard preprocessing chain.
|
||||
|
||||
### detect_bad_channels
|
||||
|
||||
```python
|
||||
npa.detect_bad_channels(
|
||||
recording: Recording,
|
||||
method: str = 'coherence+psd',
|
||||
**kwargs,
|
||||
) -> list
|
||||
```
|
||||
|
||||
Detect and return list of bad channel IDs.
|
||||
|
||||
### apply_filters
|
||||
|
||||
```python
|
||||
npa.apply_filters(
|
||||
recording: Recording,
|
||||
freq_min: float = 300,
|
||||
freq_max: float = 6000,
|
||||
filter_type: str = 'bandpass',
|
||||
) -> Recording
|
||||
```
|
||||
|
||||
Apply frequency filters.
|
||||
|
||||
### common_reference
|
||||
|
||||
```python
|
||||
npa.common_reference(
|
||||
recording: Recording,
|
||||
operator: str = 'median',
|
||||
reference: str = 'global',
|
||||
) -> Recording
|
||||
```
|
||||
|
||||
Apply common reference (CMR/CAR).
|
||||
|
||||
## Motion Module
|
||||
|
||||
### check_drift
|
||||
|
||||
```python
|
||||
npa.check_drift(
|
||||
recording: Recording,
|
||||
plot: bool = True,
|
||||
output: str = None,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Check recording for drift. Returns drift statistics.
|
||||
|
||||
### estimate_motion
|
||||
|
||||
```python
|
||||
npa.estimate_motion(
|
||||
recording: Recording,
|
||||
preset: str = 'kilosort_like',
|
||||
**kwargs,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Estimate motion without applying correction.
|
||||
|
||||
### correct_motion
|
||||
|
||||
```python
|
||||
npa.correct_motion(
|
||||
recording: Recording,
|
||||
preset: str = 'nonrigid_accurate',
|
||||
folder: str = None,
|
||||
**kwargs,
|
||||
) -> Recording
|
||||
```
|
||||
|
||||
Apply motion correction.
|
||||
|
||||
**Presets:**
|
||||
- `'kilosort_like'`: Fast, rigid correction
|
||||
- `'nonrigid_accurate'`: Slower, better for severe drift
|
||||
- `'nonrigid_fast_and_accurate'`: Balanced option
|
||||
|
||||
## Sorting Module
|
||||
|
||||
### run_sorting
|
||||
|
||||
```python
|
||||
npa.run_sorting(
|
||||
recording: Recording,
|
||||
sorter: str = 'kilosort4',
|
||||
output_folder: str = None,
|
||||
sorter_params: dict = None,
|
||||
**kwargs,
|
||||
) -> Sorting
|
||||
```
|
||||
|
||||
Run spike sorter.
|
||||
|
||||
**Supported sorters:**
|
||||
- `'kilosort4'`: GPU-based, recommended
|
||||
- `'kilosort3'`: Legacy, requires MATLAB
|
||||
- `'spykingcircus2'`: CPU-based alternative
|
||||
- `'mountainsort5'`: Fast, good for short recordings
|
||||
|
||||
### compare_sorters
|
||||
|
||||
```python
|
||||
npa.compare_sorters(
|
||||
sortings: list,
|
||||
delta_time: float = 0.4, # ms
|
||||
match_score: float = 0.5,
|
||||
) -> Comparison
|
||||
```
|
||||
|
||||
Compare results from multiple sorters.
|
||||
|
||||
## Postprocessing Module
|
||||
|
||||
### create_analyzer
|
||||
|
||||
```python
|
||||
npa.create_analyzer(
|
||||
sorting: Sorting,
|
||||
recording: Recording,
|
||||
output_folder: str = None,
|
||||
sparse: bool = True,
|
||||
) -> SortingAnalyzer
|
||||
```
|
||||
|
||||
Create SortingAnalyzer for postprocessing.
|
||||
|
||||
### postprocess
|
||||
|
||||
```python
|
||||
npa.postprocess(
|
||||
sorting: Sorting,
|
||||
recording: Recording,
|
||||
output_folder: str = None,
|
||||
compute_all: bool = True,
|
||||
n_jobs: int = -1,
|
||||
) -> tuple[SortingAnalyzer, DataFrame]
|
||||
```
|
||||
|
||||
Full postprocessing. Returns (analyzer, metrics).
|
||||
|
||||
### compute_quality_metrics
|
||||
|
||||
```python
|
||||
npa.compute_quality_metrics(
|
||||
analyzer: SortingAnalyzer,
|
||||
metric_names: list = None, # None = all
|
||||
**kwargs,
|
||||
) -> DataFrame
|
||||
```
|
||||
|
||||
Compute quality metrics for all units.
|
||||
|
||||
**Available metrics:**
|
||||
- `snr`: Signal-to-noise ratio
|
||||
- `isi_violations_ratio`: ISI violations
|
||||
- `presence_ratio`: Recording presence
|
||||
- `amplitude_cutoff`: Amplitude distribution cutoff
|
||||
- `firing_rate`: Average firing rate
|
||||
- `amplitude_cv`: Amplitude coefficient of variation
|
||||
- `sliding_rp_violation`: Sliding window refractory violations
|
||||
- `d_prime`: Isolation quality
|
||||
- `nearest_neighbor`: Nearest-neighbor overlap
|
||||
|
||||
## Curation Module
|
||||
|
||||
### curate
|
||||
|
||||
```python
|
||||
npa.curate(
|
||||
metrics: DataFrame,
|
||||
method: str = 'allen', # 'allen', 'ibl', 'strict', 'custom'
|
||||
**thresholds,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Apply automated curation. Returns {unit_id: label}.
|
||||
|
||||
### auto_classify
|
||||
|
||||
```python
|
||||
npa.auto_classify(
|
||||
metrics: DataFrame,
|
||||
snr_threshold: float = 5.0,
|
||||
isi_threshold: float = 0.01,
|
||||
presence_threshold: float = 0.9,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Classify units based on custom thresholds.
|
||||
|
||||
### filter_units
|
||||
|
||||
```python
|
||||
npa.filter_units(
|
||||
sorting: Sorting,
|
||||
labels: dict,
|
||||
keep: list = ['good'],
|
||||
) -> Sorting
|
||||
```
|
||||
|
||||
Filter sorting to keep only specified labels.
|
||||
|
||||
## AI Curation Module
|
||||
|
||||
### generate_unit_report
|
||||
|
||||
```python
|
||||
npa.generate_unit_report(
|
||||
analyzer: SortingAnalyzer,
|
||||
unit_id: int,
|
||||
output_dir: str = None,
|
||||
figsize: tuple = (16, 12),
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Generate visual report for AI analysis.
|
||||
|
||||
Returns:
|
||||
- `'image_path'`: Path to saved figure
|
||||
- `'image_base64'`: Base64 encoded image
|
||||
- `'metrics'`: Quality metrics dict
|
||||
- `'unit_id'`: Unit ID
|
||||
|
||||
### analyze_unit_visually
|
||||
|
||||
```python
|
||||
npa.analyze_unit_visually(
|
||||
analyzer: SortingAnalyzer,
|
||||
unit_id: int,
|
||||
api_client: Any = None,
|
||||
model: str = 'claude-3-5-sonnet-20241022',
|
||||
task: str = 'quality_assessment',
|
||||
custom_prompt: str = None,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Analyze unit using vision-language model.
|
||||
|
||||
**Tasks:**
|
||||
- `'quality_assessment'`: Classify as good/mua/noise
|
||||
- `'merge_candidate'`: Check if units should merge
|
||||
- `'drift_assessment'`: Assess motion/drift
|
||||
|
||||
### batch_visual_curation
|
||||
|
||||
```python
|
||||
npa.batch_visual_curation(
|
||||
analyzer: SortingAnalyzer,
|
||||
unit_ids: list = None,
|
||||
api_client: Any = None,
|
||||
model: str = 'claude-3-5-sonnet-20241022',
|
||||
output_dir: str = None,
|
||||
progress_callback: callable = None,
|
||||
) -> dict
|
||||
```
|
||||
|
||||
Run visual curation on multiple units.
|
||||
|
||||
### CurationSession
|
||||
|
||||
```python
|
||||
session = npa.CurationSession.create(
|
||||
analyzer: SortingAnalyzer,
|
||||
output_dir: str,
|
||||
session_id: str = None,
|
||||
unit_ids: list = None,
|
||||
sort_by_confidence: bool = True,
|
||||
)
|
||||
|
||||
# Navigation
|
||||
session.current_unit() -> UnitCuration
|
||||
session.next_unit() -> UnitCuration
|
||||
session.prev_unit() -> UnitCuration
|
||||
session.go_to_unit(unit_id: int) -> UnitCuration
|
||||
|
||||
# Decisions
|
||||
session.set_decision(unit_id, decision, notes='')
|
||||
session.set_ai_classification(unit_id, classification)
|
||||
|
||||
# Export
|
||||
session.get_final_labels() -> dict
|
||||
session.export_decisions(output_path) -> DataFrame
|
||||
session.get_summary() -> dict
|
||||
|
||||
# Persistence
|
||||
session.save()
|
||||
session = npa.CurationSession.load(session_dir)
|
||||
```
|
||||
|
||||
## Visualization Module
|
||||
|
||||
### plot_drift
|
||||
|
||||
```python
|
||||
npa.plot_drift(
|
||||
recording: Recording,
|
||||
motion: dict = None,
|
||||
output: str = None,
|
||||
figsize: tuple = (12, 8),
|
||||
)
|
||||
```
|
||||
|
||||
Plot drift/motion map.
|
||||
|
||||
### plot_quality_metrics
|
||||
|
||||
```python
|
||||
npa.plot_quality_metrics(
|
||||
analyzer: SortingAnalyzer,
|
||||
metrics: DataFrame = None,
|
||||
output: str = None,
|
||||
)
|
||||
```
|
||||
|
||||
Plot quality metrics overview.
|
||||
|
||||
### plot_unit_summary
|
||||
|
||||
```python
|
||||
npa.plot_unit_summary(
|
||||
analyzer: SortingAnalyzer,
|
||||
unit_id: int,
|
||||
output: str = None,
|
||||
)
|
||||
```
|
||||
|
||||
Plot comprehensive unit summary.
|
||||
|
||||
## SpikeInterface Integration
|
||||
|
||||
All neuropixels_analysis functions work with SpikeInterface objects:
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# SpikeInterface recording works with npa functions
|
||||
recording = si.read_spikeglx('/path/')
|
||||
rec = npa.preprocess(recording)
|
||||
|
||||
# Access SpikeInterface directly for advanced usage
|
||||
rec_filtered = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
|
||||
```
|
||||
|
||||
## Common Parameters
|
||||
|
||||
### Recording parameters
|
||||
- `freq_min`: Highpass cutoff (Hz)
|
||||
- `freq_max`: Lowpass cutoff (Hz)
|
||||
- `n_jobs`: Parallel jobs (-1 = all cores)
|
||||
|
||||
### Sorting parameters
|
||||
- `output_folder`: Where to save results
|
||||
- `sorter_params`: Dict of sorter-specific params
|
||||
|
||||
### Quality metric thresholds
|
||||
- `snr_threshold`: SNR cutoff (typically 5)
|
||||
- `isi_threshold`: ISI violations cutoff (typically 0.01)
|
||||
- `presence_threshold`: Presence ratio cutoff (typically 0.9)
|
||||
@@ -0,0 +1,454 @@
|
||||
# Plotting Guide
|
||||
|
||||
Comprehensive guide for creating publication-quality visualizations from Neuropixels data.
|
||||
|
||||
## Setup
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import spikeinterface.full as si
|
||||
import spikeinterface.widgets as sw
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# High-quality settings
|
||||
plt.rcParams['figure.dpi'] = 150
|
||||
plt.rcParams['savefig.dpi'] = 300
|
||||
plt.rcParams['font.size'] = 10
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
```
|
||||
|
||||
## Drift and Motion Plots
|
||||
|
||||
### Basic Drift Map
|
||||
|
||||
```python
|
||||
# Using npa
|
||||
npa.plot_drift(recording, output='drift_map.png')
|
||||
|
||||
# Using SpikeInterface widgets
|
||||
from spikeinterface.preprocessing import detect_peaks, localize_peaks
|
||||
|
||||
peaks = detect_peaks(recording, method='locally_exclusive')
|
||||
peak_locations = localize_peaks(recording, peaks, method='center_of_mass')
|
||||
|
||||
sw.plot_drift_raster_map(
|
||||
peaks=peaks,
|
||||
peak_locations=peak_locations,
|
||||
recording=recording,
|
||||
clim=(-50, 50),
|
||||
)
|
||||
plt.savefig('drift_raster.png', bbox_inches='tight')
|
||||
```
|
||||
|
||||
### Motion Estimate Visualization
|
||||
|
||||
```python
|
||||
motion_info = npa.estimate_motion(recording)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
||||
|
||||
# Motion over time
|
||||
ax = axes[0]
|
||||
for i in range(motion_info['motion'].shape[1]):
|
||||
ax.plot(motion_info['temporal_bins'], motion_info['motion'][:, i], alpha=0.5)
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Motion (um)')
|
||||
ax.set_title('Estimated Motion')
|
||||
|
||||
# Motion histogram
|
||||
ax = axes[1]
|
||||
ax.hist(motion_info['motion'].flatten(), bins=50, edgecolor='black')
|
||||
ax.set_xlabel('Motion (um)')
|
||||
ax.set_ylabel('Count')
|
||||
ax.set_title('Motion Distribution')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('motion_analysis.png', dpi=300)
|
||||
```
|
||||
|
||||
## Waveform Plots
|
||||
|
||||
### Single Unit Waveforms
|
||||
|
||||
```python
|
||||
unit_id = 0
|
||||
|
||||
# Basic waveforms
|
||||
sw.plot_unit_waveforms(analyzer, unit_ids=[unit_id])
|
||||
plt.savefig(f'unit_{unit_id}_waveforms.png')
|
||||
|
||||
# With density map
|
||||
sw.plot_unit_waveform_density_map(analyzer, unit_ids=[unit_id])
|
||||
plt.savefig(f'unit_{unit_id}_density.png')
|
||||
```
|
||||
|
||||
### Template Comparison
|
||||
|
||||
```python
|
||||
# Compare multiple units
|
||||
unit_ids = [0, 1, 2, 3]
|
||||
sw.plot_unit_templates(analyzer, unit_ids=unit_ids)
|
||||
plt.savefig('template_comparison.png')
|
||||
```
|
||||
|
||||
### Waveforms on Probe
|
||||
|
||||
```python
|
||||
# Show waveforms spatially on probe
|
||||
sw.plot_unit_waveforms_on_probe(
|
||||
analyzer,
|
||||
unit_ids=[unit_id],
|
||||
plot_channels=True,
|
||||
)
|
||||
plt.savefig(f'unit_{unit_id}_probe.png')
|
||||
```
|
||||
|
||||
## Quality Metrics Visualization
|
||||
|
||||
### Metrics Overview
|
||||
|
||||
```python
|
||||
npa.plot_quality_metrics(analyzer, metrics, output='quality_overview.png')
|
||||
```
|
||||
|
||||
### Metrics Distribution
|
||||
|
||||
```python
|
||||
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
|
||||
|
||||
metric_names = ['snr', 'isi_violations_ratio', 'presence_ratio',
|
||||
'amplitude_cutoff', 'firing_rate', 'amplitude_cv']
|
||||
|
||||
for ax, metric in zip(axes.flat, metric_names):
|
||||
if metric in metrics.columns:
|
||||
values = metrics[metric].dropna()
|
||||
ax.hist(values, bins=30, edgecolor='black', alpha=0.7)
|
||||
ax.axvline(values.median(), color='red', linestyle='--', label='median')
|
||||
ax.set_xlabel(metric)
|
||||
ax.set_ylabel('Count')
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('metrics_distribution.png', dpi=300)
|
||||
```
|
||||
|
||||
### Metrics Scatter Matrix
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
key_metrics = ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']
|
||||
pd.plotting.scatter_matrix(
|
||||
metrics[key_metrics],
|
||||
figsize=(10, 10),
|
||||
alpha=0.5,
|
||||
diagonal='hist',
|
||||
)
|
||||
plt.savefig('metrics_scatter.png', dpi=300)
|
||||
```
|
||||
|
||||
### Metrics vs Labels
|
||||
|
||||
```python
|
||||
labels_series = pd.Series(labels)
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
||||
|
||||
for ax, metric in zip(axes, ['snr', 'isi_violations_ratio', 'presence_ratio']):
|
||||
for label in ['good', 'mua', 'noise']:
|
||||
mask = labels_series == label
|
||||
if mask.any():
|
||||
ax.hist(metrics.loc[mask.index[mask], metric],
|
||||
alpha=0.5, label=label, bins=20)
|
||||
ax.set_xlabel(metric)
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('metrics_by_label.png', dpi=300)
|
||||
```
|
||||
|
||||
## Correlogram Plots
|
||||
|
||||
### Autocorrelogram
|
||||
|
||||
```python
|
||||
sw.plot_autocorrelograms(
|
||||
analyzer,
|
||||
unit_ids=[unit_id],
|
||||
window_ms=50,
|
||||
bin_ms=1,
|
||||
)
|
||||
plt.savefig(f'unit_{unit_id}_acg.png')
|
||||
```
|
||||
|
||||
### Cross-correlograms
|
||||
|
||||
```python
|
||||
unit_pairs = [(0, 1), (0, 2), (1, 2)]
|
||||
sw.plot_crosscorrelograms(
|
||||
analyzer,
|
||||
unit_pairs=unit_pairs,
|
||||
window_ms=50,
|
||||
bin_ms=1,
|
||||
)
|
||||
plt.savefig('crosscorrelograms.png')
|
||||
```
|
||||
|
||||
### Correlogram Matrix
|
||||
|
||||
```python
|
||||
sw.plot_autocorrelograms(
|
||||
analyzer,
|
||||
unit_ids=analyzer.sorting.unit_ids[:10], # First 10 units
|
||||
)
|
||||
plt.savefig('acg_matrix.png')
|
||||
```
|
||||
|
||||
## Spike Train Plots
|
||||
|
||||
### Raster Plot
|
||||
|
||||
```python
|
||||
sw.plot_rasters(
|
||||
sorting,
|
||||
time_range=(0, 30), # First 30 seconds
|
||||
unit_ids=unit_ids[:5],
|
||||
)
|
||||
plt.savefig('raster.png')
|
||||
```
|
||||
|
||||
### Firing Rate Over Time
|
||||
|
||||
```python
|
||||
unit_id = 0
|
||||
spike_train = sorting.get_unit_spike_train(unit_id)
|
||||
fs = recording.get_sampling_frequency()
|
||||
times = spike_train / fs
|
||||
|
||||
# Compute firing rate histogram
|
||||
bin_width = 1.0 # seconds
|
||||
bins = np.arange(0, recording.get_total_duration(), bin_width)
|
||||
hist, _ = np.histogram(times, bins=bins)
|
||||
firing_rate = hist / bin_width
|
||||
|
||||
plt.figure(figsize=(12, 3))
|
||||
plt.bar(bins[:-1], firing_rate, width=bin_width, edgecolor='none')
|
||||
plt.xlabel('Time (s)')
|
||||
plt.ylabel('Firing rate (Hz)')
|
||||
plt.title(f'Unit {unit_id} firing rate')
|
||||
plt.savefig(f'unit_{unit_id}_firing_rate.png', dpi=300)
|
||||
```
|
||||
|
||||
## Probe and Location Plots
|
||||
|
||||
### Probe Layout
|
||||
|
||||
```python
|
||||
sw.plot_probe_map(recording, with_channel_ids=True)
|
||||
plt.savefig('probe_layout.png')
|
||||
```
|
||||
|
||||
### Unit Locations on Probe
|
||||
|
||||
```python
|
||||
sw.plot_unit_locations(analyzer, with_channel_ids=True)
|
||||
plt.savefig('unit_locations.png')
|
||||
```
|
||||
|
||||
### Spike Locations
|
||||
|
||||
```python
|
||||
sw.plot_spike_locations(analyzer, unit_ids=[unit_id])
|
||||
plt.savefig(f'unit_{unit_id}_spike_locations.png')
|
||||
```
|
||||
|
||||
## Amplitude Plots
|
||||
|
||||
### Amplitudes Over Time
|
||||
|
||||
```python
|
||||
sw.plot_amplitudes(
|
||||
analyzer,
|
||||
unit_ids=[unit_id],
|
||||
plot_histograms=True,
|
||||
)
|
||||
plt.savefig(f'unit_{unit_id}_amplitudes.png')
|
||||
```
|
||||
|
||||
### Amplitude Distribution
|
||||
|
||||
```python
|
||||
amplitudes = analyzer.get_extension('spike_amplitudes').get_data()
|
||||
spike_vector = sorting.to_spike_vector()
|
||||
unit_idx = list(sorting.unit_ids).index(unit_id)
|
||||
unit_mask = spike_vector['unit_index'] == unit_idx
|
||||
unit_amps = amplitudes[unit_mask]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
ax.hist(unit_amps, bins=50, edgecolor='black', alpha=0.7)
|
||||
ax.axvline(np.median(unit_amps), color='red', linestyle='--', label='median')
|
||||
ax.set_xlabel('Amplitude (uV)')
|
||||
ax.set_ylabel('Count')
|
||||
ax.set_title(f'Unit {unit_id} Amplitude Distribution')
|
||||
ax.legend()
|
||||
plt.savefig(f'unit_{unit_id}_amp_dist.png', dpi=300)
|
||||
```
|
||||
|
||||
## ISI Plots
|
||||
|
||||
### ISI Histogram
|
||||
|
||||
```python
|
||||
sw.plot_isi_distribution(
|
||||
analyzer,
|
||||
unit_ids=[unit_id],
|
||||
window_ms=100,
|
||||
bin_ms=1,
|
||||
)
|
||||
plt.savefig(f'unit_{unit_id}_isi.png')
|
||||
```
|
||||
|
||||
### ISI with Refractory Markers
|
||||
|
||||
```python
|
||||
spike_train = sorting.get_unit_spike_train(unit_id)
|
||||
fs = recording.get_sampling_frequency()
|
||||
isis = np.diff(spike_train) / fs * 1000 # ms
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.hist(isis[isis < 100], bins=100, edgecolor='black', alpha=0.7)
|
||||
ax.axvline(1.5, color='red', linestyle='--', label='1.5ms refractory')
|
||||
ax.axvline(3.0, color='orange', linestyle='--', label='3ms threshold')
|
||||
ax.set_xlabel('ISI (ms)')
|
||||
ax.set_ylabel('Count')
|
||||
ax.set_title(f'Unit {unit_id} ISI Distribution')
|
||||
ax.legend()
|
||||
plt.savefig(f'unit_{unit_id}_isi_detailed.png', dpi=300)
|
||||
```
|
||||
|
||||
## Summary Plots
|
||||
|
||||
### Unit Summary Panel
|
||||
|
||||
```python
|
||||
npa.plot_unit_summary(analyzer, unit_id, output=f'unit_{unit_id}_summary.png')
|
||||
```
|
||||
|
||||
### Manual Multi-Panel Summary
|
||||
|
||||
```python
|
||||
fig = plt.figure(figsize=(16, 12))
|
||||
|
||||
# Waveforms
|
||||
ax1 = fig.add_subplot(2, 3, 1)
|
||||
wfs = analyzer.get_extension('waveforms').get_waveforms(unit_id)
|
||||
for i in range(min(50, wfs.shape[0])):
|
||||
ax1.plot(wfs[i, :, 0], 'k', alpha=0.1, linewidth=0.5)
|
||||
template = wfs.mean(axis=0)[:, 0]
|
||||
ax1.plot(template, 'b', linewidth=2)
|
||||
ax1.set_title('Waveforms')
|
||||
|
||||
# Template
|
||||
ax2 = fig.add_subplot(2, 3, 2)
|
||||
templates_ext = analyzer.get_extension('templates')
|
||||
template = templates_ext.get_unit_template(unit_id, operator='average')
|
||||
template_std = templates_ext.get_unit_template(unit_id, operator='std')
|
||||
x = range(template.shape[0])
|
||||
ax2.plot(x, template[:, 0], 'b', linewidth=2)
|
||||
ax2.fill_between(x, template[:, 0] - template_std[:, 0],
|
||||
template[:, 0] + template_std[:, 0], alpha=0.3)
|
||||
ax2.set_title('Template')
|
||||
|
||||
# Autocorrelogram
|
||||
ax3 = fig.add_subplot(2, 3, 3)
|
||||
correlograms = analyzer.get_extension('correlograms')
|
||||
ccg, bins = correlograms.get_data()
|
||||
unit_idx = list(sorting.unit_ids).index(unit_id)
|
||||
ax3.bar(bins[:-1], ccg[unit_idx, unit_idx, :], width=bins[1]-bins[0], color='gray')
|
||||
ax3.axvline(0, color='r', linestyle='--', alpha=0.5)
|
||||
ax3.set_title('Autocorrelogram')
|
||||
|
||||
# Amplitudes
|
||||
ax4 = fig.add_subplot(2, 3, 4)
|
||||
amps_ext = analyzer.get_extension('spike_amplitudes')
|
||||
amps = amps_ext.get_data()
|
||||
spike_vector = sorting.to_spike_vector()
|
||||
unit_mask = spike_vector['unit_index'] == unit_idx
|
||||
unit_times = spike_vector['sample_index'][unit_mask] / fs
|
||||
unit_amps = amps[unit_mask]
|
||||
ax4.scatter(unit_times, unit_amps, s=1, alpha=0.3)
|
||||
ax4.set_xlabel('Time (s)')
|
||||
ax4.set_ylabel('Amplitude')
|
||||
ax4.set_title('Amplitudes')
|
||||
|
||||
# ISI
|
||||
ax5 = fig.add_subplot(2, 3, 5)
|
||||
isis = np.diff(sorting.get_unit_spike_train(unit_id)) / fs * 1000
|
||||
ax5.hist(isis[isis < 100], bins=50, color='gray', edgecolor='black')
|
||||
ax5.axvline(1.5, color='r', linestyle='--')
|
||||
ax5.set_xlabel('ISI (ms)')
|
||||
ax5.set_title('ISI Distribution')
|
||||
|
||||
# Metrics
|
||||
ax6 = fig.add_subplot(2, 3, 6)
|
||||
unit_metrics = metrics.loc[unit_id]
|
||||
text_lines = [f"{k}: {v:.4f}" for k, v in unit_metrics.items() if not np.isnan(v)]
|
||||
ax6.text(0.1, 0.9, '\n'.join(text_lines[:8]), transform=ax6.transAxes,
|
||||
verticalalignment='top', fontsize=10, family='monospace')
|
||||
ax6.axis('off')
|
||||
ax6.set_title('Metrics')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'unit_{unit_id}_full_summary.png', dpi=300)
|
||||
```
|
||||
|
||||
## Publication-Quality Settings
|
||||
|
||||
### Figure Sizes
|
||||
|
||||
```python
|
||||
# Single column (3.5 inches)
|
||||
fig, ax = plt.subplots(figsize=(3.5, 3))
|
||||
|
||||
# Double column (7 inches)
|
||||
fig, ax = plt.subplots(figsize=(7, 4))
|
||||
|
||||
# Full page
|
||||
fig, ax = plt.subplots(figsize=(7, 9))
|
||||
```
|
||||
|
||||
### Font Settings
|
||||
|
||||
```python
|
||||
plt.rcParams.update({
|
||||
'font.size': 8,
|
||||
'axes.titlesize': 9,
|
||||
'axes.labelsize': 8,
|
||||
'xtick.labelsize': 7,
|
||||
'ytick.labelsize': 7,
|
||||
'legend.fontsize': 7,
|
||||
'font.family': 'Arial',
|
||||
})
|
||||
```
|
||||
|
||||
### Export Settings
|
||||
|
||||
```python
|
||||
# For publications
|
||||
plt.savefig('figure.pdf', format='pdf', bbox_inches='tight')
|
||||
plt.savefig('figure.svg', format='svg', bbox_inches='tight')
|
||||
|
||||
# High-res PNG
|
||||
plt.savefig('figure.png', dpi=600, bbox_inches='tight', facecolor='white')
|
||||
```
|
||||
|
||||
### Color Palettes
|
||||
|
||||
```python
|
||||
# Colorblind-friendly
|
||||
colors = ['#0072B2', '#E69F00', '#009E73', '#CC79A7', '#F0E442']
|
||||
|
||||
# For good/mua/noise
|
||||
label_colors = {'good': '#2ecc71', 'mua': '#f39c12', 'noise': '#e74c3c'}
|
||||
```
|
||||
@@ -0,0 +1,385 @@
|
||||
# Standard Neuropixels Analysis Workflow
|
||||
|
||||
Complete step-by-step guide for analyzing Neuropixels recordings from raw data to curated units.
|
||||
|
||||
## Overview
|
||||
|
||||
This reference documents the complete analysis pipeline:
|
||||
|
||||
```
|
||||
Raw Recording → Preprocessing → Motion Correction → Spike Sorting →
|
||||
Postprocessing → Quality Metrics → Curation → Export
|
||||
```
|
||||
|
||||
## 1. Data Loading
|
||||
|
||||
### Supported Formats
|
||||
|
||||
```python
|
||||
import spikeinterface.full as si
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# SpikeGLX (most common)
|
||||
recording = si.read_spikeglx('/path/to/run/', stream_id='imec0.ap')
|
||||
|
||||
# Open Ephys
|
||||
recording = si.read_openephys('/path/to/experiment/')
|
||||
|
||||
# NWB format
|
||||
recording = si.read_nwb('/path/to/file.nwb')
|
||||
|
||||
# Or use our convenience wrapper
|
||||
recording = npa.load_recording('/path/to/data/', format='spikeglx')
|
||||
```
|
||||
|
||||
### Verify Recording Properties
|
||||
|
||||
```python
|
||||
# Basic properties
|
||||
print(f"Channels: {recording.get_num_channels()}")
|
||||
print(f"Duration: {recording.get_total_duration():.1f}s")
|
||||
print(f"Sampling rate: {recording.get_sampling_frequency()}Hz")
|
||||
|
||||
# Probe geometry
|
||||
print(f"Probe: {recording.get_probe().name}")
|
||||
|
||||
# Channel locations
|
||||
locations = recording.get_channel_locations()
|
||||
```
|
||||
|
||||
## 2. Preprocessing
|
||||
|
||||
### Standard Preprocessing Chain
|
||||
|
||||
```python
|
||||
# Option 1: Full pipeline (recommended)
|
||||
rec_preprocessed = npa.preprocess(recording)
|
||||
|
||||
# Option 2: Step-by-step control
|
||||
rec = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
|
||||
rec = si.phase_shift(rec) # Correct ADC phase
|
||||
bad_channels = si.detect_bad_channels(rec)
|
||||
rec = rec.remove_channels(bad_channels)
|
||||
rec = si.common_reference(rec, operator='median')
|
||||
rec_preprocessed = rec
|
||||
```
|
||||
|
||||
### IBL-Style Destriping
|
||||
|
||||
For recordings with strong artifacts:
|
||||
|
||||
```python
|
||||
from ibldsp.voltage import decompress_destripe_cbin
|
||||
|
||||
# IBL destriping (very effective)
|
||||
rec = si.highpass_filter(recording, freq_min=400)
|
||||
rec = si.phase_shift(rec)
|
||||
rec = si.highpass_spatial_filter(rec) # Destriping
|
||||
rec = si.common_reference(rec, reference='global', operator='median')
|
||||
```
|
||||
|
||||
### Save Preprocessed Data
|
||||
|
||||
```python
|
||||
# Save for reuse (speeds up iteration)
|
||||
rec_preprocessed.save(folder='preprocessed/', n_jobs=4)
|
||||
```
|
||||
|
||||
## 3. Motion/Drift Correction
|
||||
|
||||
### Check if Correction Needed
|
||||
|
||||
```python
|
||||
# Estimate motion
|
||||
motion_info = npa.estimate_motion(rec_preprocessed, preset='kilosort_like')
|
||||
|
||||
# Visualize drift
|
||||
npa.plot_drift(rec_preprocessed, motion_info, output='drift_map.png')
|
||||
|
||||
# Check magnitude
|
||||
if motion_info['motion'].max() > 10: # microns
|
||||
print("Significant drift detected - correction recommended")
|
||||
```
|
||||
|
||||
### Apply Correction
|
||||
|
||||
```python
|
||||
# DREDge-based correction (default)
|
||||
rec_corrected = npa.correct_motion(
|
||||
rec_preprocessed,
|
||||
preset='nonrigid_accurate', # or 'kilosort_like' for speed
|
||||
)
|
||||
|
||||
# Or full control
|
||||
from spikeinterface.preprocessing import correct_motion
|
||||
|
||||
rec_corrected = correct_motion(
|
||||
rec_preprocessed,
|
||||
preset='nonrigid_accurate',
|
||||
folder='motion_output/',
|
||||
output_motion=True,
|
||||
)
|
||||
```
|
||||
|
||||
## 4. Spike Sorting
|
||||
|
||||
### Recommended: Kilosort4
|
||||
|
||||
```python
|
||||
# Run Kilosort4 (requires GPU)
|
||||
sorting = npa.run_sorting(
|
||||
rec_corrected,
|
||||
sorter='kilosort4',
|
||||
output_folder='sorting_KS4/',
|
||||
)
|
||||
|
||||
# With custom parameters
|
||||
sorting = npa.run_sorting(
|
||||
rec_corrected,
|
||||
sorter='kilosort4',
|
||||
output_folder='sorting_KS4/',
|
||||
sorter_params={
|
||||
'batch_size': 30000,
|
||||
'nblocks': 5, # For nonrigid drift
|
||||
'Th_learned': 8, # Detection threshold
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Alternative Sorters
|
||||
|
||||
```python
|
||||
# SpykingCircus2 (CPU-based)
|
||||
sorting = npa.run_sorting(rec_corrected, sorter='spykingcircus2')
|
||||
|
||||
# Mountainsort5 (fast, good for short recordings)
|
||||
sorting = npa.run_sorting(rec_corrected, sorter='mountainsort5')
|
||||
```
|
||||
|
||||
### Compare Multiple Sorters
|
||||
|
||||
```python
|
||||
# Run multiple sorters
|
||||
sortings = {}
|
||||
for sorter in ['kilosort4', 'spykingcircus2']:
|
||||
sortings[sorter] = npa.run_sorting(rec_corrected, sorter=sorter)
|
||||
|
||||
# Compare results
|
||||
comparison = npa.compare_sorters(list(sortings.values()))
|
||||
agreement_matrix = comparison.get_agreement_matrix()
|
||||
```
|
||||
|
||||
## 5. Postprocessing
|
||||
|
||||
### Create Analyzer
|
||||
|
||||
```python
|
||||
# Create sorting analyzer (central object for all postprocessing)
|
||||
analyzer = npa.create_analyzer(
|
||||
sorting,
|
||||
rec_corrected,
|
||||
output_folder='analyzer/',
|
||||
)
|
||||
|
||||
# Compute all standard extensions
|
||||
analyzer = npa.postprocess(
|
||||
sorting,
|
||||
rec_corrected,
|
||||
output_folder='analyzer/',
|
||||
compute_all=True, # Waveforms, templates, metrics, etc.
|
||||
)
|
||||
```
|
||||
|
||||
### Compute Individual Extensions
|
||||
|
||||
```python
|
||||
# Waveforms
|
||||
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0, max_spikes_per_unit=500)
|
||||
|
||||
# Templates
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
|
||||
# Spike amplitudes
|
||||
analyzer.compute('spike_amplitudes')
|
||||
|
||||
# Correlograms
|
||||
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
|
||||
|
||||
# Unit locations
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
|
||||
# Spike locations
|
||||
analyzer.compute('spike_locations', method='center_of_mass')
|
||||
```
|
||||
|
||||
## 6. Quality Metrics
|
||||
|
||||
### Compute All Metrics
|
||||
|
||||
```python
|
||||
# Compute comprehensive metrics
|
||||
metrics = npa.compute_quality_metrics(
|
||||
analyzer,
|
||||
metric_names=[
|
||||
'snr',
|
||||
'isi_violations_ratio',
|
||||
'presence_ratio',
|
||||
'amplitude_cutoff',
|
||||
'firing_rate',
|
||||
'amplitude_cv',
|
||||
'sliding_rp_violation',
|
||||
'd_prime',
|
||||
'nearest_neighbor',
|
||||
],
|
||||
)
|
||||
|
||||
# View metrics
|
||||
print(metrics.head())
|
||||
```
|
||||
|
||||
### Key Metrics Explained
|
||||
|
||||
| Metric | Good Value | Description |
|
||||
|--------|------------|-------------|
|
||||
| `snr` | > 5 | Signal-to-noise ratio |
|
||||
| `isi_violations_ratio` | < 0.01 | Refractory period violations |
|
||||
| `presence_ratio` | > 0.9 | Fraction of recording with spikes |
|
||||
| `amplitude_cutoff` | < 0.1 | Estimated missed spikes |
|
||||
| `firing_rate` | > 0.1 Hz | Average firing rate |
|
||||
|
||||
## 7. Curation
|
||||
|
||||
### Automated Curation
|
||||
|
||||
```python
|
||||
# Allen Institute criteria
|
||||
labels = npa.curate(metrics, method='allen')
|
||||
|
||||
# IBL criteria
|
||||
labels = npa.curate(metrics, method='ibl')
|
||||
|
||||
# Custom thresholds
|
||||
labels = npa.curate(
|
||||
metrics,
|
||||
snr_threshold=5,
|
||||
isi_violations_threshold=0.01,
|
||||
presence_threshold=0.9,
|
||||
)
|
||||
```
|
||||
|
||||
### AI-Assisted Curation
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Setup API
|
||||
client = Anthropic()
|
||||
|
||||
# Visual analysis for uncertain units
|
||||
uncertain = metrics.query('snr > 3 and snr < 8').index.tolist()
|
||||
|
||||
for unit_id in uncertain:
|
||||
result = npa.analyze_unit_visually(analyzer, unit_id, api_client=client)
|
||||
labels[unit_id] = result['classification']
|
||||
```
|
||||
|
||||
### Interactive Curation Session
|
||||
|
||||
```python
|
||||
# Create session
|
||||
session = npa.CurationSession.create(analyzer, output_dir='curation/')
|
||||
|
||||
# Review units
|
||||
while session.current_unit():
|
||||
unit = session.current_unit()
|
||||
report = npa.generate_unit_report(analyzer, unit.unit_id)
|
||||
|
||||
# Your decision
|
||||
decision = input(f"Unit {unit.unit_id}: ")
|
||||
session.set_decision(unit.unit_id, decision)
|
||||
session.next_unit()
|
||||
|
||||
# Export
|
||||
labels = session.get_final_labels()
|
||||
```
|
||||
|
||||
## 8. Export Results
|
||||
|
||||
### Export to Phy
|
||||
|
||||
```python
|
||||
from spikeinterface.exporters import export_to_phy
|
||||
|
||||
export_to_phy(
|
||||
analyzer,
|
||||
output_folder='phy_export/',
|
||||
copy_binary=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Export to NWB
|
||||
|
||||
```python
|
||||
from spikeinterface.exporters import export_to_nwb
|
||||
|
||||
export_to_nwb(
|
||||
analyzer,
|
||||
nwbfile_path='results.nwb',
|
||||
metadata={
|
||||
'session_description': 'Neuropixels recording',
|
||||
'experimenter': 'Lab Name',
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Save Quality Summary
|
||||
|
||||
```python
|
||||
# Save metrics CSV
|
||||
metrics.to_csv('quality_metrics.csv')
|
||||
|
||||
# Save labels
|
||||
import json
|
||||
with open('curation_labels.json', 'w') as f:
|
||||
json.dump(labels, f, indent=2)
|
||||
|
||||
# Generate summary report
|
||||
npa.plot_quality_metrics(analyzer, metrics, output='quality_summary.png')
|
||||
```
|
||||
|
||||
## Full Pipeline Example
|
||||
|
||||
```python
|
||||
import neuropixels_analysis as npa
|
||||
|
||||
# Load
|
||||
recording = npa.load_recording('/data/experiment/', format='spikeglx')
|
||||
|
||||
# Preprocess
|
||||
rec = npa.preprocess(recording)
|
||||
|
||||
# Motion correction
|
||||
rec = npa.correct_motion(rec)
|
||||
|
||||
# Sort
|
||||
sorting = npa.run_sorting(rec, sorter='kilosort4')
|
||||
|
||||
# Postprocess
|
||||
analyzer, metrics = npa.postprocess(sorting, rec)
|
||||
|
||||
# Curate
|
||||
labels = npa.curate(metrics, method='allen')
|
||||
|
||||
# Export good units
|
||||
good_units = [uid for uid, label in labels.items() if label == 'good']
|
||||
print(f"Good units: {len(good_units)}/{len(labels)}")
|
||||
```
|
||||
|
||||
## Tips for Success
|
||||
|
||||
1. **Always visualize drift** before deciding on motion correction
|
||||
2. **Save preprocessed data** to avoid recomputing
|
||||
3. **Compare multiple sorters** for critical experiments
|
||||
4. **Review uncertain units manually** - don't trust automated curation blindly
|
||||
5. **Document your parameters** for reproducibility
|
||||
6. **Use GPU** for Kilosort4 (10-50x faster than CPU alternatives)
|
||||
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Compute quality metrics and curate units.
|
||||
|
||||
Usage:
|
||||
python compute_metrics.py sorting/ preprocessed/ --output metrics/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
# Curation criteria presets
|
||||
CURATION_CRITERIA = {
|
||||
'allen': {
|
||||
'snr': 3.0,
|
||||
'isi_violations_ratio': 0.1,
|
||||
'presence_ratio': 0.9,
|
||||
'amplitude_cutoff': 0.1,
|
||||
},
|
||||
'ibl': {
|
||||
'snr': 4.0,
|
||||
'isi_violations_ratio': 0.5,
|
||||
'presence_ratio': 0.5,
|
||||
'amplitude_cutoff': None,
|
||||
},
|
||||
'strict': {
|
||||
'snr': 5.0,
|
||||
'isi_violations_ratio': 0.01,
|
||||
'presence_ratio': 0.95,
|
||||
'amplitude_cutoff': 0.05,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
sorting_path: str,
|
||||
recording_path: str,
|
||||
output_dir: str,
|
||||
curation_method: str = 'allen',
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Compute quality metrics and apply curation."""
|
||||
|
||||
print(f"Loading sorting from: {sorting_path}")
|
||||
sorting = si.load_extractor(Path(sorting_path) / 'sorting')
|
||||
|
||||
print(f"Loading recording from: {recording_path}")
|
||||
recording = si.load_extractor(Path(recording_path) / 'preprocessed')
|
||||
|
||||
print(f"Units: {len(sorting.unit_ids)}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create analyzer
|
||||
print("Creating SortingAnalyzer...")
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting,
|
||||
recording,
|
||||
format='binary_folder',
|
||||
folder=output_path / 'analyzer',
|
||||
sparse=True,
|
||||
)
|
||||
|
||||
# Compute extensions
|
||||
print("Computing waveforms...")
|
||||
analyzer.compute('random_spikes', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.0, ms_after=2.0)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
|
||||
print("Computing additional extensions...")
|
||||
analyzer.compute('noise_levels')
|
||||
analyzer.compute('spike_amplitudes')
|
||||
analyzer.compute('correlograms', window_ms=50.0, bin_ms=1.0)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
|
||||
# Compute quality metrics
|
||||
print("Computing quality metrics...")
|
||||
metrics = si.compute_quality_metrics(
|
||||
analyzer,
|
||||
metric_names=[
|
||||
'snr',
|
||||
'isi_violations_ratio',
|
||||
'presence_ratio',
|
||||
'amplitude_cutoff',
|
||||
'firing_rate',
|
||||
'amplitude_cv',
|
||||
'sliding_rp_violation',
|
||||
],
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
|
||||
# Save metrics
|
||||
metrics.to_csv(output_path / 'quality_metrics.csv')
|
||||
print(f"Saved metrics to: {output_path / 'quality_metrics.csv'}")
|
||||
|
||||
# Apply curation
|
||||
criteria = CURATION_CRITERIA.get(curation_method, CURATION_CRITERIA['allen'])
|
||||
print(f"\nApplying {curation_method} curation criteria: {criteria}")
|
||||
|
||||
labels = {}
|
||||
for unit_id in metrics.index:
|
||||
row = metrics.loc[unit_id]
|
||||
|
||||
# Check each criterion
|
||||
is_good = True
|
||||
|
||||
if criteria.get('snr') and row.get('snr', 0) < criteria['snr']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('isi_violations_ratio') and row.get('isi_violations_ratio', 1) > criteria['isi_violations_ratio']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('presence_ratio') and row.get('presence_ratio', 0) < criteria['presence_ratio']:
|
||||
is_good = False
|
||||
|
||||
if criteria.get('amplitude_cutoff') and row.get('amplitude_cutoff', 1) > criteria['amplitude_cutoff']:
|
||||
is_good = False
|
||||
|
||||
# Classify
|
||||
if is_good:
|
||||
labels[int(unit_id)] = 'good'
|
||||
elif row.get('snr', 0) < 2:
|
||||
labels[int(unit_id)] = 'noise'
|
||||
else:
|
||||
labels[int(unit_id)] = 'mua'
|
||||
|
||||
# Save labels
|
||||
with open(output_path / 'curation_labels.json', 'w') as f:
|
||||
json.dump(labels, f, indent=2)
|
||||
|
||||
# Summary
|
||||
label_counts = {}
|
||||
for label in labels.values():
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
|
||||
print(f"\nCuration summary:")
|
||||
print(f" Good: {label_counts.get('good', 0)}")
|
||||
print(f" MUA: {label_counts.get('mua', 0)}")
|
||||
print(f" Noise: {label_counts.get('noise', 0)}")
|
||||
print(f" Total: {len(labels)}")
|
||||
|
||||
# Metrics summary
|
||||
print(f"\nMetrics summary:")
|
||||
for col in ['snr', 'isi_violations_ratio', 'presence_ratio', 'firing_rate']:
|
||||
if col in metrics.columns:
|
||||
print(f" {col}: {metrics[col].median():.4f} (median)")
|
||||
|
||||
return analyzer, metrics, labels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Compute quality metrics')
|
||||
parser.add_argument('sorting', help='Path to sorting directory')
|
||||
parser.add_argument('recording', help='Path to preprocessed recording')
|
||||
parser.add_argument('--output', '-o', default='metrics/', help='Output directory')
|
||||
parser.add_argument('--curation', '-c', default='allen',
|
||||
choices=['allen', 'ibl', 'strict'])
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
compute_metrics(
|
||||
args.sorting,
|
||||
args.recording,
|
||||
args.output,
|
||||
curation_method=args.curation,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick exploration of Neuropixels recording.
|
||||
|
||||
Usage:
|
||||
python explore_recording.py /path/to/spikeglx/data
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import spikeinterface.full as si
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def explore_recording(data_path: str, stream_id: str = 'imec0.ap'):
|
||||
"""Explore a Neuropixels recording."""
|
||||
|
||||
print(f"Loading: {data_path}")
|
||||
recording = si.read_spikeglx(data_path, stream_id=stream_id)
|
||||
|
||||
# Basic info
|
||||
print("\n" + "="*50)
|
||||
print("RECORDING INFO")
|
||||
print("="*50)
|
||||
print(f"Channels: {recording.get_num_channels()}")
|
||||
print(f"Duration: {recording.get_total_duration():.2f} s ({recording.get_total_duration()/60:.2f} min)")
|
||||
print(f"Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
print(f"Total samples: {recording.get_num_samples()}")
|
||||
|
||||
# Probe info
|
||||
probe = recording.get_probe()
|
||||
print(f"\nProbe: {probe.manufacturer} {probe.model_name if hasattr(probe, 'model_name') else ''}")
|
||||
print(f"Probe shape: {probe.ndim}D")
|
||||
|
||||
# Channel groups
|
||||
if recording.get_channel_groups() is not None:
|
||||
groups = np.unique(recording.get_channel_groups())
|
||||
print(f"Channel groups (shanks): {len(groups)}")
|
||||
|
||||
# Check for bad channels
|
||||
print("\n" + "="*50)
|
||||
print("BAD CHANNEL DETECTION")
|
||||
print("="*50)
|
||||
bad_ids, labels = si.detect_bad_channels(recording)
|
||||
if len(bad_ids) > 0:
|
||||
print(f"Bad channels found: {len(bad_ids)}")
|
||||
for ch, label in zip(bad_ids, labels):
|
||||
print(f" Channel {ch}: {label}")
|
||||
else:
|
||||
print("No bad channels detected")
|
||||
|
||||
# Sample traces
|
||||
print("\n" + "="*50)
|
||||
print("SIGNAL STATISTICS")
|
||||
print("="*50)
|
||||
|
||||
# Get 1 second of data
|
||||
n_samples = int(recording.get_sampling_frequency())
|
||||
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
|
||||
|
||||
print(f"Sample mean: {np.mean(traces):.2f}")
|
||||
print(f"Sample std: {np.std(traces):.2f}")
|
||||
print(f"Sample min: {np.min(traces):.2f}")
|
||||
print(f"Sample max: {np.max(traces):.2f}")
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def plot_probe(recording, output_path=None):
|
||||
"""Plot probe layout."""
|
||||
fig, ax = plt.subplots(figsize=(4, 12))
|
||||
si.plot_probe_map(recording, ax=ax, with_channel_ids=False)
|
||||
ax.set_title('Probe Layout')
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_traces(recording, duration=1.0, output_path=None):
|
||||
"""Plot raw traces."""
|
||||
n_samples = int(duration * recording.get_sampling_frequency())
|
||||
traces = recording.get_traces(start_frame=0, end_frame=n_samples)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Plot subset of channels
|
||||
n_channels = min(20, recording.get_num_channels())
|
||||
channel_idx = np.linspace(0, recording.get_num_channels()-1, n_channels, dtype=int)
|
||||
|
||||
time = np.arange(n_samples) / recording.get_sampling_frequency()
|
||||
|
||||
for i, ch in enumerate(channel_idx):
|
||||
offset = i * 200 # Offset for visibility
|
||||
ax.plot(time, traces[:, ch] + offset, 'k', linewidth=0.5)
|
||||
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Channel (offset)')
|
||||
ax.set_title(f'Raw Traces ({n_channels} channels)')
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_power_spectrum(recording, output_path=None):
|
||||
"""Plot power spectrum."""
|
||||
from scipy import signal
|
||||
|
||||
# Get data from middle channel
|
||||
mid_ch = recording.get_num_channels() // 2
|
||||
n_samples = min(int(10 * recording.get_sampling_frequency()), recording.get_num_samples())
|
||||
|
||||
traces = recording.get_traces(
|
||||
start_frame=0,
|
||||
end_frame=n_samples,
|
||||
channel_ids=[recording.channel_ids[mid_ch]]
|
||||
).flatten()
|
||||
|
||||
fs = recording.get_sampling_frequency()
|
||||
|
||||
# Compute power spectrum
|
||||
freqs, psd = signal.welch(traces, fs, nperseg=4096)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
ax.semilogy(freqs, psd)
|
||||
ax.set_xlabel('Frequency (Hz)')
|
||||
ax.set_ylabel('Power Spectral Density')
|
||||
ax.set_title(f'Power Spectrum (Channel {mid_ch})')
|
||||
ax.set_xlim(0, 5000)
|
||||
ax.axvline(300, color='r', linestyle='--', alpha=0.5, label='300 Hz')
|
||||
ax.axvline(6000, color='r', linestyle='--', alpha=0.5, label='6000 Hz')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
if output_path:
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Saved: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Explore Neuropixels recording')
|
||||
parser.add_argument('data_path', help='Path to SpikeGLX recording')
|
||||
parser.add_argument('--stream', default='imec0.ap', help='Stream ID')
|
||||
parser.add_argument('--plot', action='store_true', help='Generate plots')
|
||||
parser.add_argument('--output', default=None, help='Output directory for plots')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
recording = explore_recording(args.data_path, args.stream)
|
||||
|
||||
if args.plot:
|
||||
import os
|
||||
if args.output:
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
plot_probe(recording, f"{args.output}/probe_map.png")
|
||||
plot_traces(recording, output_path=f"{args.output}/raw_traces.png")
|
||||
plot_power_spectrum(recording, f"{args.output}/power_spectrum.png")
|
||||
else:
|
||||
plot_probe(recording)
|
||||
plot_traces(recording)
|
||||
plot_power_spectrum(recording)
|
||||
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Export sorting results to Phy for manual curation.
|
||||
|
||||
Usage:
|
||||
python export_to_phy.py metrics/analyzer --output phy_export/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
from spikeinterface.exporters import export_to_phy
|
||||
|
||||
|
||||
def export_phy(
|
||||
analyzer_path: str,
|
||||
output_dir: str,
|
||||
copy_binary: bool = True,
|
||||
compute_amplitudes: bool = True,
|
||||
compute_pc_features: bool = True,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Export to Phy format."""
|
||||
|
||||
print(f"Loading analyzer from: {analyzer_path}")
|
||||
analyzer = si.load_sorting_analyzer(analyzer_path)
|
||||
|
||||
print(f"Units: {len(analyzer.sorting.unit_ids)}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# Compute required extensions if missing
|
||||
if compute_amplitudes and analyzer.get_extension('spike_amplitudes') is None:
|
||||
print("Computing spike amplitudes...")
|
||||
analyzer.compute('spike_amplitudes')
|
||||
|
||||
if compute_pc_features and analyzer.get_extension('principal_components') is None:
|
||||
print("Computing principal components...")
|
||||
analyzer.compute('principal_components', n_components=5, mode='by_channel_local')
|
||||
|
||||
print(f"Exporting to Phy: {output_path}")
|
||||
export_to_phy(
|
||||
analyzer,
|
||||
output_folder=output_path,
|
||||
copy_binary=copy_binary,
|
||||
compute_amplitudes=compute_amplitudes,
|
||||
compute_pc_features=compute_pc_features,
|
||||
n_jobs=n_jobs,
|
||||
)
|
||||
|
||||
print("\nExport complete!")
|
||||
print(f"To open in Phy, run:")
|
||||
print(f" phy template-gui {output_path / 'params.py'}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Export to Phy')
|
||||
parser.add_argument('analyzer', help='Path to sorting analyzer')
|
||||
parser.add_argument('--output', '-o', default='phy_export/', help='Output directory')
|
||||
parser.add_argument('--no-binary', action='store_true', help='Skip copying binary file')
|
||||
parser.add_argument('--no-amplitudes', action='store_true', help='Skip amplitude computation')
|
||||
parser.add_argument('--no-pc', action='store_true', help='Skip PC feature computation')
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export_phy(
|
||||
args.analyzer,
|
||||
args.output,
|
||||
copy_binary=not args.no_binary,
|
||||
compute_amplitudes=not args.no_amplitudes,
|
||||
compute_pc_features=not args.no_pc,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neuropixels Data Analysis Pipeline (Best Practices Version)
|
||||
|
||||
Based on SpikeInterface, Allen Institute, and IBL recommendations.
|
||||
|
||||
Usage:
|
||||
python neuropixels_pipeline.py /path/to/spikeglx/data /path/to/output
|
||||
|
||||
References:
|
||||
- https://spikeinterface.readthedocs.io/en/stable/how_to/analyze_neuropixels.html
|
||||
- https://github.com/AllenInstitute/ecephys_spike_sorting
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
import spikeinterface.full as si
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_recording(data_path: str, stream_name: str = 'imec0.ap') -> si.BaseRecording:
|
||||
"""Load a SpikeGLX or Open Ephys recording."""
|
||||
|
||||
data_path = Path(data_path)
|
||||
|
||||
# Auto-detect format
|
||||
if any(data_path.rglob('*.ap.bin')) or any(data_path.rglob('*.ap.meta')):
|
||||
# SpikeGLX format
|
||||
streams, _ = si.get_neo_streams('spikeglx', data_path)
|
||||
print(f"Available streams: {streams}")
|
||||
recording = si.read_spikeglx(data_path, stream_name=stream_name)
|
||||
elif any(data_path.rglob('*.oebin')):
|
||||
# Open Ephys format
|
||||
recording = si.read_openephys(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unknown format in {data_path}")
|
||||
|
||||
print(f"Loaded recording:")
|
||||
print(f" Channels: {recording.get_num_channels()}")
|
||||
print(f" Duration: {recording.get_total_duration():.2f} s")
|
||||
print(f" Sampling rate: {recording.get_sampling_frequency()} Hz")
|
||||
|
||||
return recording
|
||||
|
||||
|
||||
def preprocess(
|
||||
recording: si.BaseRecording,
|
||||
apply_phase_shift: bool = True,
|
||||
freq_min: float = 400.,
|
||||
) -> tuple:
|
||||
"""
|
||||
Apply standard Neuropixels preprocessing.
|
||||
|
||||
Following SpikeInterface recommendations:
|
||||
1. High-pass filter at 400 Hz (not 300)
|
||||
2. Detect and remove bad channels
|
||||
3. Phase shift (NP 1.0 only)
|
||||
4. Common median reference
|
||||
"""
|
||||
print("Preprocessing...")
|
||||
|
||||
# Step 1: High-pass filter
|
||||
rec = si.highpass_filter(recording, freq_min=freq_min)
|
||||
print(f" Applied high-pass filter at {freq_min} Hz")
|
||||
|
||||
# Step 2: Detect bad channels
|
||||
bad_channel_ids, channel_labels = si.detect_bad_channels(rec)
|
||||
if len(bad_channel_ids) > 0:
|
||||
print(f" Detected {len(bad_channel_ids)} bad channels: {bad_channel_ids}")
|
||||
rec = rec.remove_channels(bad_channel_ids)
|
||||
else:
|
||||
print(" No bad channels detected")
|
||||
|
||||
# Step 3: Phase shift (for Neuropixels 1.0)
|
||||
if apply_phase_shift:
|
||||
rec = si.phase_shift(rec)
|
||||
print(" Applied phase shift correction")
|
||||
|
||||
# Step 4: Common median reference
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
print(" Applied common median reference")
|
||||
|
||||
return rec, bad_channel_ids
|
||||
|
||||
|
||||
def check_drift(recording: si.BaseRecording, output_folder: str) -> dict:
|
||||
"""
|
||||
Detect peaks and check for drift before spike sorting.
|
||||
"""
|
||||
print("Checking for drift...")
|
||||
|
||||
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
||||
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Get noise levels
|
||||
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
|
||||
|
||||
# Detect peaks
|
||||
peaks = detect_peaks(
|
||||
recording,
|
||||
method='locally_exclusive',
|
||||
noise_levels=noise_levels,
|
||||
detect_threshold=5,
|
||||
radius_um=50.,
|
||||
**job_kwargs
|
||||
)
|
||||
print(f" Detected {len(peaks)} peaks")
|
||||
|
||||
# Localize peaks
|
||||
peak_locations = localize_peaks(
|
||||
recording, peaks,
|
||||
method='center_of_mass',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# Save drift plot
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Subsample for plotting
|
||||
n_plot = min(100000, len(peaks))
|
||||
idx = np.random.choice(len(peaks), n_plot, replace=False)
|
||||
|
||||
ax.scatter(
|
||||
peaks['sample_index'][idx] / recording.get_sampling_frequency(),
|
||||
peak_locations['y'][idx],
|
||||
s=1, alpha=0.1, c='k'
|
||||
)
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Depth (μm)')
|
||||
ax.set_title('Peak Activity (Check for Drift)')
|
||||
|
||||
plt.savefig(f'{output_folder}/drift_check.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f" Saved drift plot to {output_folder}/drift_check.png")
|
||||
|
||||
# Estimate drift magnitude
|
||||
y_positions = peak_locations['y']
|
||||
drift_estimate = np.percentile(y_positions, 95) - np.percentile(y_positions, 5)
|
||||
print(f" Estimated drift range: {drift_estimate:.1f} μm")
|
||||
|
||||
return {
|
||||
'peaks': peaks,
|
||||
'peak_locations': peak_locations,
|
||||
'drift_estimate': drift_estimate
|
||||
}
|
||||
|
||||
|
||||
def correct_motion(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
preset: str = 'nonrigid_fast_and_accurate'
|
||||
) -> si.BaseRecording:
|
||||
"""Apply motion correction if needed."""
|
||||
print(f"Applying motion correction (preset: {preset})...")
|
||||
|
||||
rec_corrected = si.correct_motion(
|
||||
recording,
|
||||
preset=preset,
|
||||
folder=f'{output_folder}/motion',
|
||||
output_motion_info=True,
|
||||
n_jobs=8,
|
||||
chunk_duration='1s',
|
||||
progress_bar=True
|
||||
)
|
||||
|
||||
print(" Motion correction complete")
|
||||
return rec_corrected
|
||||
|
||||
|
||||
def run_spike_sorting(
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str,
|
||||
sorter: str = 'kilosort4'
|
||||
) -> si.BaseSorting:
|
||||
"""Run spike sorting."""
|
||||
print(f"Running spike sorting with {sorter}...")
|
||||
|
||||
sorter_folder = f'{output_folder}/sorting_{sorter}'
|
||||
|
||||
sorting = si.run_sorter(
|
||||
sorter,
|
||||
recording,
|
||||
output_folder=sorter_folder,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print(f" Found {len(sorting.unit_ids)} units")
|
||||
print(f" Total spikes: {sorting.get_total_num_spikes()}")
|
||||
|
||||
return sorting
|
||||
|
||||
|
||||
def postprocess(
|
||||
sorting: si.BaseSorting,
|
||||
recording: si.BaseRecording,
|
||||
output_folder: str
|
||||
) -> tuple:
|
||||
"""Run post-processing and compute quality metrics."""
|
||||
print("Post-processing...")
|
||||
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
|
||||
# Create analyzer
|
||||
analyzer = si.create_sorting_analyzer(
|
||||
sorting, recording,
|
||||
sparse=True,
|
||||
format='binary_folder',
|
||||
folder=f'{output_folder}/analyzer'
|
||||
)
|
||||
|
||||
# Compute extensions (order matters)
|
||||
print(" Computing waveforms...")
|
||||
analyzer.compute('random_spikes', method='uniform', max_spikes_per_unit=500)
|
||||
analyzer.compute('waveforms', ms_before=1.5, ms_after=2.0, **job_kwargs)
|
||||
analyzer.compute('templates', operators=['average', 'std'])
|
||||
analyzer.compute('noise_levels')
|
||||
|
||||
print(" Computing spike features...")
|
||||
analyzer.compute('spike_amplitudes', **job_kwargs)
|
||||
analyzer.compute('correlograms', window_ms=100, bin_ms=1)
|
||||
analyzer.compute('unit_locations', method='monopolar_triangulation')
|
||||
analyzer.compute('template_similarity')
|
||||
|
||||
print(" Computing quality metrics...")
|
||||
analyzer.compute('quality_metrics')
|
||||
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
|
||||
return analyzer, qm
|
||||
|
||||
|
||||
def curate_units(qm, method: str = 'allen') -> dict:
|
||||
"""
|
||||
Classify units based on quality metrics.
|
||||
|
||||
Methods:
|
||||
'allen': Allen Institute defaults (more permissive)
|
||||
'ibl': IBL standards
|
||||
'strict': Strict single-unit criteria
|
||||
"""
|
||||
print(f"Curating units (method: {method})...")
|
||||
|
||||
labels = {}
|
||||
|
||||
for unit_id in qm.index:
|
||||
row = qm.loc[unit_id]
|
||||
|
||||
# Noise detection (universal)
|
||||
if row['snr'] < 1.5:
|
||||
labels[unit_id] = 'noise'
|
||||
continue
|
||||
|
||||
if method == 'allen':
|
||||
# Allen Institute defaults
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.5 and
|
||||
row['amplitude_cutoff'] < 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.5:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'ibl':
|
||||
# IBL standards
|
||||
if (row['presence_ratio'] > 0.9 and
|
||||
row['isi_violations_ratio'] < 0.1 and
|
||||
row['amplitude_cutoff'] < 0.1 and
|
||||
row['firing_rate'] > 0.1):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.1:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
elif method == 'strict':
|
||||
# Strict single-unit
|
||||
if (row['snr'] > 5 and
|
||||
row['presence_ratio'] > 0.95 and
|
||||
row['isi_violations_ratio'] < 0.01 and
|
||||
row['amplitude_cutoff'] < 0.01):
|
||||
labels[unit_id] = 'good'
|
||||
elif row['isi_violations_ratio'] > 0.05:
|
||||
labels[unit_id] = 'mua'
|
||||
else:
|
||||
labels[unit_id] = 'unsorted'
|
||||
|
||||
# Summary
|
||||
from collections import Counter
|
||||
counts = Counter(labels.values())
|
||||
print(f" Classification: {dict(counts)}")
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def export_results(
|
||||
analyzer,
|
||||
sorting,
|
||||
recording,
|
||||
labels: dict,
|
||||
output_folder: str
|
||||
):
|
||||
"""Export results to various formats."""
|
||||
print("Exporting results...")
|
||||
|
||||
# Get good units
|
||||
good_ids = [u for u, l in labels.items() if l == 'good']
|
||||
sorting_good = sorting.select_units(good_ids)
|
||||
|
||||
# Export to Phy
|
||||
phy_folder = f'{output_folder}/phy_export'
|
||||
si.export_to_phy(analyzer, phy_folder,
|
||||
compute_pc_features=True,
|
||||
compute_amplitudes=True)
|
||||
print(f" Phy export: {phy_folder}")
|
||||
|
||||
# Generate report
|
||||
report_folder = f'{output_folder}/report'
|
||||
si.export_report(analyzer, report_folder, format='png')
|
||||
print(f" Report: {report_folder}")
|
||||
|
||||
# Save quality metrics
|
||||
qm = analyzer.get_extension('quality_metrics').get_data()
|
||||
qm.to_csv(f'{output_folder}/quality_metrics.csv')
|
||||
|
||||
# Save labels
|
||||
with open(f'{output_folder}/unit_labels.json', 'w') as f:
|
||||
json.dump({str(k): v for k, v in labels.items()}, f, indent=2)
|
||||
|
||||
# Save summary
|
||||
summary = {
|
||||
'total_units': len(sorting.unit_ids),
|
||||
'good_units': len(good_ids),
|
||||
'total_spikes': int(sorting.get_total_num_spikes()),
|
||||
'duration_s': float(recording.get_total_duration()),
|
||||
'n_channels': int(recording.get_num_channels()),
|
||||
}
|
||||
with open(f'{output_folder}/summary.json', 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print(f" Summary: {summary}")
|
||||
|
||||
|
||||
def run_pipeline(
|
||||
data_path: str,
|
||||
output_path: str,
|
||||
sorter: str = 'kilosort4',
|
||||
stream_name: str = 'imec0.ap',
|
||||
apply_motion_correction: bool = True,
|
||||
curation_method: str = 'allen'
|
||||
):
|
||||
"""Run complete Neuropixels analysis pipeline."""
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Load data
|
||||
recording = load_recording(data_path, stream_name)
|
||||
|
||||
# 2. Preprocess
|
||||
rec_preprocessed, bad_channels = preprocess(recording)
|
||||
|
||||
# Save preprocessed
|
||||
preproc_folder = output_path / 'preprocessed'
|
||||
job_kwargs = dict(n_jobs=8, chunk_duration='1s', progress_bar=True)
|
||||
rec_preprocessed = rec_preprocessed.save(
|
||||
folder=str(preproc_folder),
|
||||
format='binary',
|
||||
**job_kwargs
|
||||
)
|
||||
|
||||
# 3. Check drift
|
||||
drift_info = check_drift(rec_preprocessed, str(output_path))
|
||||
|
||||
# 4. Motion correction (if needed)
|
||||
if apply_motion_correction and drift_info['drift_estimate'] > 20:
|
||||
print(f"Drift > 20 μm detected, applying motion correction...")
|
||||
rec_final = correct_motion(rec_preprocessed, str(output_path))
|
||||
else:
|
||||
print("Skipping motion correction (low drift)")
|
||||
rec_final = rec_preprocessed
|
||||
|
||||
# 5. Spike sorting
|
||||
sorting = run_spike_sorting(rec_final, str(output_path), sorter)
|
||||
|
||||
# 6. Post-processing
|
||||
analyzer, qm = postprocess(sorting, rec_final, str(output_path))
|
||||
|
||||
# 7. Curation
|
||||
labels = curate_units(qm, method=curation_method)
|
||||
|
||||
# 8. Export
|
||||
export_results(analyzer, sorting, rec_final, labels, str(output_path))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Pipeline complete!")
|
||||
print(f"Output directory: {output_path}")
|
||||
print("="*50)
|
||||
|
||||
return analyzer, sorting, qm, labels
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Neuropixels analysis pipeline (best practices)'
|
||||
)
|
||||
parser.add_argument('data_path', help='Path to SpikeGLX/OpenEphys recording')
|
||||
parser.add_argument('output_path', help='Output directory')
|
||||
parser.add_argument('--sorter', default='kilosort4',
|
||||
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'],
|
||||
help='Spike sorter to use')
|
||||
parser.add_argument('--stream', default='imec0.ap', help='Stream name')
|
||||
parser.add_argument('--no-motion-correction', action='store_true',
|
||||
help='Skip motion correction')
|
||||
parser.add_argument('--curation', default='allen',
|
||||
choices=['allen', 'ibl', 'strict'],
|
||||
help='Curation method')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_pipeline(
|
||||
args.data_path,
|
||||
args.output_path,
|
||||
sorter=args.sorter,
|
||||
stream_name=args.stream,
|
||||
apply_motion_correction=not args.no_motion_correction,
|
||||
curation_method=args.curation
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Preprocess Neuropixels recording.
|
||||
|
||||
Usage:
|
||||
python preprocess_recording.py /path/to/data --output preprocessed/ --format spikeglx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
def preprocess_recording(
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
format: str = 'auto',
|
||||
stream_id: str = None,
|
||||
freq_min: float = 300,
|
||||
freq_max: float = 6000,
|
||||
phase_shift: bool = True,
|
||||
common_ref: bool = True,
|
||||
detect_bad: bool = True,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Preprocess a Neuropixels recording."""
|
||||
|
||||
print(f"Loading recording from: {input_path}")
|
||||
|
||||
# Load recording
|
||||
if format == 'spikeglx' or (format == 'auto' and 'imec' in str(input_path).lower()):
|
||||
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
|
||||
elif format == 'openephys':
|
||||
recording = si.read_openephys(input_path)
|
||||
elif format == 'nwb':
|
||||
recording = si.read_nwb(input_path)
|
||||
else:
|
||||
# Try auto-detection
|
||||
try:
|
||||
recording = si.read_spikeglx(input_path, stream_id=stream_id or 'imec0.ap')
|
||||
except:
|
||||
recording = si.load_extractor(input_path)
|
||||
|
||||
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
|
||||
|
||||
# Preprocessing chain
|
||||
rec = recording
|
||||
|
||||
# Bandpass filter
|
||||
print(f"Applying bandpass filter ({freq_min}-{freq_max} Hz)...")
|
||||
rec = si.bandpass_filter(rec, freq_min=freq_min, freq_max=freq_max)
|
||||
|
||||
# Phase shift correction (for Neuropixels ADC)
|
||||
if phase_shift:
|
||||
print("Applying phase shift correction...")
|
||||
rec = si.phase_shift(rec)
|
||||
|
||||
# Bad channel detection
|
||||
if detect_bad:
|
||||
print("Detecting bad channels...")
|
||||
bad_channel_ids, bad_labels = si.detect_bad_channels(rec)
|
||||
if len(bad_channel_ids) > 0:
|
||||
print(f" Removing {len(bad_channel_ids)} bad channels: {bad_channel_ids[:10]}...")
|
||||
rec = rec.remove_channels(bad_channel_ids)
|
||||
|
||||
# Common median reference
|
||||
if common_ref:
|
||||
print("Applying common median reference...")
|
||||
rec = si.common_reference(rec, operator='median', reference='global')
|
||||
|
||||
# Save preprocessed
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Saving preprocessed recording to: {output_path}")
|
||||
rec.save(folder=output_path / 'preprocessed', n_jobs=n_jobs)
|
||||
|
||||
# Save probe info
|
||||
probe = rec.get_probe()
|
||||
if probe is not None:
|
||||
from probeinterface import write_probeinterface
|
||||
write_probeinterface(output_path / 'probe.json', probe)
|
||||
|
||||
print("Done!")
|
||||
print(f" Output channels: {rec.get_num_channels()}")
|
||||
print(f" Output duration: {rec.get_total_duration():.1f}s")
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Preprocess Neuropixels recording')
|
||||
parser.add_argument('input', help='Path to input recording')
|
||||
parser.add_argument('--output', '-o', default='preprocessed/', help='Output directory')
|
||||
parser.add_argument('--format', '-f', default='auto', choices=['auto', 'spikeglx', 'openephys', 'nwb'])
|
||||
parser.add_argument('--stream-id', default=None, help='Stream ID for multi-probe recordings')
|
||||
parser.add_argument('--freq-min', type=float, default=300, help='Highpass cutoff (Hz)')
|
||||
parser.add_argument('--freq-max', type=float, default=6000, help='Lowpass cutoff (Hz)')
|
||||
parser.add_argument('--no-phase-shift', action='store_true', help='Skip phase shift correction')
|
||||
parser.add_argument('--no-cmr', action='store_true', help='Skip common median reference')
|
||||
parser.add_argument('--no-bad-channel', action='store_true', help='Skip bad channel detection')
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocess_recording(
|
||||
args.input,
|
||||
args.output,
|
||||
format=args.format,
|
||||
stream_id=args.stream_id,
|
||||
freq_min=args.freq_min,
|
||||
freq_max=args.freq_max,
|
||||
phase_shift=not args.no_phase_shift,
|
||||
common_ref=not args.no_cmr,
|
||||
detect_bad=not args.no_bad_channel,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Run spike sorting on preprocessed recording.
|
||||
|
||||
Usage:
|
||||
python run_sorting.py preprocessed/ --sorter kilosort4 --output sorting/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import spikeinterface.full as si
|
||||
|
||||
|
||||
# Default parameters for each sorter
|
||||
SORTER_DEFAULTS = {
|
||||
'kilosort4': {
|
||||
'batch_size': 30000,
|
||||
'nblocks': 1,
|
||||
'Th_learned': 8,
|
||||
'Th_universal': 9,
|
||||
},
|
||||
'kilosort3': {
|
||||
'do_CAR': False, # Already done in preprocessing
|
||||
},
|
||||
'spykingcircus2': {
|
||||
'apply_preprocessing': False,
|
||||
},
|
||||
'mountainsort5': {
|
||||
'filter': False,
|
||||
'whiten': False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_sorting(
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
sorter: str = 'kilosort4',
|
||||
sorter_params: dict = None,
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""Run spike sorting."""
|
||||
|
||||
print(f"Loading preprocessed recording from: {input_path}")
|
||||
recording = si.load_extractor(Path(input_path) / 'preprocessed')
|
||||
|
||||
print(f"Recording: {recording.get_num_channels()} channels, {recording.get_total_duration():.1f}s")
|
||||
|
||||
# Get sorter parameters
|
||||
params = SORTER_DEFAULTS.get(sorter, {}).copy()
|
||||
if sorter_params:
|
||||
params.update(sorter_params)
|
||||
|
||||
print(f"Running {sorter} with params: {params}")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
|
||||
# Run sorter (note: parameter is 'folder' not 'output_folder' in newer SpikeInterface)
|
||||
sorting = si.run_sorter(
|
||||
sorter,
|
||||
recording,
|
||||
folder=output_path / f'{sorter}_output',
|
||||
verbose=True,
|
||||
**params,
|
||||
)
|
||||
|
||||
print(f"\nSorting complete!")
|
||||
print(f" Units found: {len(sorting.unit_ids)}")
|
||||
print(f" Total spikes: {sum(len(sorting.get_unit_spike_train(uid)) for uid in sorting.unit_ids)}")
|
||||
|
||||
# Save sorting
|
||||
sorting.save(folder=output_path / 'sorting')
|
||||
print(f" Saved to: {output_path / 'sorting'}")
|
||||
|
||||
return sorting
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run spike sorting')
|
||||
parser.add_argument('input', help='Path to preprocessed recording')
|
||||
parser.add_argument('--output', '-o', default='sorting/', help='Output directory')
|
||||
parser.add_argument('--sorter', '-s', default='kilosort4',
|
||||
choices=['kilosort4', 'kilosort3', 'spykingcircus2', 'mountainsort5'])
|
||||
parser.add_argument('--n-jobs', type=int, default=-1, help='Number of parallel jobs')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_sorting(
|
||||
args.input,
|
||||
args.output,
|
||||
sorter=args.sorter,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user