diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 066194725d..073708f353 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _run( - self, - ): + def _run(self, verbose=False): + self.data["random_spikes_indices"] = random_spikes_selection( self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], @@ -145,7 +144,7 @@ def nbefore(self): def nafter(self): return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() recording = self.sorting_analyzer.recording @@ -183,6 +182,7 @@ def _run(self, **job_kwargs): sparsity_mask=sparsity_mask, copy=copy, job_name="compute_waveforms", + verbose=verbose, **job_kwargs, ) @@ -311,7 +311,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N ) return params - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() if self.sorting_analyzer.has_extension("waveforms"): @@ -339,6 +339,7 @@ def _run(self, **job_kwargs): self.nafter, return_scaled=return_scaled, return_std=return_std, + verbose=verbose, **job_kwargs, ) @@ -581,7 +582,7 @@ def _select_extension_data(self, unit_ids): # this do not depend on units return self.data - def _run(self): + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params ) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ee6cf5268d..3585b07b23 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -473,6 +473,7 @@ def run_node_pipeline( squeeze_output=True, folder=None, names=None, + verbose=False, ): """ Common function to run pipeline with peak detector or already detected peak. @@ -499,6 +500,7 @@ def run_node_pipeline( init_args, gather_func=gather_func, job_name=job_name, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d9fcf44442..53e060262b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -835,7 +835,7 @@ def get_num_units(self) -> int: return self.sorting.get_num_units() ## extensions zone - def compute(self, input, save=True, extension_params=None, **kwargs): + def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs): """ Compute one extension or several extensiosn. Internally calls compute_one_extension() or compute_several_extensions() depending on the input type. @@ -883,11 +883,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ) """ if isinstance(input, str): - return self.compute_one_extension(extension_name=input, save=save, **kwargs) + return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs) elif isinstance(input, dict): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" - self.compute_several_extensions(extensions=input, save=save, **job_kwargs) + self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs) elif isinstance(input, list): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" @@ -898,11 +898,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ext_name in input ), f"SortingAnalyzer.compute(): Parameters specified for {ext_name}, which is not in the specified {input}" extensions[ext_name] = ext_params - self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, **kwargs): + def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): """ Compute one extension. @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): Returns ------- result_extension: AnalyzerExtension - Return the extension instance. + Return the extension instance Examples -------- @@ -961,13 +961,16 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) - extension_instance.run(save=save, **job_kwargs) + if extension_class.need_job_kwargs: + extension_instance.run(save=save, verbose=verbose, **job_kwargs) + else: + extension_instance.run(save=save, verbose=verbose) self.extensions[extension_name] = extension_instance return extension_instance - def compute_several_extensions(self, extensions, save=True, **job_kwargs): + def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs): """ Compute several extensions @@ -1021,9 +1024,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_without_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) # then extensions with pipeline if len(extensions_with_pipeline) > 0: all_nodes = [] @@ -1053,6 +1056,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): job_name=job_name, gather_mode="memory", squeeze_output=False, + verbose=verbose, ) for r, result in enumerate(results): @@ -1071,9 +1075,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_post_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) def get_saved_extension_names(self): """ diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58966334db..acc368b2e5 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -221,6 +221,7 @@ def distribute_waveforms_to_buffers( mode="memmap", sparsity_mask=None, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -281,7 +282,9 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() @@ -410,6 +413,7 @@ def extract_waveforms_to_single_buffer( sparsity_mask=None, copy=True, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -523,7 +527,9 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() if mode == "memmap": @@ -783,6 +789,7 @@ def estimate_templates_with_accumulator( return_scaled: bool = True, job_name=None, return_std: bool = False, + verbose: bool = False, **job_kwargs, ): """ @@ -861,7 +868,9 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() # average diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 66a2a65bfb..2aa34533a6 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -440,7 +440,8 @@ class NwbRecordingExtractor(BaseRecording): stream_cache_path: str, Path, or None, default: None Specifies the local path for caching the file. Relevant only if `cache` is True. storage_options: dict | None = None, - Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode. + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. use_pynwb: bool, default: False Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. @@ -861,8 +862,10 @@ def _fetch_main_properties_backend(self): @staticmethod def fetch_available_electrical_series_paths( - file_path: str | Path, stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None - ) -> List[str]: + file_path: str | Path, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + storage_options: dict | None = None, + ) -> list[str]: """ Retrieves the paths to all ElectricalSeries objects within a neurodata file. @@ -873,7 +876,9 @@ def fetch_available_electrical_series_paths( stream_mode : "fsspec" | "remfile" | "zarr" | None, optional Determines the streaming mode for reading the file. Use this for optimized reading from different sources, such as local disk or remote servers. - + storage_options: dict | None = None, + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. Returns ------- list of str @@ -901,6 +906,7 @@ def fetch_available_electrical_series_paths( file_handle = read_file_from_backend( file_path=file_path, stream_mode=stream_mode, + storage_options=storage_options, ) electrical_series_paths = _find_neurodata_type_from_backend( @@ -988,7 +994,8 @@ class NwbSortingExtractor(BaseSorting): If True, the file is cached in the file passed to stream_cache_path if False, the file is not cached. storage_options: dict | None = None, - Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode. + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. use_pynwb: bool, default: False Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index e2dcdd8e5a..d2b363e69a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -181,7 +181,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amp_scalings, collision_mask = run_node_pipeline( @@ -190,6 +190,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="amplitude_scalings", gather_mode="memory", + verbose=verbose, ) self.data["amplitude_scalings"] = amp_scalings if self.params["handle_collisions"]: diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3a01305d6b..f0bd151c68 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -70,7 +70,7 @@ def _select_extension_data(self, unit_ids): new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data - def _run(self): + def _run(self, verbose=False): ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index c7e850993f..3742cbfa96 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -56,7 +56,7 @@ def _select_extension_data(self, unit_ids): new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data - def _run(self): + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 76d7c1744e..8eb375e90b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -256,7 +256,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) return new_projections - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): """ Compute the PCs on waveforms extacted within the by ComputeWaveforms. Projections are computed only on the waveforms sampled by the SortingAnalyzer. @@ -295,7 +295,7 @@ def _run(self, **job_kwargs): def _get_data(self): return self.data["pca_projection"] - def run_for_all_spikes(self, file_path=None, **job_kwargs): + def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): """ Project all spikes from the sorting on the PCA model. This is a long computation because waveform need to be extracted from each spikes. @@ -359,7 +359,9 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs + ) processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index add9764790..cc1d4b26e9 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -107,7 +107,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( @@ -116,6 +116,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_amplitudes", gather_mode="memory", + verbose=False, ) self.data["amplitudes"] = amps diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d1e1d38c6a..d468bd90ab 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -120,7 +120,7 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( @@ -129,6 +129,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_locations", gather_mode="memory", + verbose=verbose, ) self.data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 19e6a1b47a..d7179ffefa 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -150,7 +150,7 @@ def _select_extension_data(self, unit_ids): new_metrics = self.data["metrics"].loc[np.array(unit_ids)] return dict(metrics=new_metrics) - def _run(self): + def _run(self, verbose=False): import pandas as pd from scipy.signal import resample_poly diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 4a45da0269..15a1fe34ce 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -42,7 +42,7 @@ def _select_extension_data(self, unit_ids): new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) - def _run(self): + def _run(self, verbose=False): templates_array = get_dense_templates_array( self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled ) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d1d3b5075a..e40523e7e5 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -64,7 +64,7 @@ def _select_extension_data(self, unit_ids): new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) - def _run(self): + def _run(self, verbose=False): method = self.params["method"] method_kwargs = self.params["method_kwargs"] diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f78bff7a54..1b360b5dfa 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -179,6 +179,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["matched_filtering"]: prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) detection_params["prototype"] = prototype + detection_params["ms_before"] = ms_before for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in detection_params: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 1c5c947b02..9476a0df03 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -7,7 +7,9 @@ from spikeinterface.core import get_chunk_with_margin -def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, **job_kwargs): +def find_spikes_from_templates( + recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs +): """Find spike from a recording from given templates. Parameters @@ -53,7 +55,14 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr init_func = _init_worker_find_spikes init_args = (recording, method, method_kwargs_seralized) processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", **job_kwargs + recording, + func, + init_func, + init_args, + handle_returns=True, + job_name=f"find spikes ({method})", + verbose=verbose, + **job_kwargs, ) spikes = processor.run() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 508a033c41..d23f0fec74 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -9,10 +9,8 @@ import numpy as np from spikeinterface.core.job_tools import ( - ChunkRecordingExecutor, _shared_job_kwargs_doc, split_job_kwargs, - fix_job_kwargs, ) from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks @@ -613,6 +611,7 @@ def __init__( self, recording, prototype, + ms_before, peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, @@ -644,6 +643,7 @@ def __init__( raise NotImplementedError("Matched filtering not working with peak_sign=both yet!") self.peak_sign = peak_sign + self.nbefore = int(ms_before * recording.sampling_frequency / 1000) contact_locations = recording.get_channel_locations() dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) weights, self.z_factors = get_convolution_weights(dist, **weight_method) @@ -689,17 +689,19 @@ def get_trace_margin(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, temporal, spatial, singular, z_factors = self.args - assert HAVE_NUMBA, "You need to install numba" conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular) conv_traces /= self.abs_thresholds[:, None] conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] + num_z_factors = len(self.z_factors) - num_channels = conv_traces.shape[0] // num_z_factors + num_templates = traces.shape[1] + traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) + conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 + peak_mask = _numba_detect_peak_matched_filtering( conv_traces, traces_center, @@ -708,15 +710,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): self.abs_thresholds, self.peak_sign, self.neighbours_mask, - num_channels, + num_templates, ) # Find peaks and correct for time shift - peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) - - # If we do not want to estimate the z accurately - z = self.z_factors[peak_chan_ind // num_channels] - peak_chan_ind = peak_chan_ind % num_channels + z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) # If we want to estimate z # peak_chan_ind = peak_chan_ind % num_channels @@ -730,7 +728,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): if peak_sample_ind.size == 0 or peak_chan_ind.size == 0: return (np.zeros(0, dtype=self._dtype),) - peak_sample_ind += self.exclude_sweep_size + self.conv_margin + peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore peak_amplitude = traces[peak_sample_ind, peak_chan_ind] local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) @@ -738,7 +736,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): local_peaks["channel_index"] = peak_chan_ind local_peaks["amplitude"] = peak_amplitude local_peaks["segment_index"] = segment_index - local_peaks["z"] = z + local_peaks["z"] = z_ind # return is always a tuple return (local_peaks,) @@ -747,10 +745,11 @@ def get_convolved_traces(self, traces, temporal, spatial, singular): import scipy.signal num_timesteps, num_templates = len(traces), temporal.shape[1] - scalar_products = np.zeros((num_templates, num_timesteps), dtype=np.float32) + num_peaks = num_timesteps - self.conv_margin + 1 + scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32) spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) scaled_filtered_data = spatially_filtered_data * singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="same") + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) return scalar_products @@ -876,27 +875,51 @@ def _numba_detect_peak_neg( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates ): - num_chans = traces_center.shape[0] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[1]): - if not peak_mask[chan_ind, s]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind % num_channels, neighbour % num_channels]: + num_z = traces_center.shape[0] + for template_ind in range(num_templates): + for z in range(num_z): + for s in range(peak_mask.shape[2]): + if not peak_mask[z, template_ind, s]: continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] >= traces_center[neighbour, s] - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] > traces[neighbour, s + i] - peak_mask[chan_ind, s] &= ( - traces_center[chan_ind, s] >= traces[neighbour, exclude_sweep_size + s + i + 1] - ) - if not peak_mask[chan_ind, s]: + for neighbour in range(num_templates): + if not neighbours_mask[template_ind, neighbour]: + continue + for j in range(num_z): + for i in range(exclude_sweep_size): + if template_ind >= neighbour: + if z >= j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + elif template_ind < neighbour: + if z > j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces[j, neighbour, s + i] + ) + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] + >= traces[j, neighbour, exclude_sweep_size + s + i + 1] + ) + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: break - if not peak_mask[chan_ind, s]: - break + return peak_mask diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 2ecccb421c..fa30ba3483 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -323,6 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) detect_threshold=5, exclude_sweep_ms=0.1, prototype=prototype, + ms_before=1.0, **job_kwargs, ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)