Skip to content

Commit

Permalink
Merge pull request #2605 from chrishalcrow/sync_counts_update
Browse files Browse the repository at this point in the history
compute_synchrony_metrics update
  • Loading branch information
alejoe91 authored Apr 16, 2024
2 parents 3963101 + 16b344c commit fa57fee
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 43 deletions.
19 changes: 19 additions & 0 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
generate_unit_locations,
generate_ground_truth_recording,
generate_sorting_to_inject,
synthesize_random_firings,
)

from spikeinterface.core.numpyextractors import NumpySorting
Expand Down Expand Up @@ -555,6 +556,24 @@ def test_generate_sorting_to_inject():
assert num_injected_spikes[unit_id] <= num_spikes[unit_id]


def test_synthesize_random_firings_length():

firing_rates = [2.0, 3.0]
duration = 2
num_units = 2

spike_times, spike_units = synthesize_random_firings(
num_units=num_units, duration=duration, firing_rates=firing_rates
)

assert len(spike_times) == int(np.sum(firing_rates) * duration)

units, counts = np.unique(spike_units, return_counts=True)

assert len(units) == num_units
assert np.sum(counts) == int(np.sum(firing_rates) * duration)


if __name__ == "__main__":
strategy = "tile_pregenerated"
# strategy = "on_the_fly"
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/qualitymetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
get_default_qm_params,
)
from .pca_metrics import get_quality_pca_metric_list
from .misc_metrics import get_synchrony_counts
112 changes: 73 additions & 39 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,51 @@ def compute_sliding_rp_violations(
)


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs):
def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
"""Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`
Parameters
----------
spikes : np.array
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
synchrony_sizes : numpy array
The synchrony sizes to compute. Should be pre-sorted.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.
Returns
-------
synchrony_counts : dict
The synchrony counts for the synchrony sizes.
References
----------
Based on concepts described in [Gruen]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""

synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64)

# compute the occurrence of each sample_index. Count >2 means there's synchrony
_, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True)

sync_indices = unique_spike_index[counts >= 2]
sync_counts = counts[counts >= 2]

for i, sync_index in enumerate(sync_indices):

num_of_syncs = sync_counts[i]
units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)]

# Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added
# to the 2 simultaneous spike bins.
how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs])
synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1

return synchrony_counts


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
"""Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
Expand All @@ -521,49 +565,39 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="dict")
sorting = sorting_analyzer.sorting
spikes = sorting.to_spike_vector(concatenated=False)
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
synchrony_sizes_np.sort()

if unit_ids is None:
unit_ids = sorting_analyzer.unit_ids
res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np])

# Pre-allocate synchrony counts
synchrony_counts = {}
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size] = np.zeros(len(sorting_analyzer.unit_ids), dtype=np.int64)
sorting = sorting_analyzer.sorting

all_unit_ids = list(sorting.unit_ids)
for segment_index in range(sorting.get_num_segments()):
spikes_in_segment = spikes[segment_index]
spike_counts = sorting.count_num_spikes_per_unit(outputs="dict")

# we compute just by counting the occurrence of each sample_index
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)
spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
sync_id_metrics_dict = {}
for i, unit_id in enumerate(all_unit_ids):
if spike_counts[unit_id] != 0:
sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id]
else:
sync_id_metrics_dict[unit_id] = 0
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict

# add counts for this segment
for unit_id in unit_ids:
unit_index = all_unit_ids.index(unit_id)
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

# add counts for this segment
synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id]
for unit_id in unit_ids
}
for synchrony_size in synchrony_sizes
}

# Convert dict to named tuple
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys())
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict)
return synchrony_metrics
if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)):
return res(**synchrony_metrics_dict)
else:
reduced_synchrony_metrics_dict = {}
for key in synchrony_metrics_dict:
reduced_synchrony_metrics_dict[key] = {
unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids
}
return res(**reduced_synchrony_metrics_dict)


_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
add_synchrony_to_sorting,
generate_ground_truth_recording,
create_sorting_analyzer,
synthesize_random_firings,
)

# from spikeinterface.extractors.toy_example import toy_example
Expand Down Expand Up @@ -35,8 +36,10 @@
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
get_synchrony_counts,
)

from spikeinterface.core.basesorting import minimum_spike_dtype

# if hasattr(pytest, "global_test_folder"):
# cache_folder = pytest.global_test_folder / "qualitymetrics"
Expand Down Expand Up @@ -129,6 +132,76 @@ def sorting_analyzer_violations():
return _sorting_analyzer_violations()


def test_synchrony_counts_no_sync():

spike_times, spike_units = synthesize_random_firings(num_units=1, duration=1, firing_rates=1.0)

one_spike = np.zeros(len(spike_times), minimum_spike_dtype)
one_spike["sample_index"] = spike_times
one_spike["unit_index"] = spike_units

sync_count = get_synchrony_counts(one_spike, np.array((2)), [0])

assert np.all(sync_count[0] == np.array([0]))


def test_synchrony_counts_one_sync():
# a spike train containing two synchronized spikes
spike_indices, spike_labels = synthesize_random_firings(
num_units=2,
duration=1,
firing_rates=1.0,
)

added_spikes_indices = [100, 100]
added_spikes_labels = [1, 0]

two_spikes = np.zeros(len(spike_indices) + 2, minimum_spike_dtype)
two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1])

assert np.all(sync_count[0] == np.array([1, 1]))


def test_synchrony_counts_one_quad_sync():
# a spike train containing four synchronized spikes
spike_indices, spike_labels = synthesize_random_firings(
num_units=4,
duration=1,
firing_rates=1.0,
)

added_spikes_indices = [100, 100, 100, 100]
added_spikes_labels = [0, 1, 2, 3]

four_spikes = np.zeros(len(spike_indices) + 4, minimum_spike_dtype)
four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3])

assert np.all(sync_count[0] == np.array([1, 1, 1, 1]))
assert np.all(sync_count[1] == np.array([1, 1, 1, 1]))


def test_synchrony_counts_not_all_units():
# a spike train containing two synchronized spikes
spike_indices, spike_labels = synthesize_random_firings(num_units=3, duration=1, firing_rates=1.0)

added_spikes_indices = [50, 100, 100]
added_spikes_labels = [0, 1, 2]

three_spikes = np.zeros(len(spike_indices) + 3, minimum_spike_dtype)
three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices))
three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels))

sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2])

assert np.all(sync_count[0] == np.array([0, 1, 1]))


def test_mahalanobis_metrics():
all_pcs1, all_labels1 = create_ground_truth_pc_distributions([1, -1], [1000, 1000])
all_pcs2, all_labels2 = create_ground_truth_pc_distributions(
Expand Down Expand Up @@ -358,6 +431,28 @@ def test_synchrony_metrics(sorting_analyzer_simple):
previous_sorting_analyzer = sorting_analyzer_sync


def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple):

unit_ids_subset = [3, 7]

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(
sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset
)

assert list(synchrony_metrics.keys()) == [3, 7]


def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple):

# all_unit_ids = sorting_analyzer_simple.sorting.unit_ids

synchrony_sizes = (2,)
(synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes)

assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids)


@pytest.mark.sortingcomponents
def test_calculate_drift_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
Expand Down Expand Up @@ -400,7 +495,9 @@ def test_calculate_sd_ratio(sorting_analyzer_simple):
# test_calculate_amplitude_median(sorting_analyzer)
# test_calculate_sliding_rp_violations(sorting_analyzer)
# test_calculate_drift_metrics(sorting_analyzer)
# test_synchrony_metrics(sorting_analyzer)
test_synchrony_metrics(sorting_analyzer)
test_synchrony_metrics_unit_id_subset(sorting_analyzer)
test_synchrony_metrics_no_unit_ids(sorting_analyzer)
# test_calculate_firing_range(sorting_analyzer)
# test_calculate_amplitude_cv_metrics(sorting_analyzer)
test_calculate_sd_ratio(sorting_analyzer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import shutil

from pandas import isnull

from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
Expand Down Expand Up @@ -117,8 +119,6 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):
sorting_analyzer_norec._recording = None
assert not sorting_analyzer_norec.has_recording()

print(sorting_analyzer_norec)

metrics_norec = compute_quality_metrics(
sorting_analyzer_norec,
metric_names=None,
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_empty_units(sorting_analyzer_simple):
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
assert np.all(np.isnan(metrics_empty.loc[empty_unit_id]))
assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down

0 comments on commit fa57fee

Please sign in to comment.