Skip to content

Commit

Permalink
Redo tests to use namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Mar 20, 2024
1 parent b91f5f2 commit bfbc5e3
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def test_synchrony_metrics(sorting_analyzer_simple):

# check returns
for size in synchrony_sizes:
assert f"sync_spike_{size}" in synchrony_metrics
assert f"sync_spike_{size}" in synchrony_metrics._fields

# here we test that increasing added synchrony is captured by syncrhony metrics
added_synchrony_levels = (0.2, 0.5, 0.8)
Expand All @@ -386,8 +386,13 @@ def test_synchrony_metrics(sorting_analyzer_simple):
current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes)
print(current_synchrony_metrics)
# check that all values increased
for syncs in previous_synchrony_metrics:
assert list(previous_synchrony_metrics[syncs].values()) < list(current_synchrony_metrics[syncs].values())
for i, col in enumerate(previous_synchrony_metrics._fields):
assert np.all(
v_prev < v_curr
for (v_prev, v_curr) in zip(
previous_synchrony_metrics[i].values(), current_synchrony_metrics[i].values()
)
)

# set new previous waveform extractor
previous_sorting_analyzer = sorting_analyzer_sync
Expand All @@ -398,21 +403,21 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):
unit_ids_subset = [3, 7]

synchrony_sizes = (2,)
synchrony_metrics = compute_synchrony_metrics(
synchrony_metrics, = compute_synchrony_metrics(
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
)

assert list(synchrony_metrics["sync_spike_2"].keys()) == [3, 7]
assert list(synchrony_metrics.keys()) == [3, 7]


def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):

all_unit_ids = sorting_analyzer_simple.sorting.unit_ids
#all_unit_ids = sorting_analyzer_simple.sorting.unit_ids

synchrony_sizes = (2,)
synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)
synchrony_metrics, = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)

assert np.all(list(synchrony_metrics["sync_spike_2"].keys()) == sorting_analyzer_simple.unit_ids)
assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)


@pytest.mark.sortingcomponents
Expand Down

0 comments on commit bfbc5e3

Please sign in to comment.