Skip to content

Commit

Permalink
Merge pull request #2905 from yger/fix_matched_filtering
Browse files Browse the repository at this point in the history
Fix some bugs in matched filtering
  • Loading branch information
samuelgarcia authored May 25, 2024
2 parents f478f26 + 59a1c78 commit 63a8442
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if params["matched_filtering"]:
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before

for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in detection_params:
Expand Down
87 changes: 56 additions & 31 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def __init__(
self,
recording,
prototype,
ms_before,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
Expand Down Expand Up @@ -642,6 +643,7 @@ def __init__(
raise NotImplementedError("Matched filtering not working with peak_sign=both yet!")

self.peak_sign = peak_sign
self.nbefore = int(ms_before * recording.sampling_frequency / 1000)
contact_locations = recording.get_channel_locations()
dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2)
weights, self.z_factors = get_convolution_weights(dist, **weight_method)
Expand Down Expand Up @@ -687,17 +689,19 @@ def get_trace_margin(self):

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

# peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, temporal, spatial, singular, z_factors = self.args

assert HAVE_NUMBA, "You need to install numba"
conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular)
conv_traces /= self.abs_thresholds[:, None]
conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin]
traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size]

num_z_factors = len(self.z_factors)
num_channels = conv_traces.shape[0] // num_z_factors
num_templates = traces.shape[1]

traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1])
conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1])
peak_mask = traces_center > 1

peak_mask = _numba_detect_peak_matched_filtering(
conv_traces,
traces_center,
Expand All @@ -706,15 +710,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
self.abs_thresholds,
self.peak_sign,
self.neighbours_mask,
num_channels,
num_templates,
)

# Find peaks and correct for time shift
peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask)

# If we do not want to estimate the z accurately
z = self.z_factors[peak_chan_ind // num_channels]
peak_chan_ind = peak_chan_ind % num_channels
z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask)

# If we want to estimate z
# peak_chan_ind = peak_chan_ind % num_channels
Expand All @@ -728,15 +728,15 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
if peak_sample_ind.size == 0 or peak_chan_ind.size == 0:
return (np.zeros(0, dtype=self._dtype),)

peak_sample_ind += self.exclude_sweep_size + self.conv_margin
peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore

peak_amplitude = traces[peak_sample_ind, peak_chan_ind]
local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype)
local_peaks["sample_index"] = peak_sample_ind
local_peaks["channel_index"] = peak_chan_ind
local_peaks["amplitude"] = peak_amplitude
local_peaks["segment_index"] = segment_index
local_peaks["z"] = z
local_peaks["z"] = z_ind

# return is always a tuple
return (local_peaks,)
Expand All @@ -745,10 +745,11 @@ def get_convolved_traces(self, traces, temporal, spatial, singular):
import scipy.signal

num_timesteps, num_templates = len(traces), temporal.shape[1]
scalar_products = np.zeros((num_templates, num_timesteps), dtype=np.float32)
num_peaks = num_timesteps - self.conv_margin + 1
scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32)
spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * singular
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="same")
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid")
scalar_products += np.sum(objective_by_rank, axis=0)
return scalar_products

Expand Down Expand Up @@ -874,27 +875,51 @@ def _numba_detect_peak_neg(

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_matched_filtering(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates
):
num_chans = traces_center.shape[0]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[1]):
if not peak_mask[chan_ind, s]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind % num_channels, neighbour % num_channels]:
num_z = traces_center.shape[0]
for template_ind in range(num_templates):
for z in range(num_z):
for s in range(peak_mask.shape[2]):
if not peak_mask[z, template_ind, s]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[chan_ind, s] &= traces_center[chan_ind, s] >= traces_center[neighbour, s]
peak_mask[chan_ind, s] &= traces_center[chan_ind, s] > traces[neighbour, s + i]
peak_mask[chan_ind, s] &= (
traces_center[chan_ind, s] >= traces[neighbour, exclude_sweep_size + s + i + 1]
)
if not peak_mask[chan_ind, s]:
for neighbour in range(num_templates):
if not neighbours_mask[template_ind, neighbour]:
continue
for j in range(num_z):
for i in range(exclude_sweep_size):
if template_ind >= neighbour:
if z >= j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] >= traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
elif template_ind < neighbour:
if z > j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces[j, neighbour, s + i]
)
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s]
>= traces[j, neighbour, exclude_sweep_size + s + i + 1]
)
if not peak_mask[z, template_ind, s]:
break
if not peak_mask[z, template_ind, s]:
break
if not peak_mask[z, template_ind, s]:
break
if not peak_mask[chan_ind, s]:
break

return peak_mask


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)
detect_threshold=5,
exclude_sweep_ms=0.1,
prototype=prototype,
ms_before=1.0,
**job_kwargs,
)
assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)
Expand Down

0 comments on commit 63a8442

Please sign in to comment.