From daf751057e021eb86d8d1a5906127d0d4891263f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 15 Jul 2024 20:55:42 +0200 Subject: [PATCH 01/10] Start apply_curation() --- .../curation/curation_format.py | 141 ++++++++++++++++-- .../curation/tests/test_curation_format.py | 24 ++- 2 files changed, 149 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index fc75f74399..770a1052bd 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,5 +1,8 @@ from itertools import combinations +import numpy as np + +from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting supported_curation_format_versions = {"1"} @@ -119,9 +122,9 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo return curation_dict -def curation_label_to_dataframe(curation_dict): +def curation_label_to_vectors(curation_dict): """ - Transform the curation dict into a pandas dataframe. + Transform the curation dict into dict of vectors. For label category with exclusive=True : a column is created and values are the unique label. For label category with exclusive=False : one column per possible is created and values are boolean. @@ -134,30 +137,140 @@ def curation_label_to_dataframe(curation_dict): Returns ------- - labels : pd.DataFrame - dataframe with labels. + labels: dict of numpy vector + """ - import pandas as pd + unit_ids = list(curation_dict["unit_ids"]) + n = len(unit_ids) - labels = pd.DataFrame(index=curation_dict["unit_ids"]) + labels = {} for label_key, label_def in curation_dict["label_definitions"].items(): if label_def["exclusive"]: - assert label_key not in labels.columns, f"{label_key} is already a column" - labels[label_key] = pd.Series(dtype=str) - labels[label_key][:] = "" + assert label_key not in labels, f"{label_key} is already a key" + labels[label_key] = [""] * n for lbl in curation_dict["manual_labels"]: value = lbl.get(label_key, []) if len(value) == 1: - labels.at[lbl["unit_id"], label_key] = value[0] + unit_index = unit_ids.index(lbl["unit_id"]) + labels[label_key][unit_index] = value[0] + labels[label_key] = np.array(labels[label_key]) else: for label_opt in label_def["label_options"]: - assert label_opt not in labels.columns, f"{label_opt} is already a column" - labels[label_opt] = pd.Series(dtype=bool) - labels[label_opt][:] = False + assert label_opt not in labels, f"{label_opt} is already a key" + labels[label_opt] = np.zeros(n, dtype=bool) for lbl in curation_dict["manual_labels"]: values = lbl.get(label_key, []) for value in values: - labels.at[lbl["unit_id"], value] = True + unit_index = unit_ids.index(lbl["unit_id"]) + labels[value][unit_index] = True + + return labels + + +def curation_label_to_dataframe(curation_dict): + """ + Transform the curation dict into a pandas dataframe. + For label category with exclusive=True : a column is created and values are the unique label. + For label category with exclusive=False : one column per possible is created and values are boolean. + + If exclusive=False and the same label appear several times then it raises an error. + + Parameters + ---------- + curation_dict : dict + A curation dictionary + Returns + ------- + labels : pd.DataFrame + dataframe with labels. + """ + import pandas as pd + labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) return labels + + + +def apply_curation_labels(sorting, curation_dict): + labels = curation_label_to_vectors(curation_dict) + unit_ids = np.asarray(curation_dict["unit_ids"]) + mask = np.isin(unit_ids, sorting.unit_ids) + for key, values in labels.items(): + sorting.set_property(key, values[mask], unit_ids[mask]) + + +def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_strategy="append", + merging_mode="soft", sparsity_overlap=0.75, verbose=False, + **job_kwargs): + """ + Apply curation dict to a Sorting or an SortingAnalyzer. + + Steps are done this order: + 1. Apply removal using curation_dict["removed_units"] + 2. Apply merges using curation_dict["merge_unit_groups"] + 3. Set labels using curation_dict["manual_labels"] + + A new Sorting or SortingAnalyzer (in memory) is returned. + The user (an adult) has the responsability to save it somewhere (or not). + + Parameters + ---------- + sorting_or_analyzer : Sorting | SortingAnalyzer + The Sorting object to apply merges. + curation_dict : dict + The curation dict. + censor_ms: float | None, default: None + When applying the merges, should be discard consecutive spikes violating a given refractory per + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + merging_mode : "soft" | "hard" + Used for SortingAnalyzer + sparsity_overlap: + + verbose: + + **job_kwargs + + Returns + ------- + sorting_or_analyzer : Sorting | SortingAnalyzer + The curated object. + + + """ + validate_curation_dict(curation_dict) + if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): + raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") + + + if isinstance(sorting_or_analyzer, BaseSorting): + sorting = sorting_or_analyzer + sorting = sorting.remove_units(curation_dict["removed_units"]) + sorting = apply_merges_to_sorting(sorting, curation_dict["merge_unit_groups"], + censor_ms=censor_ms, return_kept=False, new_id_strategy=new_id_strategy) + apply_curation_labels(sorting, curation_dict) + return sorting + + elif isinstance(sorting_or_analyzer, SortingAnalyzer): + analyzer = sorting_or_analyzer + analyzer = analyzer.remove_units(curation_dict["removed_units"]) + analyzer = analyzer.merge_units( + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + new_id_strategy=new_id_strategy, + format="memory", + verbose=verbose, + **job_kwargs, + ) + apply_curation_labels(analyzer.sorting, curation_dict) + return analyzer + else: + raise ValueError("apply_curation() must have a Sorting or a SortingAnalyzer") + + diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 94812ee0aa..b942de27c0 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -6,6 +6,7 @@ from spikeinterface.curation.curation_format import ( validate_curation_dict, convert_from_sortingview_curation_format_v0, + curation_label_to_vectors, curation_label_to_dataframe, ) @@ -141,6 +142,20 @@ def test_convert_from_sortingview_curation_format_v0(): validate_curation_dict(curation_v1) + +def test_curation_label_to_vectors(): + + labels = curation_label_to_vectors(curation_ids_int) + assert "quality" in labels + assert "excitatory" in labels + print(labels) + + labels = curation_label_to_vectors(curation_ids_str) + print(labels) + + + + def test_curation_label_to_dataframe(): df = curation_label_to_dataframe(curation_ids_int) @@ -152,10 +167,15 @@ def test_curation_label_to_dataframe(): # print(df) +def test_apply_curation(): + pass + # TODO + if __name__ == "__main__": # test_curation_format_validation() # test_to_from_json() # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_dataframe() + # test_curation_label_to_vectors() + test_curation_label_to_dataframe() - print(json.dumps(curation_ids_str, indent=4)) + # print(json.dumps(curation_ids_str, indent=4)) From b0ea548fa5030878b1b7951f3dcce5994d12f99a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jul 2024 08:10:05 +0200 Subject: [PATCH 02/10] test_apply_curation --- .../curation/curation_format.py | 11 ++++++--- .../curation/tests/test_curation_format.py | 23 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 770a1052bd..aa57923078 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -194,10 +194,15 @@ def curation_label_to_dataframe(curation_dict): def apply_curation_labels(sorting, curation_dict): labels = curation_label_to_vectors(curation_dict) - unit_ids = np.asarray(curation_dict["unit_ids"]) - mask = np.isin(unit_ids, sorting.unit_ids) + # unit_ids = np.asarray(curation_dict["unit_ids"]) + # mask = np.isin(unit_ids, sorting.unit_ids) for key, values in labels.items(): - sorting.set_property(key, values[mask], unit_ids[mask]) + all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + if unit_id in curation_dict["unit_ids"]: + ind = curation_dict["unit_ids"].index(unit_id) + all_values[unit_ind] = values[ind] + sorting.set_property(key, all_values) def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_strategy="append", diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index b942de27c0..842994544d 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -2,12 +2,16 @@ from pathlib import Path import json +import numpy as np + +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface.curation.curation_format import ( validate_curation_dict, convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, + apply_curation ) @@ -168,14 +172,25 @@ def test_curation_label_to_dataframe(): def test_apply_curation(): - pass - # TODO + recording, sorting = generate_ground_truth_recording(durations=[10.], num_units=9, seed=2205) + sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + sorting_curated = apply_curation(sorting, curation_ids_int) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" + assert sorting_curated.get_property("excitatory", ids=[2])[0] + + analyzer_curated = apply_curation(analyzer, curation_ids_int) + assert "quality" in analyzer_curated.sorting.get_property_keys() + if __name__ == "__main__": # test_curation_format_validation() # test_to_from_json() # test_convert_from_sortingview_curation_format_v0() # test_curation_label_to_vectors() - test_curation_label_to_dataframe() + # test_curation_label_to_dataframe() + + test_apply_curation() - # print(json.dumps(curation_ids_str, indent=4)) From 9b1f59650120a08ca91034b3aded30019869039b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 16 Jul 2024 20:07:44 +0200 Subject: [PATCH 03/10] =?UTF-8?q?apply=5Fsortingview=5Fcuration(=C3=A0=20u?= =?UTF-8?q?sing=20new=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../curation/sortingview_curation.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index c4d2a32958..55cca6a558 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -4,11 +4,44 @@ from pathlib import Path from .curationsorting import CurationSorting +from .curation_format import convert_from_sortingview_curation_format_v0, apply_curation -# @alessio -# TODO later : this should be reimplemented using the new curation format -def apply_sortingview_curation( +# TODO discussion on this +def apply_sortingview_curation(sorting, uri_or_json): + + # download + if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): + with open(uri_or_json, "r") as f: + curation_dict = json.load(f) + else: + try: + import kachery_cloud as kcl + except ImportError: + raise ImportError( + "To apply a SortingView manual curation, you need to have sortingview installed: " + ">>> pip install sortingview" + ) + + try: + curation_dict = kcl.load_json(uri=uri_or_json) + except: + raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") + + # convert to new format + if "format_version" not in curation_dict: + curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) + + # apply + sorting_curated = apply_curation(sorting, curation_dict) + + return sorting_curated + + + + +# TODO discussion do we keep this ??? +def apply_sortingview_curation_legacy( sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False ): """ From 49360e89c25844b7e8a2895a02213f06c826695c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 17 Jul 2024 12:42:12 +0200 Subject: [PATCH 04/10] improve back compatibility with apply_sortingview_curation() --- src/spikeinterface/core/sorting_tools.py | 21 +++-- src/spikeinterface/core/sortinganalyzer.py | 15 +++- .../core/tests/test_sorting_tools.py | 9 +- .../curation/curation_format.py | 84 ++++++++++++++++--- .../curation/sortingview_curation.py | 79 +++++++++++++++-- .../tests/test_sortingview_curation.py | 32 +++---- 6 files changed, 193 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 2a2f7b6b5a..b5278a88bb 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -227,7 +227,8 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" + sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, + new_id_strategy="append" ): """ Apply a resolved representation of the merges to a sorting object. @@ -250,8 +251,8 @@ def apply_merges_to_sorting( merged units will have the first unit_id of every lists of merges. censor_ms: float | None, default: None When applying the merges, should be discard consecutive spikes violating a given refractory per - return_kept : bool, default: False - If True, also return also a boolean mask of kept spikes. + return_extra : bool, default: False + If True, also return also a boolean mask of kept spikes and new_unit_ids. new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. @@ -316,8 +317,8 @@ def apply_merges_to_sorting( spikes = spikes[keep_mask] sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) - if return_kept: - return sorting, keep_mask + if return_extra: + return sorting, keep_mask, new_unit_ids else: return sorting @@ -380,11 +381,13 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ new_unit_ids : list | None, default: None Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. If None, new ids will be generated. - new_id_strategy : "append" | "take_first", default: "append" + new_id_strategy : "append" | "take_first" | "join", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + * "join" : new_unit_ids will join unit_ids of groups with a "-". + Only works if unit_ids are str otherwise switch to "append" Returns ------- @@ -419,6 +422,12 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ else: # dtype int new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) + elif new_id_strategy == "join": + if np.issubdtype(dtype, np.character): + new_unit_ids = ["-".join(group) for group in merge_unit_groups] + else: + # dtype int + new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: raise ValueError("wrong new_id_strategy") diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 27a47a31ac..5bee8f7540 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -732,12 +732,12 @@ def _save_or_select_or_merge( else: from spikeinterface.core.sorting_tools import apply_merges_to_sorting - sorting_provenance, keep_mask = apply_merges_to_sorting( + sorting_provenance, keep_mask, _ = apply_merges_to_sorting( sorting=sorting_provenance, merge_unit_groups=merge_unit_groups, new_unit_ids=new_unit_ids, censor_ms=censor_ms, - return_kept=True, + return_extra=True, ) if censor_ms is None: # in this case having keep_mask None is faster instead of having a vector of ones @@ -879,6 +879,7 @@ def merge_units( merging_mode="soft", sparsity_overlap=0.75, new_id_strategy="append", + return_new_unit_ids=False, format="memory", folder=None, verbose=False, @@ -911,6 +912,8 @@ def merge_units( The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + return_new_unit_ids : bool, default False + Alse return new_unit_ids which are the ids of the new units. folder : Path or None The new folder where selected waveforms are copied format : "auto" | "binary_folder" | "zarr" @@ -945,7 +948,7 @@ def merge_units( ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) - return self._save_or_select_or_merge( + new_analyzer = self._save_or_select_or_merge( format=format, folder=folder, merge_unit_groups=merge_unit_groups, @@ -957,6 +960,12 @@ def merge_units( new_unit_ids=new_unit_ids, **job_kwargs, ) + if return_new_unit_ids: + return new_analyzer, new_unit_ids + else: + return new_analyzer + + def copy(self): """ diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 38baf62c35..944ba61827 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -96,7 +96,7 @@ def test_apply_merges_to_sorting(): spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] ) - sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) + sorting3, keep_mask, _ = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_extra=True) spikes3 = sorting3.to_spike_vector() assert spikes3.size < spikes1.size assert not keep_mask[1] @@ -153,6 +153,13 @@ def test_generate_unit_ids_for_merge_group(): ) assert np.array_equal(new_unit_ids, ["0", "9"]) + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="join" + ) + assert np.array_equal(new_unit_ids, ["0-5", "9-15"]) + + + if __name__ == "__main__": # test_spike_vector_to_spike_trains() diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index aa57923078..2c79f379fa 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -3,6 +3,7 @@ import numpy as np from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting +import copy supported_curation_format_versions = {"1"} @@ -168,6 +169,31 @@ def curation_label_to_vectors(curation_dict): return labels +def clean_curation_dict(curation_dict): + """ + In some cases the curation_dict can have inconsistencies (like in the sorting view format). + For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. + This is ambiguous! + + This cleaner helper units tagged as removed will be revmove from merges lists. + """ + curation_dict = copy.deepcopy(curation_dict) + + clean_merge_unit_groups = [] + for group in curation_dict["merge_unit_groups"]: + clean_group = [] + for unit_id in group: + if unit_id not in curation_dict["removed_units"]: + clean_group.append(unit_id) + if len(clean_group) > 1: + clean_merge_unit_groups.append(clean_group) + + curation_dict["merge_unit_groups"] = clean_merge_unit_groups + return curation_dict + + + + def curation_label_to_dataframe(curation_dict): """ Transform the curation dict into a pandas dataframe. @@ -192,17 +218,51 @@ def curation_label_to_dataframe(curation_dict): -def apply_curation_labels(sorting, curation_dict): - labels = curation_label_to_vectors(curation_dict) - # unit_ids = np.asarray(curation_dict["unit_ids"]) - # mask = np.isin(unit_ids, sorting.unit_ids) - for key, values in labels.items(): +def apply_curation_labels(sorting, new_unit_ids, curation_dict): + """ + Apply manual labels after merges. + + Rules: + * label for non merge is applied first + * for merged group, when exclusive=True, if all have the same label then this label is applied + * for merged group, when exclusive=False, if one unit has the label then the new one have also it + """ + + # Please note that manual_labels is done on the unit_ids before the merge!!! + manual_labels = curation_label_to_vectors(curation_dict) + + # apply on non merged + for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): - if unit_id in curation_dict["unit_ids"]: + if unit_id not in new_unit_ids: ind = curation_dict["unit_ids"].index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) + + for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): + for label_key, label_def in curation_dict["label_definitions"].items(): + if label_def["exclusive"]: + group_values = [] + for unit_id in old_group_ids: + ind = curation_dict["unit_ids"].index(unit_id) + value = manual_labels[label_key][ind] + if value != '': + group_values.append(value) + if len(set(group_values)) == 1: + # all group has the same label or empty + sorting.set_property(key, values=group_values, ids=[new_unit_id]) + else: + + for key in label_def["label_options"]: + group_values = [] + for unit_id in old_group_ids: + ind = curation_dict["unit_ids"].index(unit_id) + value = manual_labels[key][ind] + group_values.append(value) + new_value = np.any(group_values) + sorting.set_property(key, values=[new_value], ids=[new_unit_id]) + def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_strategy="append", @@ -255,25 +315,27 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer sorting = sorting.remove_units(curation_dict["removed_units"]) - sorting = apply_merges_to_sorting(sorting, curation_dict["merge_unit_groups"], - censor_ms=censor_ms, return_kept=False, new_id_strategy=new_id_strategy) - apply_curation_labels(sorting, curation_dict) + sorting, _, new_unit_ids = apply_merges_to_sorting(sorting, curation_dict["merge_unit_groups"], + censor_ms=censor_ms, return_extra=True, + new_id_strategy=new_id_strategy) + apply_curation_labels(sorting, new_unit_ids, curation_dict) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer analyzer = analyzer.remove_units(curation_dict["removed_units"]) - analyzer = analyzer.merge_units( + analyzer, new_unit_ids = analyzer.merge_units( curation_dict["merge_unit_groups"], censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, new_id_strategy=new_id_strategy, + return_new_unit_ids=True, format="memory", verbose=verbose, **job_kwargs, ) - apply_curation_labels(analyzer.sorting, curation_dict) + apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: raise ValueError("apply_curation() must have a Sorting or a SortingAnalyzer") diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 55cca6a558..22373a882c 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -1,15 +1,51 @@ from __future__ import annotations + +import warnings + import json import numpy as np from pathlib import Path from .curationsorting import CurationSorting -from .curation_format import convert_from_sortingview_curation_format_v0, apply_curation +from .curation_format import convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, clean_curation_dict + + +def apply_sortingview_curation( + sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None + ): + """ + Apply curation from SortingView manual legacy curation format (before the official "curation_format") + First, merges (if present) are applied. Then labels are loaded and units + are optionally filtered based on exclude_labels and include_labels. + + Parameters + ---------- + sorting_or_analyzer : Sorting | SortingAnalyzer + The sorting or analyzer to be curated + uri_or_json : str or Path + The URI curation link from SortingView or the path to the curation json file + exclude_labels : list, default: None + Optional list of labels to exclude (e.g. ["reject", "noise"]). + Mutually exclusive with include_labels + include_labels : list, default: None + Optional list of labels to include (e.g. ["accept"]). + Mutually exclusive with exclude_labels, by default None + skip_merge : bool, default: False + If True, merges are not applied (only labels) + verbose : None + Deprecated + + + Returns + ------- + sorting_or_analyzer_curated : BaseSorting + The curated sorting or analyzer + """ + + if verbose is not None: + warnings.warn("versobe in apply_sortingview_curation() is deprecated") -# TODO discussion on this -def apply_sortingview_curation(sorting, uri_or_json): - # download if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): with open(uri_or_json, "r") as f: @@ -32,15 +68,46 @@ def apply_sortingview_curation(sorting, uri_or_json): if "format_version" not in curation_dict: curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) + + unit_ids = sorting_or_analyzer.unit_ids + + # this is a hack because it was not in the old format + curation_dict["unit_ids"] = list(unit_ids) + + if exclude_labels is not None: + assert include_labels is None, "Use either `include_labels` or `exclude_labels` to filter units." + manual_labels = curation_label_to_vectors(curation_dict) + removed_units = [] + for k in exclude_labels: + remove_mask = manual_labels[k] + removed_units.extend(unit_ids[remove_mask]) + removed_units = np.unique(removed_units) + curation_dict["removed_units"] = removed_units + + if include_labels is not None: + manual_labels = curation_label_to_vectors(curation_dict) + removed_units = [] + for k in include_labels: + remove_mask = ~manual_labels[k] + removed_units.extend(unit_ids[remove_mask]) + removed_units = np.unique(removed_units) + curation_dict["removed_units"] = removed_units + + if skip_merge: + curation_dict["merge_unit_groups"] = [] + + # cleaner to ensure validity + curation_dict = clean_curation_dict(curation_dict) + # apply - sorting_curated = apply_curation(sorting, curation_dict) + sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") return sorting_curated -# TODO discussion do we keep this ??? +# TODO @alessio you remove this after testing def apply_sortingview_curation_legacy( sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False ): diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 00721ff34d..bb152e7f71 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -52,7 +52,7 @@ def test_gh_curation(): # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" - sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) + sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri) assert len(sorting_curated_gh.unit_ids) == 9 assert 1, 2 in sorting_curated_gh.unit_ids @@ -81,7 +81,7 @@ def test_sha1_curation(): # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 sha1_uri = "sha1://449a428e8824eef9ad9bcc3241e45a2cee02d381" - sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) + sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri) # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 @@ -109,7 +109,7 @@ def test_json_curation(): # from curation.json json_file = parent_folder / "sv-sorting-curation.json" # print(f"Sorting: {sorting.get_unit_ids()}") - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file) assert len(sorting_curated_json.unit_ids) == 9 assert 1, 2 in sorting_curated_json.unit_ids @@ -146,7 +146,7 @@ def test_false_positive_curation(): # print("Sorting: {}".format(sorting.get_unit_ids())) json_file = parent_folder / "sv-sorting-curation-false-positive.json" - sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file) # print("Curated:", sorting_curated_json.get_unit_ids()) # Assertions @@ -190,17 +190,13 @@ def test_label_inheritance_int(): assert not sorting_merge.get_unit_property(unit_id=10, key="noise") assert sorting_merge.get_unit_property(unit_id=10, key="accept") - # Assertions for exclude_labels + # Assertions for exclude_labels should all be False sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") - assert 9 not in sorting_exclude_noise.get_unit_ids() + assert np.all(~sorting_exclude_noise.get_property("noise")) # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") - assert 8 not in sorting_include_accept.get_unit_ids() - assert 9 not in sorting_include_accept.get_unit_ids() - assert 10 in sorting_include_accept.get_unit_ids() + assert np.all(sorting_include_accept.get_property("accept")) def test_label_inheritance_str(): @@ -219,7 +215,7 @@ def test_label_inheritance_str(): # Apply curation json_file = parent_folder / "sv-sorting-curation-str.json" - sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) # Assertions for merged units # print(f"Merge only: {sorting_merge.get_unit_ids()}") @@ -238,22 +234,18 @@ def test_label_inheritance_str(): assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") - # Assertions for exclude_labels + # Assertions for exclude_labels should all be False sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) - # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") - assert "c-d" not in sorting_exclude_noise.get_unit_ids() + assert np.all(~sorting_exclude_noise.get_property("noise")) # Assertions for include_labels sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) - # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") - assert "a-b" not in sorting_include_accept.get_unit_ids() - assert "c-d" not in sorting_include_accept.get_unit_ids() - assert "e-f" in sorting_include_accept.get_unit_ids() + assert np.all(sorting_include_accept.get_property("accept")) if __name__ == "__main__": # generate_sortingview_curation_dataset() - test_sha1_curation() + # test_sha1_curation() test_gh_curation() test_json_curation() test_false_positive_curation() From 096caa8f4125c35b67ef556641f60aa8a7b49c69 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 19 Jul 2024 12:41:05 +0200 Subject: [PATCH 05/10] Merci Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- .../curation/curation_format.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 2c79f379fa..61e6591467 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -175,7 +175,7 @@ def clean_curation_dict(curation_dict): For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. This is ambiguous! - This cleaner helper units tagged as removed will be revmove from merges lists. + This cleaner helper function ensures units tagged as `removed_units` are removed from the `merge_unit_groups` """ curation_dict = copy.deepcopy(curation_dict) @@ -200,7 +200,7 @@ def curation_label_to_dataframe(curation_dict): For label category with exclusive=True : a column is created and values are the unique label. For label category with exclusive=False : one column per possible is created and values are boolean. - If exclusive=False and the same label appear several times then it raises an error. + If exclusive=False and the same label appears several times then an error is raised. Parameters ---------- @@ -269,9 +269,9 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st merging_mode="soft", sparsity_overlap=0.75, verbose=False, **job_kwargs): """ - Apply curation dict to a Sorting or an SortingAnalyzer. + Apply curation dict to a Sorting or a SortingAnalyzer. - Steps are done this order: + Steps are done in this order: 1. Apply removal using curation_dict["removed_units"] 2. Apply merges using curation_dict["merge_unit_groups"] 3. Set labels using curation_dict["manual_labels"] @@ -286,15 +286,20 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st curation_dict : dict The curation dict. censor_ms: float | None, default: None - When applying the merges, should be discard consecutive spikes violating a given refractory per + When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of + as the desired refractory period. If `censor_ms=None`, no spikes are discarded. new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. - * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges - merging_mode : "soft" | "hard" - Used for SortingAnalyzer - sparsity_overlap: + merging_mode : "soft" | "hard", default: "soft" + How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of + the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately + performed, reloading waveforms if needed + sparsity_overlap : float, default 0.75 + The percentage of overlap that units should share in order to accept merges. If this criteria is not + achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. verbose: @@ -338,6 +343,6 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: - raise ValueError("apply_curation() must have a Sorting or a SortingAnalyzer") + raise TypeError(f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)") From 52953ee2bf4b0e542257c0957861837bb103f511 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:43:25 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 3 +-- src/spikeinterface/core/sortinganalyzer.py | 4 +--- .../core/tests/test_sorting_tools.py | 2 -- .../curation/curation_format.py | 16 ++++++------- .../curation/sortingview_curation.py | 24 ++++++++++--------- .../curation/tests/test_curation_format.py | 8 ++----- 6 files changed, 24 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 1798e77b2c..5f33350820 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -227,8 +227,7 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, - new_id_strategy="append" + sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append" ): """ Apply a resolved representation of the merges to a sorting object. diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0ee4e46031..ac142405ab 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -954,7 +954,7 @@ def merge_units( ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) - new_analyzer = self._save_or_select_or_merge( + new_analyzer = self._save_or_select_or_merge( format=format, folder=folder, merge_unit_groups=merge_unit_groups, @@ -970,8 +970,6 @@ def merge_units( return new_analyzer, new_unit_ids else: return new_analyzer - - def copy(self): """ diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 0e7285eec7..34bb3a221d 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -159,8 +159,6 @@ def test_generate_unit_ids_for_merge_group(): assert np.array_equal(new_unit_ids, ["0-5", "9-15"]) - - if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 61e6591467..831a27a868 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -187,7 +187,7 @@ def clean_curation_dict(curation_dict): clean_group.append(unit_id) if len(clean_group) > 1: clean_merge_unit_groups.append(clean_group) - + curation_dict["merge_unit_groups"] = clean_merge_unit_groups return curation_dict @@ -239,7 +239,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): ind = curation_dict["unit_ids"].index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - + for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): for label_key, label_def in curation_dict["label_definitions"].items(): if label_def["exclusive"]: @@ -287,21 +287,21 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st The curation dict. censor_ms: float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of - as the desired refractory period. If `censor_ms=None`, no spikes are discarded. + as the desired refractory period. If `censor_ms=None`, no spikes are discarded. new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges merging_mode : "soft" | "hard", default: "soft" - How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of - the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately + How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of + the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately performed, reloading waveforms if needed sparsity_overlap : float, default 0.75 The percentage of overlap that units should share in order to accept merges. If this criteria is not achieved, soft merging will not be possible and an error will be raised. This is for use with a SortingAnalyzer input. - verbose: + verbose: **job_kwargs @@ -325,7 +325,7 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st new_id_strategy=new_id_strategy) apply_curation_labels(sorting, new_unit_ids, curation_dict) return sorting - + elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer analyzer = analyzer.remove_units(curation_dict["removed_units"]) @@ -344,5 +344,3 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st return analyzer else: raise TypeError(f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)") - - diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 22373a882c..1e36183eed 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -7,12 +7,17 @@ from pathlib import Path from .curationsorting import CurationSorting -from .curation_format import convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, clean_curation_dict +from .curation_format import ( + convert_from_sortingview_curation_format_v0, + apply_curation, + curation_label_to_vectors, + clean_curation_dict, +) def apply_sortingview_curation( - sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None - ): + sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None +): """ Apply curation from SortingView manual legacy curation format (before the official "curation_format") @@ -35,14 +40,14 @@ def apply_sortingview_curation( If True, merges are not applied (only labels) verbose : None Deprecated - + Returns ------- sorting_or_analyzer_curated : BaseSorting The curated sorting or analyzer """ - + if verbose is not None: warnings.warn("versobe in apply_sortingview_curation() is deprecated") @@ -68,10 +73,9 @@ def apply_sortingview_curation( if "format_version" not in curation_dict: curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) - unit_ids = sorting_or_analyzer.unit_ids - # this is a hack because it was not in the old format + # this is a hack because it was not in the old format curation_dict["unit_ids"] = list(unit_ids) if exclude_labels is not None: @@ -92,7 +96,7 @@ def apply_sortingview_curation( removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) curation_dict["removed_units"] = removed_units - + if skip_merge: curation_dict["merge_unit_groups"] = [] @@ -101,10 +105,8 @@ def apply_sortingview_curation( # apply sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") - - return sorting_curated - + return sorting_curated # TODO @alessio you remove this after testing diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 842994544d..af9d8e1eac 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -11,7 +11,7 @@ convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, - apply_curation + apply_curation, ) @@ -146,7 +146,6 @@ def test_convert_from_sortingview_curation_format_v0(): validate_curation_dict(curation_v1) - def test_curation_label_to_vectors(): labels = curation_label_to_vectors(curation_ids_int) @@ -158,8 +157,6 @@ def test_curation_label_to_vectors(): print(labels) - - def test_curation_label_to_dataframe(): df = curation_label_to_dataframe(curation_ids_int) @@ -172,7 +169,7 @@ def test_curation_label_to_dataframe(): def test_apply_curation(): - recording, sorting = generate_ground_truth_recording(durations=[10.], num_units=9, seed=2205) + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) @@ -193,4 +190,3 @@ def test_apply_curation(): # test_curation_label_to_dataframe() test_apply_curation() - From d9b108268ba068be712715501bfe8d41c5e0be61 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 19 Jul 2024 12:46:02 +0200 Subject: [PATCH 07/10] oups --- src/spikeinterface/curation/curation_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 831a27a868..0f9ab60f6b 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -343,4 +343,4 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: - raise TypeError(f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)") + raise TypeError(f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}") From 7caf1e8973fd574314ba8b41b5cf273a784f723d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:46:47 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/curation_format.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 0f9ab60f6b..88190a9bab 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -192,8 +192,6 @@ def clean_curation_dict(curation_dict): return curation_dict - - def curation_label_to_dataframe(curation_dict): """ Transform the curation dict into a pandas dataframe. @@ -213,11 +211,11 @@ def curation_label_to_dataframe(curation_dict): dataframe with labels. """ import pandas as pd + labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) return labels - def apply_curation_labels(sorting, new_unit_ids, curation_dict): """ Apply manual labels after merges. @@ -247,7 +245,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): for unit_id in old_group_ids: ind = curation_dict["unit_ids"].index(unit_id) value = manual_labels[label_key][ind] - if value != '': + if value != "": group_values.append(value) if len(set(group_values)) == 1: # all group has the same label or empty @@ -264,10 +262,16 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): sorting.set_property(key, values=[new_value], ids=[new_unit_id]) - -def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_strategy="append", - merging_mode="soft", sparsity_overlap=0.75, verbose=False, - **job_kwargs): +def apply_curation( + sorting_or_analyzer, + curation_dict, + censor_ms=None, + new_id_strategy="append", + merging_mode="soft", + sparsity_overlap=0.75, + verbose=False, + **job_kwargs, +): """ Apply curation dict to a Sorting or a SortingAnalyzer. @@ -316,13 +320,16 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") - if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer sorting = sorting.remove_units(curation_dict["removed_units"]) - sorting, _, new_unit_ids = apply_merges_to_sorting(sorting, curation_dict["merge_unit_groups"], - censor_ms=censor_ms, return_extra=True, - new_id_strategy=new_id_strategy) + sorting, _, new_unit_ids = apply_merges_to_sorting( + sorting, + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + return_extra=True, + new_id_strategy=new_id_strategy, + ) apply_curation_labels(sorting, new_unit_ids, curation_dict) return sorting @@ -343,4 +350,6 @@ def apply_curation(sorting_or_analyzer, curation_dict, censor_ms=None, new_id_st apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: - raise TypeError(f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}") + raise TypeError( + f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" + ) From 042dfa6ab3c2c94b286fc31520ac162c328f8ede Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 19 Jul 2024 13:21:15 +0200 Subject: [PATCH 09/10] Finalize curation docs --- doc/api.rst | 12 +++++++++--- doc/modules/curation.rst | 22 +++++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 1ac37e4740..4d8bd2f329 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -324,15 +324,21 @@ spikeinterface.curation ------------------------ .. automodule:: spikeinterface.curation - .. autoclass:: CurationSorting - .. autoclass:: MergeUnitsSorting - .. autoclass:: SplitUnitSorting + .. autofunction:: apply_curation .. autofunction:: get_potential_auto_merge .. autofunction:: find_redundant_units .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + +Deprecated +~~~~~~~~~~ +.. automodule:: spikeinterface.curation + .. autofunction:: apply_sortingview_curation + .. autoclass:: CurationSorting + .. autoclass:: MergeUnitsSorting + .. autoclass:: SplitUnitSorting spikeinterface.generation diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 3cdf5c170b..0ac57ff87a 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -260,11 +260,23 @@ format is the definition; the second part of the format is manual action): } -.. note:: - The curation format was recently introduced (v0.101.0), and we are still working on - properly integrating it into the SpikeInterface ecosystem. - Soon there will be functions vailable, in the curation module, to apply this - standardized curation format to ``SortingAnalyzer`` and a ``BaseSorting`` objects. +The curation format can be loaded into a dictionary and directly applied to +a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function. + +.. code-block:: python + + from spikeinterface.curation import apply_curation + + # load the curation JSON file + curation_json = "path/to/curation.json" + with open(curation_json, 'r') as f: + curation_dict = json.load(f) + + # apply the curation to the sorting output + clean_sorting = apply_curation(sorting=sorting, curation_dict=curation_dict) + + # apply the curation to the sorting analyzer + clean_sorting_analyzer = apply_curation(sorting_analyzer=sorting_analyzer, curation_dict=curation_dict) Using the ``SpikeInterface GUI`` From 0eacd1ae4b9cdefa50ae5270720f9ae8462fac2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 19 Jul 2024 13:26:13 +0200 Subject: [PATCH 10/10] Update doc/modules/curation.rst --- doc/modules/curation.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 2c92e31558..d115b33e4a 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -274,10 +274,10 @@ a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterf curation_dict = json.load(f) # apply the curation to the sorting output - clean_sorting = apply_curation(sorting=sorting, curation_dict=curation_dict) + clean_sorting = apply_curation(sorting, curation_dict=curation_dict) # apply the curation to the sorting analyzer - clean_sorting_analyzer = apply_curation(sorting_analyzer=sorting_analyzer, curation_dict=curation_dict) + clean_sorting_analyzer = apply_curation(sorting_analyzer, curation_dict=curation_dict) Using the ``SpikeInterface GUI``