Skip to content

Commit

Permalink
Merge pull request #2146 from DradeAW/sd_test
Browse files Browse the repository at this point in the history
Implemented `sd_ratio` as quality metric
  • Loading branch information
alejoe91 authored Nov 22, 2023
2 parents 43ab557 + cd5eb3f commit 20974b3
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 9 deletions.
5 changes: 3 additions & 2 deletions doc/modules/qualitymetrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ For more details about each metric and it's availability and use within SpikeInt
qualitymetrics/isolation_distance
qualitymetrics/l_ratio
qualitymetrics/nearest_neighbor
qualitymetrics/noise_cutoff
qualitymetrics/presence_ratio
qualitymetrics/sd_ratio
qualitymetrics/silhouette_score
qualitymetrics/sliding_rp_violations
qualitymetrics/snr
qualitymetrics/noise_cutoff
qualitymetrics/silhouette_score
qualitymetrics/synchrony


Expand Down
2 changes: 2 additions & 0 deletions doc/modules/qualitymetrics/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ References
.. [Llobet] Llobet Victor, Wyngaard Aurélien and Barbour Boris. “Automatic post-processing and merging of multiple spike-sorting analyses with Lussac“. BioRxiv (2022).
.. [Pouzat] Pouzat Christophe, Mazor Ofer and Laurent Gilles. “Using noise signature to optimize spike-sorting and to assess neuronal classification quality“. Journal of Neuroscience Methods (2002).
.. [Rousseeuw] Peter J Rousseeuw. Silhouettes: A graphical aid to the interpretation and validation of cluster analysis. Journal of computational and applied mathematics, 20(C):53–65, 1987.
.. [Schmitzer-Torbert] Schmitzer-Torbert, Neil, and A. David Redish. “Neuronal Activity in the Rodent Dorsal Striatum in Sequential Navigation: Separation of Spatial and Reward Responses on the Multiple T Task.” Journal of neurophysiology 91.5 (2004): 2259–2272. Web.
Expand Down
38 changes: 38 additions & 0 deletions doc/modules/qualitymetrics/sd_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Standard Deviation (SD) ratio (:code:`sd_ratio`)
==============================================

Calculation
-----------

All spikes from the same neuron should have the same shape. This means that at the peak of the spike, the standard deviation of the voltage should be the same as that of noise. If spikes from multiple neurons are grouped into a single unit, the standard deviation of spike amplitudes would likely be increased.

This metric, first described [Pouzat]_ then adapted by Wyngaard, Llobet & Barbour (in preparation), returns the ratio between both standard deviations:

.. math::
S = \frac{\sigma_{\mathrm{unit}}}{\sigma_{\mathrm{noise}}}
To remove the effect of drift on spikes amplitude, :math:`\sigma_{\mathrm{unit}}` is computed by subtracting each spike amplitude, and dividing the resulting standard deviation by :math:`\sqrt{2}`.
Also to remove the effect of bursts (which can have lower amplitudes), you can specify a censored period (by default 4.0 ms) where spikes happening less than this period after another spike will not be considered.


Expectation and use
-------------------

For a unit representing a single neuron, this metric should return a value close to one. However for units that are contaminated, the value can be significantly higher.


Example code
------------

.. code-block:: python
import spikeinterface.qualitymetrics as sqm
sd_ratio = sqm.compute_sd_ratio(wvf_extractor, censored_period_ms=4.0)
Literature
----------

Introduced by [Pouzat]_ (2002).
Expanded by Wyngaard, Llobet and Barbour (in preparation).
105 changes: 103 additions & 2 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import warnings

from ..postprocessing import correlogram_for_one_segment
from ..core import get_noise_levels
from ..postprocessing import compute_spike_amplitudes, correlogram_for_one_segment
from ..core import WaveformExtractor, get_noise_levels
from ..core.template_tools import (
get_template_extremum_channel,
get_template_extremum_amplitude,
Expand Down Expand Up @@ -1365,3 +1365,104 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters,
spike_train = spike_trains[spike_clusters == i]
n_v = _compute_nb_violations_numba(spike_train, t_r)
nb_rp_violations[i] += n_v


def compute_sd_ratio(
wvf_extractor: WaveformExtractor,
censored_period_ms: float = 4.0,
correct_for_drift: bool = True,
correct_for_template_itself: bool = True,
unit_ids=None,
**kwargs,
):
"""
Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise.
In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit.
(ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself).
Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
censored_period_ms : float, default: 4.0
The censored period in milliseconds. This is to remove any potential bursts that could affect the SD.
correct_for_drift: bool, default: True
If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift.
correct_for_template_itself: bool, default: True
If true, will take into account that the template itself impacts the standard deviation of the noise,
and will make a rough estimation of what that impact is (and remove it).
unit_ids : list or None, default: None
The list of unit ids to compute this metric. If None, all units are used.
**kwargs:
Keyword arguments for computing spike amplitudes and extremum channel.
TODO: Take jitter into account.
Returns
-------
num_spikes : dict
The number of spikes, across all segments, for each unit ID.
"""

from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative

censored_period = int(round(censored_period_ms * 1e-3 * wvf_extractor.sampling_frequency))
if unit_ids is None:
unit_ids = wvf_extractor.unit_ids

if not wvf_extractor.has_recording():
warnings.warn(
"The `sd_ratio` metric cannot work with a recordless WaveformExtractor object"
"SD ratio metric will be set to NaN"
)
return {unit_id: np.nan for unit_id in unit_ids}

if wvf_extractor.is_extension("spike_amplitudes"):
amplitudes_ext = wvf_extractor.load_extension("spike_amplitudes")
spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit")
else:
warnings.warn(
"The `sd_ratio` metric require the `spike_amplitudes` waveform extension. "
"Use the `postprocessing.compute_spike_amplitudes()` functions. "
"SD ratio metric will be set to NaN"
)
return {unit_id: np.nan for unit_id in unit_ids}

noise_levels = get_noise_levels(
wvf_extractor.recording, return_scaled=amplitudes_ext._params["return_scaled"], method="std"
)
best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs)
n_spikes = wvf_extractor.sorting.count_num_spikes_per_unit()

sd_ratio = {}
for unit_id in unit_ids:
spk_amp = []

for segment_index in range(wvf_extractor.get_num_segments()):
spike_train = wvf_extractor.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
np.int64
)
censored_indices = _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period)
spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices))
spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))])

if correct_for_drift:
unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2)
else:
unit_std = np.std(spk_amp)

best_channel = best_channels[unit_id]
std_noise = noise_levels[best_channel]

if correct_for_template_itself:
template = wvf_extractor.get_template(unit_id, force_dense=True)[:, best_channel]

# Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
# TODO: Take into account that templates for different segments might differ.
p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples()
total_variance = p * np.mean(template**2) - p**2 * np.mean(template)

std_noise = np.sqrt(std_noise**2 - total_variance)

sd_ratio[unit_id] = unit_std / std_noise

return sd_ratio
2 changes: 2 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)

from .pca_metrics import (
Expand Down Expand Up @@ -46,4 +47,5 @@
"synchrony": compute_synchrony_metrics,
"firing_range": compute_firing_ranges,
"drift": compute_drift_metrics,
"sd_ratio": compute_sd_ratio,
}
19 changes: 14 additions & 5 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def _simulated_data():


def _waveform_extractor_simple():
recording, sorting = toy_example(duration=50, seed=10)
recording, sorting = toy_example(duration=80, seed=10, firing_rate=6.0)
recording = recording.save(folder=cache_folder / "rec1")
sorting = sorting.save(folder=cache_folder / "sort1")
folder = cache_folder / "waveform_folder1"
Expand All @@ -86,6 +87,7 @@ def _waveform_extractor_simple():
overwrite=True,
)
_ = compute_principal_components(we, n_components=5, mode="by_channel_local")
_ = compute_spike_amplitudes(we, return_scaled=True)
return we


Expand Down Expand Up @@ -227,7 +229,7 @@ def test_calculate_firing_range(waveform_extractor_simple):

def test_calculate_amplitude_cutoff(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
spike_amps = we.load_extension("spike_amplitudes").get_data()
amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10)
print(amp_cuts)

Expand All @@ -238,7 +240,7 @@ def test_calculate_amplitude_cutoff(waveform_extractor_simple):

def test_calculate_amplitude_median(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
spike_amps = we.load_extension("spike_amplitudes").get_data()
amp_medians = compute_amplitude_medians(we)
print(spike_amps, amp_medians)

Expand All @@ -249,7 +251,6 @@ def test_calculate_amplitude_median(waveform_extractor_simple):

def test_calculate_amplitude_cv_metrics(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20)
print(amp_cv_median)
print(amp_cv_range)
Expand Down Expand Up @@ -379,6 +380,13 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05)


def test_calculate_sd_ratio(waveform_extractor_simple):
sd_ratio = compute_sd_ratio(waveform_extractor_simple)

assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids)
assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0)


if __name__ == "__main__":
sim_data = _simulated_data()
we = _waveform_extractor_simple()
Expand All @@ -390,5 +398,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# test_calculate_sliding_rp_violations(we)
# test_calculate_drift_metrics(we)
# test_synchrony_metrics(we)
test_calculate_firing_range(we)
# test_calculate_firing_range(we)
# test_calculate_amplitude_cv_metrics(we)
test_calculate_sd_ratio(we)
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def test_recordingless(self):

# check metrics are the same
for metric_name in qm_rec.columns:
if metric_name == "sd_ratio":
continue

# rtol is addedd for sliding_rp_violation, for a reason I do not have to explore now. Sam.
assert np.allclose(qm_rec[metric_name].values, qm_no_rec[metric_name].values, rtol=1e-02)

Expand Down

0 comments on commit 20974b3

Please sign in to comment.