Skip to content

Commit

Permalink
adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed May 29, 2024
1 parent e1d87d0 commit 49ef51e
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/spikeinterface/sortingcomponents/tests/test_peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,38 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)
)
assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)

peaks_local_mf_filtering_both = detect_peaks(
recording,
method="matched_filtering",
peak_sign="both",
detect_threshold=5,
exclude_sweep_ms=0.1 ,
prototype=prototype,
ms_before=1.0,
**job_kwargs,
)
assert len(peaks_local_mf_filtering_both) > len(peaks_local_mf_filtering)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

peaks = peaks_local_mf_filtering
peaks_local = peaks_by_channel_np
peaks_mf_neg = peaks_local_mf_filtering
peaks_mf_both = peaks_local_mf_filtering_both
labels = ['locally_exclusive', 'mf_neg', 'mf_both']

sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
fig, ax = plt.subplots()
chan_offset = 500
traces = recording.get_traces().copy()
traces += np.arange(traces.shape[1])[None, :] * chan_offset
fig, ax = plt.subplots()
ax.plot(traces, color="k")
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r")

for count, peaks in enumerate([peaks_local, peaks_mf_neg, peaks_mf_both]):
sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count])

ax.legend()
plt.show()


Expand Down

0 comments on commit 49ef51e

Please sign in to comment.