diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c1021e787a..05853b4c39 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -179,6 +179,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["matched_filtering"]: prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) detection_params["prototype"] = prototype + detection_params["ms_before"] = ms_before for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: if value in detection_params: diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 4e1fa64961..d23f0fec74 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -611,6 +611,7 @@ def __init__( self, recording, prototype, + ms_before, peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, @@ -642,6 +643,7 @@ def __init__( raise NotImplementedError("Matched filtering not working with peak_sign=both yet!") self.peak_sign = peak_sign + self.nbefore = int(ms_before * recording.sampling_frequency / 1000) contact_locations = recording.get_channel_locations() dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) weights, self.z_factors = get_convolution_weights(dist, **weight_method) @@ -687,17 +689,19 @@ def get_trace_margin(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, temporal, spatial, singular, z_factors = self.args - assert HAVE_NUMBA, "You need to install numba" conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular) conv_traces /= self.abs_thresholds[:, None] conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] + num_z_factors = len(self.z_factors) - num_channels = conv_traces.shape[0] // num_z_factors + num_templates = traces.shape[1] + traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) + conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 + peak_mask = _numba_detect_peak_matched_filtering( conv_traces, traces_center, @@ -706,15 +710,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): self.abs_thresholds, self.peak_sign, self.neighbours_mask, - num_channels, + num_templates, ) # Find peaks and correct for time shift - peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) - - # If we do not want to estimate the z accurately - z = self.z_factors[peak_chan_ind // num_channels] - peak_chan_ind = peak_chan_ind % num_channels + z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) # If we want to estimate z # peak_chan_ind = peak_chan_ind % num_channels @@ -728,7 +728,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): if peak_sample_ind.size == 0 or peak_chan_ind.size == 0: return (np.zeros(0, dtype=self._dtype),) - peak_sample_ind += self.exclude_sweep_size + self.conv_margin + peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore peak_amplitude = traces[peak_sample_ind, peak_chan_ind] local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) @@ -736,7 +736,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): local_peaks["channel_index"] = peak_chan_ind local_peaks["amplitude"] = peak_amplitude local_peaks["segment_index"] = segment_index - local_peaks["z"] = z + local_peaks["z"] = z_ind # return is always a tuple return (local_peaks,) @@ -745,10 +745,11 @@ def get_convolved_traces(self, traces, temporal, spatial, singular): import scipy.signal num_timesteps, num_templates = len(traces), temporal.shape[1] - scalar_products = np.zeros((num_templates, num_timesteps), dtype=np.float32) + num_peaks = num_timesteps - self.conv_margin + 1 + scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32) spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) scaled_filtered_data = spatially_filtered_data * singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="same") + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) return scalar_products @@ -874,27 +875,51 @@ def _numba_detect_peak_neg( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates ): - num_chans = traces_center.shape[0] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[1]): - if not peak_mask[chan_ind, s]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind % num_channels, neighbour % num_channels]: + num_z = traces_center.shape[0] + for template_ind in range(num_templates): + for z in range(num_z): + for s in range(peak_mask.shape[2]): + if not peak_mask[z, template_ind, s]: continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] >= traces_center[neighbour, s] - peak_mask[chan_ind, s] &= traces_center[chan_ind, s] > traces[neighbour, s + i] - peak_mask[chan_ind, s] &= ( - traces_center[chan_ind, s] >= traces[neighbour, exclude_sweep_size + s + i + 1] - ) - if not peak_mask[chan_ind, s]: + for neighbour in range(num_templates): + if not neighbours_mask[template_ind, neighbour]: + continue + for j in range(num_z): + for i in range(exclude_sweep_size): + if template_ind >= neighbour: + if z >= j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + elif template_ind < neighbour: + if z > j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces[j, neighbour, s + i] + ) + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] + >= traces[j, neighbour, exclude_sweep_size + s + i + 1] + ) + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: + break + if not peak_mask[z, template_ind, s]: break - if not peak_mask[chan_ind, s]: - break + return peak_mask diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 2ecccb421c..fa30ba3483 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -323,6 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) detect_threshold=5, exclude_sweep_ms=0.1, prototype=prototype, + ms_before=1.0, **job_kwargs, ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)