From e1ecf86d4c243a414565d07ed7c43ac85531830e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 24 May 2024 12:27:16 +0200 Subject: [PATCH 01/13] Fix the bug with too many peaks detected --- .../sortingcomponents/peak_detection.py | 65 +++++++++++-------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 508a033c41..ed1faa4133 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -696,10 +696,14 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): conv_traces /= self.abs_thresholds[:, None] conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] + num_z_factors = len(self.z_factors) - num_channels = conv_traces.shape[0] // num_z_factors + num_templates = traces.shape[1] + traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) + conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 + peak_mask = _numba_detect_peak_matched_filtering( conv_traces, traces_center, @@ -708,15 +712,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): self.abs_thresholds, self.peak_sign, self.neighbours_mask, - num_channels, + num_templates, ) # Find peaks and correct for time shift - peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) - - # If we do not want to estimate the z accurately - z = self.z_factors[peak_chan_ind // num_channels] - peak_chan_ind = peak_chan_ind % num_channels + z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) # If we want to estimate z # peak_chan_ind = peak_chan_ind % num_channels @@ -738,7 +738,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): local_peaks["channel_index"] = peak_chan_ind local_peaks["amplitude"] = peak_amplitude local_peaks["segment_index"] = segment_index - local_peaks["z"] = z + local_peaks["z"] = z_ind # return is always a tuple return (local_peaks,) @@ -876,27 +876,40 @@ def _numba_detect_peak_neg( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates ): - num_chans = traces_center.shape[0] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[1]): - if not peak_mask[chan_ind, s]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind % num_channels, neighbour % num_channels]: + num_z = traces_center.shape[0] + for template_ind in range(num_templates): + for z in range(num_z): + for s in range(peak_mask.shape[2]): + if not peak_mask[z, template_ind, s]: continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] >= traces_center[neighbour, s] - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] > traces[neighbour, s + i] - peak_mask[chan_ind, s] &= ( - traces_center[chan_ind, s] >= traces[neighbour, exclude_sweep_size + s + i + 1] - ) - if not peak_mask[chan_ind, s]: + for neighbour in range(num_templates): + if not neighbours_mask[template_ind, neighbour]: + continue + for j in range(num_z): + for i in range(exclude_sweep_size): + if (template_ind >= neighbour): + if (z >= j): + peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + else: + peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + elif (template_ind < neighbour): + if (z > j): + peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + else: + peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces[j, neighbour, s + i] + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces[j, neighbour, exclude_sweep_size + s + i + 1] + ) + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: break - if not peak_mask[chan_ind, s]: - break + return peak_mask From 41359e2149fd1cda0fbb404193d2d150ac7a58ac Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 24 May 2024 13:40:15 +0200 Subject: [PATCH 02/13] Fixing asymmetric prototype --- .../sortingcomponents/peak_detection.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index ed1faa4133..26cfd3324a 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -613,6 +613,7 @@ def __init__( self, recording, prototype, + ms_before, peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, @@ -644,6 +645,7 @@ def __init__( raise NotImplementedError("Matched filtering not working with peak_sign=both yet!") self.peak_sign = peak_sign + self.nbefore = int(ms_before * recording.sampling_frequency / 1000) contact_locations = recording.get_channel_locations() dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) weights, self.z_factors = get_convolution_weights(dist, **weight_method) @@ -689,8 +691,6 @@ def get_trace_margin(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, temporal, spatial, singular, z_factors = self.args - assert HAVE_NUMBA, "You need to install numba" conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular) conv_traces /= self.abs_thresholds[:, None] @@ -730,7 +730,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): if peak_sample_ind.size == 0 or peak_chan_ind.size == 0: return (np.zeros(0, dtype=self._dtype),) - peak_sample_ind += self.exclude_sweep_size + self.conv_margin + peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore peak_amplitude = traces[peak_sample_ind, peak_chan_ind] local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) @@ -747,10 +747,11 @@ def get_convolved_traces(self, traces, temporal, spatial, singular): import scipy.signal num_timesteps, num_templates = len(traces), temporal.shape[1] - scalar_products = np.zeros((num_templates, num_timesteps), dtype=np.float32) + num_peaks = num_timesteps - self.conv_margin + 1 + scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32) spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) scaled_filtered_data = spatially_filtered_data * singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="same") + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) return scalar_products From 3e08f7a950ba1f89c91c736a4a2fb6a1fed05e6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 May 2024 11:45:38 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/peak_detection.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 26cfd3324a..61485900e3 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -703,7 +703,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 - + peak_mask = _numba_detect_peak_matched_filtering( conv_traces, traces_center, @@ -890,19 +890,30 @@ def _numba_detect_peak_matched_filtering( continue for j in range(num_z): for i in range(exclude_sweep_size): - if (template_ind >= neighbour): - if (z >= j): - peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + if template_ind >= neighbour: + if z >= j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + ) else: - peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - elif (template_ind < neighbour): - if (z > j): - peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + elif template_ind < neighbour: + if z > j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) else: - peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - peak_mask[z, template_ind, s] &= traces_center[z, template_ind, s] > traces[j, neighbour, s + i] + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces[j, neighbour, s + i] + ) peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] >= traces[j, neighbour, exclude_sweep_size + s + i + 1] + traces_center[z, template_ind, s] + >= traces[j, neighbour, exclude_sweep_size + s + i + 1] ) if not peak_mask[z, template_ind, s]: break @@ -910,7 +921,7 @@ def _numba_detect_peak_matched_filtering( break if not peak_mask[z, template_ind, s]: break - + return peak_mask From e81b96af446ad12549a540f021b8135a05a5f3b5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 24 May 2024 13:51:08 +0200 Subject: [PATCH 04/13] Fixes tests --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + .../sortingcomponents/tests/test_peak_detection.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c1021e787a..05853b4c39 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -179,6 +179,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["matched_filtering"]: prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) detection_params["prototype"] = prototype + detection_params["ms_before"] = ms_before for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in detection_params: diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 2ecccb421c..fbab084f0d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -323,6 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) detect_threshold=5, exclude_sweep_ms=0.1, prototype=prototype, + ms_before=1.0 **job_kwargs, ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np) From 0bdb2ed9cd02b350e1bdcf2c2af305bf1182b38f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 May 2024 11:51:38 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/tests/test_peak_detection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index fbab084f0d..9d00ce24a7 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -323,8 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) detect_threshold=5, exclude_sweep_ms=0.1, prototype=prototype, - ms_before=1.0 - **job_kwargs, + ms_before=1.0**job_kwargs, ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np) From 63d80a5d33877c30307a026fe91ef273daa3222b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 24 May 2024 14:09:27 +0200 Subject: [PATCH 06/13] typo --- .../sortingcomponents/tests/test_peak_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index fbab084f0d..fa30ba3483 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -323,7 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) detect_threshold=5, exclude_sweep_ms=0.1, prototype=prototype, - ms_before=1.0 + ms_before=1.0, **job_kwargs, ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np) From 91a6707901001740b92af43a6a8cec03d3cac089 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 May 2024 18:02:39 +0200 Subject: [PATCH 07/13] Fix more verbosity propagation --- .../core/analyzer_extension_core.py | 6 +++-- src/spikeinterface/core/node_pipeline.py | 2 ++ src/spikeinterface/core/sortinganalyzer.py | 23 ++++++++++--------- src/spikeinterface/core/waveform_tools.py | 9 +++++--- .../postprocessing/amplitude_scalings.py | 3 ++- .../postprocessing/principal_component.py | 4 ++-- .../postprocessing/spike_amplitudes.py | 3 ++- .../postprocessing/spike_locations.py | 3 ++- .../sortingcomponents/matching/main.py | 4 ++-- .../sortingcomponents/peak_detection.py | 2 -- 10 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 066194725d..8766718304 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -145,7 +145,7 @@ def nbefore(self): def nafter(self): return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() recording = self.sorting_analyzer.recording @@ -183,6 +183,7 @@ def _run(self, **job_kwargs): sparsity_mask=sparsity_mask, copy=copy, job_name="compute_waveforms", + verbose=verbose, **job_kwargs, ) @@ -311,7 +312,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N ) return params - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() if self.sorting_analyzer.has_extension("waveforms"): @@ -339,6 +340,7 @@ def _run(self, **job_kwargs): self.nafter, return_scaled=return_scaled, return_std=return_std, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ee6cf5268d..3585b07b23 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -473,6 +473,7 @@ def run_node_pipeline( squeeze_output=True, folder=None, names=None, + verbose=False, ): """ Common function to run pipeline with peak detector or already detected peak. @@ -499,6 +500,7 @@ def run_node_pipeline( init_args, gather_func=gather_func, job_name=job_name, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d9fcf44442..7cc491d3db 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -835,7 +835,7 @@ def get_num_units(self) -> int: return self.sorting.get_num_units() ## extensions zone - def compute(self, input, save=True, extension_params=None, **kwargs): + def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs): """ Compute one extension or several extensiosn. Internally calls compute_one_extension() or compute_several_extensions() depending on the input type. @@ -883,11 +883,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ) """ if isinstance(input, str): - return self.compute_one_extension(extension_name=input, save=save, **kwargs) + return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs) elif isinstance(input, dict): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" - self.compute_several_extensions(extensions=input, save=save, **job_kwargs) + self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs) elif isinstance(input, list): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" @@ -898,11 +898,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ext_name in input ), f"SortingAnalyzer.compute(): Parameters specified for {ext_name}, which is not in the specified {input}" extensions[ext_name] = ext_params - self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, **kwargs): + def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): """ Compute one extension. @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): Returns ------- result_extension: AnalyzerExtension - Return the extension instance. + Return the extension insdef computance. Examples -------- @@ -967,7 +967,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): return extension_instance - def compute_several_extensions(self, extensions, save=True, **job_kwargs): + def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs): """ Compute several extensions @@ -1021,9 +1021,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_without_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) # then extensions with pipeline if len(extensions_with_pipeline) > 0: all_nodes = [] @@ -1053,6 +1053,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): job_name=job_name, gather_mode="memory", squeeze_output=False, + verbose=verbose ) for r, result in enumerate(results): @@ -1071,9 +1072,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_post_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) def get_saved_extension_names(self): """ diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58966334db..8ecaac6163 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -221,6 +221,7 @@ def distribute_waveforms_to_buffers( mode="memmap", sparsity_mask=None, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -281,7 +282,7 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() @@ -410,6 +411,7 @@ def extract_waveforms_to_single_buffer( sparsity_mask=None, copy=True, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -523,7 +525,7 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() if mode == "memmap": @@ -783,6 +785,7 @@ def estimate_templates_with_accumulator( return_scaled: bool = True, job_name=None, return_std: bool = False, + verbose: bool = False, **job_kwargs, ): """ @@ -861,7 +864,7 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() # average diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index e2dcdd8e5a..d2b363e69a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -181,7 +181,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amp_scalings, collision_mask = run_node_pipeline( @@ -190,6 +190,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="amplitude_scalings", gather_mode="memory", + verbose=verbose, ) self.data["amplitude_scalings"] = amp_scalings if self.params["handle_collisions"]: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 76d7c1744e..60d517fdc1 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -295,7 +295,7 @@ def _run(self, **job_kwargs): def _get_data(self): return self.data["pca_projection"] - def run_for_all_spikes(self, file_path=None, **job_kwargs): + def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): """ Project all spikes from the sorting on the PCA model. This is a long computation because waveform need to be extracted from each spikes. @@ -359,7 +359,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs) processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index add9764790..cc1d4b26e9 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -107,7 +107,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( @@ -116,6 +116,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_amplitudes", gather_mode="memory", + verbose=False, ) self.data["amplitudes"] = amps diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d1e1d38c6a..d468bd90ab 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -120,7 +120,7 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( @@ -129,6 +129,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_locations", gather_mode="memory", + verbose=verbose, ) self.data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 1c5c947b02..e3c5ca8222 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -7,7 +7,7 @@ from spikeinterface.core import get_chunk_with_margin -def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, **job_kwargs): +def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs): """Find spike from a recording from given templates. Parameters @@ -53,7 +53,7 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr init_func = _init_worker_find_spikes init_args = (recording, method, method_kwargs_seralized) processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", **job_kwargs + recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", verbose=verbose, **job_kwargs ) spikes = processor.run() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 508a033c41..4e1fa64961 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -9,10 +9,8 @@ import numpy as np from spikeinterface.core.job_tools import ( - ChunkRecordingExecutor, _shared_job_kwargs_doc, split_job_kwargs, - fix_job_kwargs, ) from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks From 4b1b80401f3495cb4daad84ff2c9e9b6198adcc2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 May 2024 16:03:21 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 10 +++++----- src/spikeinterface/core/waveform_tools.py | 12 +++++++++--- .../postprocessing/principal_component.py | 4 +++- .../sortingcomponents/matching/main.py | 13 +++++++++++-- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7cc491d3db..cadf4036fc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -902,7 +902,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): + def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): """ Compute one extension. @@ -1023,7 +1023,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job if extension_class.need_job_kwargs: self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) # then extensions with pipeline if len(extensions_with_pipeline) > 0: all_nodes = [] @@ -1053,7 +1053,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job job_name=job_name, gather_mode="memory", squeeze_output=False, - verbose=verbose + verbose=verbose, ) for r, result in enumerate(results): @@ -1072,9 +1072,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job for extension_name, extension_params in extensions_post_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) def get_saved_extension_names(self): """ diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 8ecaac6163..acc368b2e5 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -282,7 +282,9 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() @@ -525,7 +527,9 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() if mode == "memmap": @@ -864,7 +868,9 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() # average diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 60d517fdc1..95b38f52f9 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -359,7 +359,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs + ) processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index e3c5ca8222..9476a0df03 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -7,7 +7,9 @@ from spikeinterface.core import get_chunk_with_margin -def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs): +def find_spikes_from_templates( + recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs +): """Find spike from a recording from given templates. Parameters @@ -53,7 +55,14 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr init_func = _init_worker_find_spikes init_args = (recording, method, method_kwargs_seralized) processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", verbose=verbose, **job_kwargs + recording, + func, + init_func, + init_args, + handle_returns=True, + job_name=f"find spikes ({method})", + verbose=verbose, + **job_kwargs, ) spikes = processor.run() From 2de6fb3e9e5106c236a43cf4ad2b9b96362cc73c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 May 2024 18:22:28 +0200 Subject: [PATCH 09/13] oups --- src/spikeinterface/core/sortinganalyzer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7cc491d3db..bfa64b0927 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -961,7 +961,10 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kw extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) - extension_instance.run(save=save, **job_kwargs) + if extension_class.need_job_kwargs: + extension_instance.run(save=save, verbose=verbose, **job_kwargs) + else: + extension_instance.run(save=save) self.extensions[extension_name] = extension_instance From 2b838ba7dc35ff954b526c89be4d13dafe4afaea Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 May 2024 19:25:10 +0200 Subject: [PATCH 10/13] verbose in all _run() for extension --- src/spikeinterface/core/analyzer_extension_core.py | 7 +++---- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/postprocessing/correlograms.py | 2 +- src/spikeinterface/postprocessing/principal_component.py | 2 +- src/spikeinterface/postprocessing/template_metrics.py | 2 +- src/spikeinterface/postprocessing/template_similarity.py | 2 +- src/spikeinterface/postprocessing/unit_localization.py | 2 +- 7 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 8766718304..073708f353 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _run( - self, - ): + def _run(self, verbose=False): + self.data["random_spikes_indices"] = random_spikes_selection( self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], @@ -583,7 +582,7 @@ def _select_extension_data(self, unit_ids): # this do not depend on units return self.data - def _run(self): + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3e9a842d3a..99fb51c029 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -964,7 +964,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar if extension_class.need_job_kwargs: extension_instance.run(save=save, verbose=verbose, **job_kwargs) else: - extension_instance.run(save=save) + extension_instance.run(save=save, verbose=verbose) self.extensions[extension_name] = extension_instance diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3a01305d6b..f0bd151c68 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -70,7 +70,7 @@ def _select_extension_data(self, unit_ids): new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data - def _run(self): + def _run(self, verbose=False): ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 95b38f52f9..8eb375e90b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -256,7 +256,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) return new_projections - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): """ Compute the PCs on waveforms extacted within the by ComputeWaveforms. Projections are computed only on the waveforms sampled by the SortingAnalyzer. diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 19e6a1b47a..d7179ffefa 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -150,7 +150,7 @@ def _select_extension_data(self, unit_ids): new_metrics = self.data["metrics"].loc[np.array(unit_ids)] return dict(metrics=new_metrics) - def _run(self): + def _run(self, verbose=False): import pandas as pd from scipy.signal import resample_poly diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 4a45da0269..15a1fe34ce 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -42,7 +42,7 @@ def _select_extension_data(self, unit_ids): new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) - def _run(self): + def _run(self, verbose=False): templates_array = get_dense_templates_array( self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled ) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d1d3b5075a..e40523e7e5 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -64,7 +64,7 @@ def _select_extension_data(self, unit_ids): new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) - def _run(self): + def _run(self, verbose=False): method = self.params["method"] method_kwargs = self.params["method_kwargs"] From 26157caee3e3da614a0babc3266761a943c1aa97 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 24 May 2024 19:25:39 +0200 Subject: [PATCH 11/13] oups --- src/spikeinterface/postprocessing/isi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index c7e850993f..3742cbfa96 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -56,7 +56,7 @@ def _select_extension_data(self, unit_ids): new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data - def _run(self): + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins From ee4529f51eae58a22d8010c6f52f7a60e510914f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 24 May 2024 12:09:13 -0600 Subject: [PATCH 12/13] Update src/spikeinterface/core/sortinganalyzer.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 99fb51c029..53e060262b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar Returns ------- result_extension: AnalyzerExtension - Return the extension insdef computance. + Return the extension instance Examples -------- From 790715c959898baaf71335dd7c27d233fd434cdc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 25 May 2024 04:48:40 -0600 Subject: [PATCH 13/13] add zarr arguments (#2909) --- src/spikeinterface/extractors/nwbextractors.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 66a2a65bfb..2aa34533a6 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -440,7 +440,8 @@ class NwbRecordingExtractor(BaseRecording): stream_cache_path: str, Path, or None, default: None Specifies the local path for caching the file. Relevant only if `cache` is True. storage_options: dict | None = None, - Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode. + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. use_pynwb: bool, default: False Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. @@ -861,8 +862,10 @@ def _fetch_main_properties_backend(self): @staticmethod def fetch_available_electrical_series_paths( - file_path: str | Path, stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None - ) -> List[str]: + file_path: str | Path, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + storage_options: dict | None = None, + ) -> list[str]: """ Retrieves the paths to all ElectricalSeries objects within a neurodata file. @@ -873,7 +876,9 @@ def fetch_available_electrical_series_paths( stream_mode : "fsspec" | "remfile" | "zarr" | None, optional Determines the streaming mode for reading the file. Use this for optimized reading from different sources, such as local disk or remote servers. - + storage_options: dict | None = None, + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. Returns ------- list of str @@ -901,6 +906,7 @@ def fetch_available_electrical_series_paths( file_handle = read_file_from_backend( file_path=file_path, stream_mode=stream_mode, + storage_options=storage_options, ) electrical_series_paths = _find_neurodata_type_from_backend( @@ -988,7 +994,8 @@ class NwbSortingExtractor(BaseSorting): If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. storage_options: dict | None = None, - Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode. + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. use_pynwb: bool, default: False Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.