From feff1569ca85c347c676e5bb80c81a61818976dd Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:44:26 +0000 Subject: [PATCH] Added get_random_spikes method to MockWaveformExtractor --- src/spikeinterface/core/analyzer_extension_core.py | 4 +++- .../core/waveforms_extractor_backwards_compatibility.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 108aceab7e..240d19f8f5 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -319,7 +319,9 @@ def _compute_and_append(self, operators): # spikes = self.sorting_analyzer.sorting.to_spike_vector() # some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] - assert self.sorting_analyzer.has_extension("random_spikes"), "compute templates requires the random_spikes extension." + assert ( + self.sorting_analyzer.has_extension("random_spikes") + ), "compute templates requires the random_spikes extension. You can run WaveformExtractor.get_random_spikes()" some_spikes = self.sorting_analyzer.get_extension("random_spikes").some_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index f06b701a5c..b6985e1471 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -131,7 +131,10 @@ def delete_waveforms(self) -> None: self.sorting_analyzer.delete_extension("waveforms") def delete_extension(self, extension) -> None: - self.sorting_analyzer.delete_extension(extension) + self.sorting_analyzer.delete_extension() + + def get_random_spikes(self) -> None: + self.sorting_analyzer.compute("random_spikes") @property def recording(self) -> BaseRecording: