Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 19, 2024
1 parent 2eea2d5 commit cff7886
Showing 1 changed file with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np


class SimilarityExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase):
extension_class = TemplateSimilarityCalculator
extension_data_names = ["similarity"]
Expand All @@ -26,21 +27,18 @@ def test_check_equal_template_with_distribution_overlap(self):
check_equal_template_with_distribution_overlap(waveforms0, waveforms1)



def test_compare_multiple_templates_different_units():

duration = 5
num_channels = 4

num_units_1 = 5
num_units_2 = 10

rec1, sort1 = toy_example(duration=duration, num_segments=1,
num_channels=num_channels, num_units=num_units_1)

rec2, sort2 = toy_example(duration=duration, num_segments=1,
num_channels=num_channels, num_units=num_units_2)

rec1, sort1 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_1)

rec2, sort2 = toy_example(duration=duration, num_segments=1, num_channels=num_channels, num_units=num_units_2)

# compute waveforms
we1 = extract_waveforms(rec1, sort1, n_jobs=1, mode="memory")
we2 = extract_waveforms(rec2, sort2, n_jobs=1, mode="memory")
Expand All @@ -50,6 +48,7 @@ def test_compare_multiple_templates_different_units():

assert np.shape(temp_cmp.agreement_scores) == (num_units_1, num_units_2)


if __name__ == "__main__":
test = SimilarityExtensionTest()
test.setUp()
Expand Down

0 comments on commit cff7886

Please sign in to comment.