diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 36ad991db9..dd80c4aed8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -10,6 +10,7 @@ import numpy as np + class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): extension_class = TemplateSimilarityCalculator extension_data_names = ["similarity"] @@ -26,21 +27,18 @@ def test_check_equal_template_with_distribution_overlap(self): check_equal_template_with_distribution_overlap(waveforms0, waveforms1) - def test_compare_multiple_templates_different_units(): - + duration = 5 num_channels = 4 num_units_1 = 5 num_units_2 = 10 - rec1, sort1 = toy_example(duration=duration, num_segments=1, - num_channels=num_channels, num_units=num_units_1) - - rec2, sort2 = toy_example(duration=duration, num_segments=1, - num_channels=num_channels, num_units=num_units_2) - + rec1, sort1 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_1) + + rec2, sort2 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_2) + # compute waveforms we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory") we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory") @@ -50,6 +48,7 @@ def test_compare_multiple_templates_different_units(): assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2) + if __name__ == "__main__": test = SimilarityExtensionTest() test.setUp()