diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 51064fafd2..53949d3a56 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -112,7 +112,9 @@ def compute_template_similarity( similarity = tmc.get_data() return similarity else: - return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method) + return _compute_template_similarity( + waveform_extractor=waveform_extractor, waveform_extractor_other=waveform_extractor_other, method=method + ) def check_equal_template_with_distribution_overlap( diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 210954bbc4..dd80c4aed8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -4,6 +4,12 @@ from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.core import extract_waveforms +from spikeinterface.extractors import toy_example +from spikeinterface.comparison import compare_templates + +import numpy as np + class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = TemplateSimilarityCalculator @@ -21,8 +27,31 @@ def test_check_equal_template_with_distribution_overlap(self): check_equal_template_with_distribution_overlap(waveforms0, waveforms1) +def test_compare_multiple_templates_different_units(): + + duration = 5 + num_channels = 4 + + num_units_1 = 5 + num_units_2 = 10 + + rec1, sort1 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_1) + + rec2, sort2 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_2) + + # compute waveforms + we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory") + we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory") + + # paired comparison + temp_cmp = compare_templates(we1, we2) + + assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2) + + if __name__ == "__main__": test = SimilarityExtensionTest() test.setUp() test.test_extension() test.test_check_equal_template_with_distribution_overlap() + test_compare_multiple_templates_different_units()