-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compute_synchrony_metrics update #2605
Changes from all commits
e8e4ca4
220add7
cb45927
b91f5f2
bfbc5e3
f92e2a9
4f5b119
1bb06b6
0706b4d
635d330
f22698b
56f0559
3de7c27
c53fedd
c9244a2
7cfd0ea
f3217f4
51f0039
f5d20fa
23c3355
7719ee1
66077e2
11d5d39
5de7b66
378976f
ecc6a9a
f6652f0
0b9a58a
16b344c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -496,7 +496,51 @@ def compute_sliding_rp_violations( | |
) | ||
|
||
|
||
def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): | ||
def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): | ||
"""Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes` | ||
|
||
Parameters | ||
---------- | ||
spikes : np.array | ||
Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). | ||
synchrony_sizes : numpy array | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this given in seconds, samples, milliseconds? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made the description more accruate (since it's actually a structured numpy array). One of the fields is "sample_index" which clarifies the unit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was talking about the synchrony_sizes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry. These are the number of synchronous events you want to count. So if you want to see when two or four spikes fire at the same time you use synchrony_sizes = (2,4). So it's an integer, and I think it will be clear for anyone who knows enough about the metric that their using it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I guess my lack of familarity with the metric is evident then. I somehow imagined that it was the windows in which an event would be counted as syncronous. Thanks for explaining it. |
||
The synchrony sizes to compute. Should be pre-sorted. | ||
unit_ids : list or None, default: None | ||
List of unit ids to compute the synchrony metrics. Expecting all units. | ||
|
||
Returns | ||
------- | ||
synchrony_counts : dict | ||
The synchrony counts for the synchrony sizes. | ||
|
||
References | ||
---------- | ||
Based on concepts described in [Gruen]_ | ||
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_ | ||
""" | ||
|
||
synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this the synhcony is higly dependant on the sampling rate no ? |
||
|
||
# compute the occurrence of each sample_index. Count >2 means there's synchrony | ||
_, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True) | ||
|
||
sync_indices = unique_spike_index[counts >= 2] | ||
sync_counts = counts[counts >= 2] | ||
|
||
for i, sync_index in enumerate(sync_indices): | ||
|
||
num_of_syncs = sync_counts[i] | ||
units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)] | ||
|
||
# Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added | ||
# to the 2 simultaneous spike bins. | ||
how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs]) | ||
synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1 | ||
|
||
return synchrony_counts | ||
|
||
|
||
def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): | ||
"""Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of | ||
"synchrony_size" spikes at the exact same sample index. | ||
|
||
|
@@ -521,49 +565,39 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ | |
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_ | ||
""" | ||
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" | ||
spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="dict") | ||
sorting = sorting_analyzer.sorting | ||
spikes = sorting.to_spike_vector(concatenated=False) | ||
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts | ||
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) | ||
synchrony_sizes_np.sort() | ||
|
||
if unit_ids is None: | ||
unit_ids = sorting_analyzer.unit_ids | ||
res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np]) | ||
|
||
# Pre-allocate synchrony counts | ||
synchrony_counts = {} | ||
for synchrony_size in synchrony_sizes: | ||
synchrony_counts[synchrony_size] = np.zeros(len(sorting_analyzer.unit_ids), dtype=np.int64) | ||
sorting = sorting_analyzer.sorting | ||
|
||
all_unit_ids = list(sorting.unit_ids) | ||
for segment_index in range(sorting.get_num_segments()): | ||
spikes_in_segment = spikes[segment_index] | ||
spike_counts = sorting.count_num_spikes_per_unit(outputs="dict") | ||
|
||
# we compute just by counting the occurrence of each sample_index | ||
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) | ||
spikes = sorting.to_spike_vector() | ||
all_unit_ids = sorting.unit_ids | ||
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) | ||
|
||
synchrony_metrics_dict = {} | ||
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): | ||
sync_id_metrics_dict = {} | ||
for i, unit_id in enumerate(all_unit_ids): | ||
if spike_counts[unit_id] != 0: | ||
sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] | ||
else: | ||
sync_id_metrics_dict[unit_id] = 0 | ||
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict | ||
|
||
# add counts for this segment | ||
for unit_id in unit_ids: | ||
unit_index = all_unit_ids.index(unit_id) | ||
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] | ||
# some segments/units might have no spikes | ||
if len(spikes_per_unit) == 0: | ||
continue | ||
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] | ||
for synchrony_size in synchrony_sizes: | ||
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) | ||
|
||
# add counts for this segment | ||
synchrony_metrics_dict = { | ||
f"sync_spike_{synchrony_size}": { | ||
unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id] | ||
for unit_id in unit_ids | ||
} | ||
for synchrony_size in synchrony_sizes | ||
} | ||
|
||
# Convert dict to named tuple | ||
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys()) | ||
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict) | ||
return synchrony_metrics | ||
if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)): | ||
return res(**synchrony_metrics_dict) | ||
else: | ||
reduced_synchrony_metrics_dict = {} | ||
for key in synchrony_metrics_dict: | ||
reduced_synchrony_metrics_dict[key] = { | ||
unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids | ||
} | ||
return res(**reduced_synchrony_metrics_dict) | ||
|
||
|
||
_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would excpet a delta in the signature no ?