Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into fix_split
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Jan 15, 2025
2 parents 61d3e2f + 01d1479 commit 169e83d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
4 changes: 3 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@

# for sphinx gallery plugin
sphinx_gallery_conf = {
'only_warn_on_example_error': True,
# This is the default but including here explicitly. Should build all docs and fail on gallery failures only.
# other option would be abort_on_example_error, but this fails on first failure. So we decided against this.
'only_warn_on_example_error': False,
'examples_dirs': ['../examples/tutorials'],
'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples
'subsection_order': ExplicitOrder([
Expand Down
40 changes: 24 additions & 16 deletions src/spikeinterface/curation/curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict):
if not removed_units_set.issubset(unit_set):
raise ValueError("Curation format: some removed units are not in the unit list")

for group in curation_dict["merge_unit_groups"]:
if len(group) < 2:
raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements")

all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]]
for gp_1, gp_2 in combinations(all_merging_groups, 2):
if len(gp_1.intersection(gp_2)) != 0:
raise ValueError("Some units belong to multiple merge groups")
raise ValueError("Curation format: some units belong to multiple merge groups")
if len(removed_units_set.intersection(merged_units_set)) != 0:
raise ValueError("Some units were merged and deleted")
raise ValueError("Curation format: some units were merged and deleted")

# Check the labels exclusivity
for lbl in curation_dict["manual_labels"]:
Expand Down Expand Up @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype)
for unit_ind, unit_id in enumerate(sorting.unit_ids):
if unit_id not in new_unit_ids:
ind = curation_dict["unit_ids"].index(unit_id)
ind = list(curation_dict["unit_ids"]).index(unit_id)
all_values[unit_ind] = values[ind]
sorting.set_property(key, all_values)

Expand All @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
group_values.append(value)
if len(set(group_values)) == 1:
# all group has the same label or empty
sorting.set_property(key, values=group_values, ids=[new_unit_id])
sorting.set_property(key, values=group_values[:1], ids=[new_unit_id])
else:

for key in label_def["label_options"]:
Expand Down Expand Up @@ -339,18 +343,22 @@ def apply_curation(

elif isinstance(sorting_or_analyzer, SortingAnalyzer):
analyzer = sorting_or_analyzer
analyzer = analyzer.remove_units(curation_dict["removed_units"])
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
if len(curation_dict["removed_units"]) > 0:
analyzer = analyzer.remove_units(curation_dict["removed_units"])
if len(curation_dict["merge_unit_groups"]) > 0:
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
else:
new_unit_ids = []
apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict)
return analyzer
else:
Expand Down

0 comments on commit 169e83d

Please sign in to comment.