Skip to content

Commit

Permalink
Merge pull request #2601 from chrishalcrow/fix_template_similarity_di…
Browse files Browse the repository at this point in the history
…ff_units

Fix template similarity when there are multiple sizes of units
  • Loading branch information
alejoe91 authored Mar 19, 2024
2 parents e990d53 + cff7886 commit 9041249
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def compute_template_similarity(
similarity = tmc.get_data()
return similarity
else:
return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method)
return _compute_template_similarity(
waveform_extractor=waveform_extractor, waveform_extractor_other=waveform_extractor_other, method=method
)


def check_equal_template_with_distribution_overlap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite

from spikeinterface.core import extract_waveforms
from spikeinterface.extractors import toy_example
from spikeinterface.comparison import compare_templates

import numpy as np


class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase):
extension_class = TemplateSimilarityCalculator
Expand All @@ -21,8 +27,31 @@ def test_check_equal_template_with_distribution_overlap(self):
check_equal_template_with_distribution_overlap(waveforms0, waveforms1)


def test_compare_multiple_templates_different_units():

duration = 5
num_channels = 4

num_units_1 = 5
num_units_2 = 10

rec1, sort1 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_1)

rec2, sort2 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_2)

# compute waveforms
we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory")
we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory")

# paired comparison
temp_cmp = compare_templates(we1, we2)

assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2)


if __name__ == "__main__":
test = SimilarityExtensionTest()
test.setUp()
test.test_extension()
test.test_check_equal_template_with_distribution_overlap()
test_compare_multiple_templates_different_units()

0 comments on commit 9041249

Please sign in to comment.