Skip to content

Commit

Permalink
Make namedtuple and check div by zero
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Mar 20, 2024
1 parent 220add7 commit cb45927
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
29 changes: 16 additions & 13 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int64)
synchrony_sizes_np.sort()

res = namedtuple("synchrony", [f"sync_spike_{size}" for size in synchrony_sizes_np])

sorting = sorting_analyzer.sorting

spike_counts = sorting.count_num_spikes_per_unit(outputs="dict")
Expand All @@ -577,25 +579,26 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)

synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[sync_idx][unit_id] / spike_counts[unit_id]
for i, unit_id in enumerate(all_unit_ids)
}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np)
}

if (unit_ids == None) or (len(unit_ids) == len(all_unit_ids)):
return synchrony_metrics_dict
synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
synch_id_metrics_dict = {}
for i, unit_id in enumerate(all_unit_ids):
if spike_counts[unit_id] != 0:
synch_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id]
else:
synch_id_metrics_dict[unit_id] = 0
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = synch_id_metrics_dict

if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)):
return res(**synchrony_metrics_dict)
else:
reduced_synchrony_metrics_dict = {}
for key in synchrony_metrics_dict:
reduced_synchrony_metrics_dict[key] = {
unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids
}
return reduced_synchrony_metrics_dict


return res(**reduced_synchrony_metrics_dict)

_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import shutil

from pandas import isnull

from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
Expand Down Expand Up @@ -117,7 +119,6 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):
sorting_analyzer_norec._recording = None
assert not sorting_analyzer_norec.has_recording()

print(sorting_analyzer_norec)

metrics_norec = compute_quality_metrics(
sorting_analyzer_norec,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_empty_units(sorting_analyzer_simple):
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
assert np.all(np.isnan(metrics_empty.loc[empty_unit_id]))
assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down

0 comments on commit cb45927

Please sign in to comment.