diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index ec1788350f..962de2dfd8 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -35,11 +35,12 @@ For more details about each metric and it's availability and use within SpikeInt qualitymetrics/isolation_distance qualitymetrics/l_ratio qualitymetrics/nearest_neighbor + qualitymetrics/noise_cutoff qualitymetrics/presence_ratio + qualitymetrics/sd_ratio + qualitymetrics/silhouette_score qualitymetrics/sliding_rp_violations qualitymetrics/snr - qualitymetrics/noise_cutoff - qualitymetrics/silhouette_score qualitymetrics/synchrony diff --git a/doc/modules/qualitymetrics/references.rst b/doc/modules/qualitymetrics/references.rst index 4f10c7b2b7..f5236cff66 100644 --- a/doc/modules/qualitymetrics/references.rst +++ b/doc/modules/qualitymetrics/references.rst @@ -21,6 +21,8 @@ References .. [Llobet] Llobet Victor, Wyngaard Aurélien and Barbour Boris. “Automatic post-processing and merging of multiple spike-sorting analyses with Lussac“. BioRxiv (2022). +.. [Pouzat] Pouzat Christophe, Mazor Ofer and Laurent Gilles. “Using noise signature to optimize spike-sorting and to assess neuronal classification quality“. Journal of Neuroscience Methods (2002). + .. [Rousseeuw] Peter J Rousseeuw. Silhouettes: A graphical aid to the interpretation and validation of cluster analysis. Journal of computational and applied mathematics, 20(C):53–65, 1987. .. [Schmitzer-Torbert] Schmitzer-Torbert, Neil, and A. David Redish. “Neuronal Activity in the Rodent Dorsal Striatum in Sequential Navigation: Separation of Spatial and Reward Responses on the Multiple T Task.” Journal of neurophysiology 91.5 (2004): 2259–2272. Web. diff --git a/doc/modules/qualitymetrics/sd_ratio.rst b/doc/modules/qualitymetrics/sd_ratio.rst new file mode 100644 index 0000000000..701050afcb --- /dev/null +++ b/doc/modules/qualitymetrics/sd_ratio.rst @@ -0,0 +1,38 @@ +Standard Deviation (SD) ratio (:code:`sd_ratio`) +============================================== + +Calculation +----------- + +All spikes from the same neuron should have the same shape. This means that at the peak of the spike, the standard deviation of the voltage should be the same as that of noise. If spikes from multiple neurons are grouped into a single unit, the standard deviation of spike amplitudes would likely be increased. + +This metric, first described [Pouzat]_ then adapted by Wyngaard, Llobet & Barbour (in preparation), returns the ratio between both standard deviations: + +.. math:: + S = \frac{\sigma_{\mathrm{unit}}}{\sigma_{\mathrm{noise}}} + +To remove the effect of drift on spikes amplitude, :math:`\sigma_{\mathrm{unit}}` is computed by subtracting each spike amplitude, and dividing the resulting standard deviation by :math:`\sqrt{2}`. +Also to remove the effect of bursts (which can have lower amplitudes), you can specify a censored period (by default 4.0 ms) where spikes happening less than this period after another spike will not be considered. + + +Expectation and use +------------------- + +For a unit representing a single neuron, this metric should return a value close to one. However for units that are contaminated, the value can be significantly higher. + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + sd_ratio = sqm.compute_sd_ratio(wvf_extractor, censored_period_ms=4.0) + + +Literature +---------- + +Introduced by [Pouzat]_ (2002). +Expanded by Wyngaard, Llobet and Barbour (in preparation). diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 9dab06124b..1e33965db3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -13,8 +13,8 @@ import numpy as np import warnings -from ..postprocessing import correlogram_for_one_segment -from ..core import get_noise_levels +from ..postprocessing import compute_spike_amplitudes, correlogram_for_one_segment +from ..core import WaveformExtractor, get_noise_levels from ..core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, @@ -1365,3 +1365,104 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters, spike_train = spike_trains[spike_clusters == i] n_v = _compute_nb_violations_numba(spike_train, t_r) nb_rp_violations[i] += n_v + + +def compute_sd_ratio( + wvf_extractor: WaveformExtractor, + censored_period_ms: float = 4.0, + correct_for_drift: bool = True, + correct_for_template_itself: bool = True, + unit_ids=None, + **kwargs, +): + """ + Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise. + In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit. + (ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself). + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + censored_period_ms : float, default: 4.0 + The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. + correct_for_drift: bool, default: True + If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift. + correct_for_template_itself: bool, default: True + If true, will take into account that the template itself impacts the standard deviation of the noise, + and will make a rough estimation of what that impact is (and remove it). + unit_ids : list or None, default: None + The list of unit ids to compute this metric. If None, all units are used. + **kwargs: + Keyword arguments for computing spike amplitudes and extremum channel. + TODO: Take jitter into account. + + Returns + ------- + num_spikes : dict + The number of spikes, across all segments, for each unit ID. + """ + + from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative + + censored_period = int(round(censored_period_ms * 1e-3 * wvf_extractor.sampling_frequency)) + if unit_ids is None: + unit_ids = wvf_extractor.unit_ids + + if not wvf_extractor.has_recording(): + warnings.warn( + "The `sd_ratio` metric cannot work with a recordless WaveformExtractor object" + "SD ratio metric will be set to NaN" + ) + return {unit_id: np.nan for unit_id in unit_ids} + + if wvf_extractor.is_extension("spike_amplitudes"): + amplitudes_ext = wvf_extractor.load_extension("spike_amplitudes") + spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit") + else: + warnings.warn( + "The `sd_ratio` metric require the `spike_amplitudes` waveform extension. " + "Use the `postprocessing.compute_spike_amplitudes()` functions. " + "SD ratio metric will be set to NaN" + ) + return {unit_id: np.nan for unit_id in unit_ids} + + noise_levels = get_noise_levels( + wvf_extractor.recording, return_scaled=amplitudes_ext._params["return_scaled"], method="std" + ) + best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs) + n_spikes = wvf_extractor.sorting.count_num_spikes_per_unit() + + sd_ratio = {} + for unit_id in unit_ids: + spk_amp = [] + + for segment_index in range(wvf_extractor.get_num_segments()): + spike_train = wvf_extractor.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( + np.int64 + ) + censored_indices = _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period) + spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices)) + spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))]) + + if correct_for_drift: + unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2) + else: + unit_std = np.std(spk_amp) + + best_channel = best_channels[unit_id] + std_noise = noise_levels[best_channel] + + if correct_for_template_itself: + template = wvf_extractor.get_template(unit_id, force_dense=True)[:, best_channel] + + # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template. + # TODO: Take into account that templates for different segments might differ. + p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples() + total_variance = p * np.mean(template**2) - p**2 * np.mean(template) + + std_noise = np.sqrt(std_noise**2 - total_variance) + + sd_ratio[unit_id] = unit_std / std_noise + + return sd_ratio diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 97f14ec6f4..b2f7fe7fc9 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -14,6 +14,7 @@ compute_synchrony_metrics, compute_firing_ranges, compute_amplitude_cv_metrics, + compute_sd_ratio, ) from .pca_metrics import ( @@ -46,4 +47,5 @@ "synchrony": compute_synchrony_metrics, "firing_range": compute_firing_ranges, "drift": compute_drift_metrics, + "sd_ratio": compute_sd_ratio, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 8a32c4cee8..b1fae6f621 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -34,6 +34,7 @@ compute_synchrony_metrics, compute_firing_ranges, compute_amplitude_cv_metrics, + compute_sd_ratio, ) @@ -70,7 +71,7 @@ def _simulated_data(): def _waveform_extractor_simple(): - recording, sorting = toy_example(duration=50, seed=10) + recording, sorting = toy_example(duration=80, seed=10, firing_rate=6.0) recording = recording.save(folder=cache_folder / "rec1") sorting = sorting.save(folder=cache_folder / "sort1") folder = cache_folder / "waveform_folder1" @@ -86,6 +87,7 @@ def _waveform_extractor_simple(): overwrite=True, ) _ = compute_principal_components(we, n_components=5, mode="by_channel_local") + _ = compute_spike_amplitudes(we, return_scaled=True) return we @@ -227,7 +229,7 @@ def test_calculate_firing_range(waveform_extractor_simple): def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple - spike_amps = compute_spike_amplitudes(we) + spike_amps = we.load_extension("spike_amplitudes").get_data() amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10) print(amp_cuts) @@ -238,7 +240,7 @@ def test_calculate_amplitude_cutoff(waveform_extractor_simple): def test_calculate_amplitude_median(waveform_extractor_simple): we = waveform_extractor_simple - spike_amps = compute_spike_amplitudes(we) + spike_amps = we.load_extension("spike_amplitudes").get_data() amp_medians = compute_amplitude_medians(we) print(spike_amps, amp_medians) @@ -249,7 +251,6 @@ def test_calculate_amplitude_median(waveform_extractor_simple): def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): we = waveform_extractor_simple - spike_amps = compute_spike_amplitudes(we) amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) print(amp_cv_median) print(amp_cv_range) @@ -379,6 +380,13 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) +def test_calculate_sd_ratio(waveform_extractor_simple): + sd_ratio = compute_sd_ratio(waveform_extractor_simple) + + assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids) + assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0) + + if __name__ == "__main__": sim_data = _simulated_data() we = _waveform_extractor_simple() @@ -390,5 +398,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_sliding_rp_violations(we) # test_calculate_drift_metrics(we) # test_synchrony_metrics(we) - test_calculate_firing_range(we) + # test_calculate_firing_range(we) # test_calculate_amplitude_cv_metrics(we) + test_calculate_sd_ratio(we) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index b601e5d6d8..8ab07740d9 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -279,6 +279,9 @@ def test_recordingless(self): # check metrics are the same for metric_name in qm_rec.columns: + if metric_name == "sd_ratio": + continue + # rtol is addedd for sliding_rp_violation, for a reason I do not have to explore now. Sam. assert np.allclose(qm_rec[metric_name].values, qm_no_rec[metric_name].values, rtol=1e-02)