From 2eea2d59f67c96a4dd109ff87895866edd6f2417 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 19 Mar 2024 12:50:57 +0000 Subject: [PATCH 1/2] Fix template similarity when there are multiple sizes of units --- .../postprocessing/template_similarity.py | 4 ++- .../tests/test_template_similarity.py | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 51064fafd2..53949d3a56 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -112,7 +112,9 @@ def compute_template_similarity( similarity = tmc.get_data() return similarity else: - return _compute_template_similarity(waveform_extractor, waveform_extractor_other, method) + return _compute_template_similarity( + waveform_extractor=waveform_extractor, waveform_extractor_other=waveform_extractor_other, method=method + ) def check_equal_template_with_distribution_overlap( diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 210954bbc4..36ad991db9 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -4,6 +4,11 @@ from spikeinterface.postprocessing.tests.common_extension_tests import WaveformExtensionCommonTestSuite +from spikeinterface.core import extract_waveforms +from spikeinterface.extractors import toy_example +from spikeinterface.comparison import compare_templates + +import numpy as np class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = TemplateSimilarityCalculator @@ -21,8 +26,33 @@ def test_check_equal_template_with_distribution_overlap(self): check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + +def test_compare_multiple_templates_different_units(): + + duration = 5 + num_channels = 4 + + num_units_1 = 5 + num_units_2 = 10 + + rec1, sort1 = toy_example(duration=duration, num_segments=1, + num_channels=num_channels, num_units=num_units_1) + + rec2, sort2 = toy_example(duration=duration, num_segments=1, + num_channels=num_channels, num_units=num_units_2) + + # compute waveforms + we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory") + we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory") + + # paired comparison + temp_cmp = compare_templates(we1, we2) + + assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2) + if __name__ == "__main__": test = SimilarityExtensionTest() test.setUp() test.test_extension() test.test_check_equal_template_with_distribution_overlap() + test_compare_multiple_templates_different_units() From cff7886fb43ad2507b6ff6ffc905d9ce2519be76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Mar 2024 12:54:07 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests/test_template_similarity.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 36ad991db9..dd80c4aed8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -10,6 +10,7 @@ import numpy as np + class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = TemplateSimilarityCalculator extension_data_names = ["similarity"] @@ -26,21 +27,18 @@ def test_check_equal_template_with_distribution_overlap(self): check_equal_template_with_distribution_overlap(waveforms0, waveforms1) - def test_compare_multiple_templates_different_units(): - + duration = 5 num_channels = 4 num_units_1 = 5 num_units_2 = 10 - rec1, sort1 = toy_example(duration=duration, num_segments=1, - num_channels=num_channels, num_units=num_units_1) - - rec2, sort2 = toy_example(duration=duration, num_segments=1, - num_channels=num_channels, num_units=num_units_2) - + rec1, sort1 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_1) + + rec2, sort2 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_2) + # compute waveforms we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory") we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory") @@ -50,6 +48,7 @@ def test_compare_multiple_templates_different_units(): assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2) + if __name__ == "__main__": test = SimilarityExtensionTest() test.setUp()