Skip to content

Commit

Permalink
Merge pull request #2727 from zm711/neg-spikes-v2
Browse files Browse the repository at this point in the history
Make sure `has_exceeding_spikes` also checks for negative spikes.
  • Loading branch information
alejoe91 authored Apr 17, 2024
2 parents fa57fee + db2e255 commit 33d478a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ def has_exceeding_spikes(recording, sorting):
if len(spike_vector_seg) > 0:
if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
return True
if spike_vector_seg["sample_index"][0] < 0:
return True
return False


Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/curation/remove_excess_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def _custom_cache_spike_vector(self) -> None:
for segment_index in range(num_segments):
spike_vector = parent_spike_vector[segments_bounds[segment_index] : segments_bounds[segment_index + 1]]
end = np.searchsorted(spike_vector["sample_index"], self._num_samples[segment_index])
list_spike_vectors.append(spike_vector[:end])
start = np.searchsorted(spike_vector["sample_index"], 0, side="left")
list_spike_vectors.append(spike_vector[start:end])

spike_vector = np.concatenate(list_spike_vectors)
self._cached_spike_vector = spike_vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_remove_excess_spikes():
times_segment = np.array([], dtype=int)
labels_segment = np.array([], dtype=int)
for unit in range(num_units):
neg_spike_times = np.random.randint(-50, 0, num_neg_spike_times_per_segment)
neg_spike_times = np.random.randint(-50, -1, num_neg_spike_times_per_segment)
spike_times = np.random.randint(0, num_samples, num_spikes)
last_samples_spikes = (num_samples - 1) * np.ones(num_num_samples_spikes_per_segment, dtype=int)
num_samples_spike_times = num_samples * np.ones(num_num_samples_spikes_per_segment, dtype=int)
Expand Down

0 comments on commit 33d478a

Please sign in to comment.