Skip to content

Commit

Permalink
Merge pull request #3208 from samuelgarcia/apply_curation
Browse files Browse the repository at this point in the history
Start apply_curation()
  • Loading branch information
samuelgarcia authored Jul 19, 2024
2 parents 63b295c + 0eacd1a commit 9e84a62
Show file tree
Hide file tree
Showing 9 changed files with 410 additions and 56 deletions.
12 changes: 9 additions & 3 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,21 @@ spikeinterface.curation
------------------------
.. automodule:: spikeinterface.curation

.. autoclass:: CurationSorting
.. autoclass:: MergeUnitsSorting
.. autoclass:: SplitUnitSorting
.. autofunction:: apply_curation
.. autofunction:: get_potential_auto_merge
.. autofunction:: find_redundant_units
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes

Deprecated
~~~~~~~~~~
.. automodule:: spikeinterface.curation

.. autofunction:: apply_sortingview_curation
.. autoclass:: CurationSorting
.. autoclass:: MergeUnitsSorting
.. autoclass:: SplitUnitSorting


spikeinterface.generation
Expand Down
22 changes: 17 additions & 5 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,23 @@ format is the definition; the second part of the format is manual action):
}
.. note::
The curation format was recently introduced (v0.101.0), and we are still working on
properly integrating it into the SpikeInterface ecosystem.
Soon there will be functions vailable, in the curation module, to apply this
standardized curation format to ``SortingAnalyzer`` and a ``BaseSorting`` objects.
The curation format can be loaded into a dictionary and directly applied to
a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function.

.. code-block:: python
from spikeinterface.curation import apply_curation
# load the curation JSON file
curation_json = "path/to/curation.json"
with open(curation_json, 'r') as f:
curation_dict = json.load(f)
# apply the curation to the sorting output
clean_sorting = apply_curation(sorting, curation_dict=curation_dict)
# apply the curation to the sorting analyzer
clean_sorting_analyzer = apply_curation(sorting_analyzer, curation_dict=curation_dict)
Using the ``SpikeInterface GUI``
Expand Down
20 changes: 14 additions & 6 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def random_spikes_selection(


def apply_merges_to_sorting(
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append"
):
"""
Apply a resolved representation of the merges to a sorting object.
Expand All @@ -250,8 +250,8 @@ def apply_merges_to_sorting(
merged units will have the first unit_id of every lists of merges.
censor_ms: float | None, default: None
When applying the merges, should be discard consecutive spikes violating a given refractory per
return_kept : bool, default: False
If True, also return also a boolean mask of kept spikes.
return_extra : bool, default: False
If True, also return also a boolean mask of kept spikes and new_unit_ids.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
Expand Down Expand Up @@ -316,8 +316,8 @@ def apply_merges_to_sorting(
spikes = spikes[keep_mask]
sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids)

if return_kept:
return sorting, keep_mask
if return_extra:
return sorting, keep_mask, new_unit_ids
else:
return sorting

Expand Down Expand Up @@ -384,11 +384,13 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
new_unit_ids : list | None, default: None
Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
If None, new ids will be generated.
new_id_strategy : "append" | "take_first", default: "append"
new_id_strategy : "append" | "take_first" | "join", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
* "join" : new_unit_ids will join unit_ids of groups with a "-".
Only works if unit_ids are str otherwise switch to "append"
Returns
-------
Expand Down Expand Up @@ -423,6 +425,12 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
else:
# dtype int
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
elif new_id_strategy == "join":
if np.issubdtype(dtype, np.character):
new_unit_ids = ["-".join(group) for group in merge_unit_groups]
else:
# dtype int
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
raise ValueError("wrong new_id_strategy")

Expand Down
14 changes: 10 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,12 +732,12 @@ def _save_or_select_or_merge(
else:
from spikeinterface.core.sorting_tools import apply_merges_to_sorting

sorting_provenance, keep_mask = apply_merges_to_sorting(
sorting_provenance, keep_mask, _ = apply_merges_to_sorting(
sorting=sorting_provenance,
merge_unit_groups=merge_unit_groups,
new_unit_ids=new_unit_ids,
censor_ms=censor_ms,
return_kept=True,
return_extra=True,
)
if censor_ms is None:
# in this case having keep_mask None is faster instead of having a vector of ones
Expand Down Expand Up @@ -885,6 +885,7 @@ def merge_units(
merging_mode="soft",
sparsity_overlap=0.75,
new_id_strategy="append",
return_new_unit_ids=False,
format="memory",
folder=None,
verbose=False,
Expand Down Expand Up @@ -917,14 +918,15 @@ def merge_units(
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorting.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
return_new_unit_ids : bool, default False
Alse return new_unit_ids which are the ids of the new units.
folder : Path | None, default: None
The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr"
format : "memory" | "binary_folder" | "zarr", default: "memory"
The format of SortingAnalyzer
verbose : bool, default: False
Whether to display calculations (such as sparsity estimation)
Returns
-------
analyzer : SortingAnalyzer
Expand Down Expand Up @@ -952,7 +954,7 @@ def merge_units(
)
all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids)

return self._save_or_select_or_merge(
new_analyzer = self._save_or_select_or_merge(
format=format,
folder=folder,
merge_unit_groups=merge_unit_groups,
Expand All @@ -964,6 +966,10 @@ def merge_units(
new_unit_ids=new_unit_ids,
**job_kwargs,
)
if return_new_unit_ids:
return new_analyzer, new_unit_ids
else:
return new_analyzer

def copy(self):
"""
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_apply_merges_to_sorting():
spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"]
)

sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True)
sorting3, keep_mask, _ = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_extra=True)
spikes3 = sorting3.to_spike_vector()
assert spikes3.size < spikes1.size
assert not keep_mask[1]
Expand Down Expand Up @@ -153,6 +153,11 @@ def test_generate_unit_ids_for_merge_group():
)
assert np.array_equal(new_unit_ids, ["0", "9"])

new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="join"
)
assert np.array_equal(new_unit_ids, ["0-5", "9-15"])


if __name__ == "__main__":
# test_spike_vector_to_spike_trains()
Expand Down
Loading

0 comments on commit 9e84a62

Please sign in to comment.