diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b78a346c02..85f83b0816 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2,6 +2,7 @@ from typing import Literal, Optional from pathlib import Path +from itertools import chain import os import json import pickle @@ -978,14 +979,17 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): >>> sorting_analyzer.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}}) """ - for extension_name in extensions.keys(): + + sorted_extensions = _sort_extensions_by_dependency(extensions) + + for extension_name in sorted_extensions.keys(): for child in _get_children_dependencies(extension_name): self.delete_extension(child) extensions_with_pipeline = {} extensions_without_pipeline = {} extensions_post_pipeline = {} - for extension_name, extension_params in extensions.items(): + for extension_name, extension_params in sorted_extensions.items(): if extension_name == "quality_metrics": # PATCH: the quality metric is computed after the pipeline, since some of the metrics optionally require # the output of the pipeline extensions (e.g., spike_amplitudes, spike_locations). @@ -1009,6 +1013,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): all_nodes = [] result_routage = [] extension_instances = {} + for extension_name, extension_params in extensions_with_pipeline.items(): extension_class = get_extension_class(extension_name) assert self.has_recording(), f"Extension {extension_name} need the recording" @@ -1024,6 +1029,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): all_nodes.extend(nodes) job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys()) + results = run_node_pipeline( self.recording, all_nodes, @@ -1191,6 +1197,58 @@ def get_default_extension_params(self, extension_name: str): return get_default_analyzer_extension_params(extension_name) +def _sort_extensions_by_dependency(extensions): + """ + Sorts a dictionary of extensions so that the parents of each extension are on the "left" of their children. + Assumes there is a valid ordering of the included extensions. + + Parameters + ---------- + extensions: dict + A dict of extensions. + + Returns + ------- + sorted_extensions: dict + A dict of extensions, with the parents on the left of their children. + """ + + extensions_list = list(extensions.keys()) + extension_params = list(extensions.values()) + + i = 0 + while i < len(extensions_list): + + extension = extensions_list[i] + dependencies = get_extension_class(extension).depend_on + + # Split cases with an "or" in them, and flatten into a list + dependencies = list(chain.from_iterable([dependency.split("|") for dependency in dependencies])) + + # Should only iterate if nothing has happened. + # Otherwise, should check the dependency which has just been moved => at position i + did_nothing = True + for dependency in dependencies: + + # if dependency is on the right, move it left of the current dependency + if dependency in extensions_list[i:]: + + dependency_arg = extensions_list.index(dependency) + + extension_params.pop(dependency_arg) + extension_params.insert(i, extensions[dependency]) + + extensions_list.pop(dependency_arg) + extensions_list.insert(i, dependency) + + did_nothing = False + + if did_nothing: + i += 1 + + return dict(zip(extensions_list, extension_params)) + + global _possible_extensions _possible_extensions = [] diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 7c47920978..8991c959ad 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -224,6 +224,26 @@ def test_delete_on_recompute(): assert sorting_analyzer.get_extension("waveforms") is None +def test_compute_several(): + sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False) + + # should raise an error since waveforms depends on random_spikes, which isn't calculated + with pytest.raises(AssertionError): + sorting_analyzer.compute(["waveforms"]) + + # check that waveforms are calculated + sorting_analyzer.compute(["random_spikes", "waveforms"]) + waveform_data = sorting_analyzer.get_extension("waveforms").get_data() + assert waveform_data is not None + + sorting_analyzer.delete_extension("waveforms") + sorting_analyzer.delete_extension("random_spikes") + + # check that waveforms are calculated as before, even when parent is after child + sorting_analyzer.compute(["waveforms", "random_spikes"]) + assert np.all(waveform_data == sorting_analyzer.get_extension("waveforms").get_data()) + + if __name__ == "__main__": test_ComputeWaveforms(format="memory", sparse=True) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 66b670d956..faed5161c6 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -3,14 +3,18 @@ import shutil -from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import ( + generate_ground_truth_recording, create_sorting_analyzer, load_sorting_analyzer, get_available_analyzer_extensions, get_default_analyzer_extension_params, ) -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.core.sortinganalyzer import ( + register_result_extension, + AnalyzerExtension, + _sort_extensions_by_dependency, +) import numpy as np @@ -255,6 +259,30 @@ def test_extension(): register_result_extension(DummyAnalyzerExtension2) +def test_extensions_sorting(): + + # nothing happens if all parents are on the left of the children + extensions_in_order = {"random_spikes": {"rs": 1}, "waveforms": {"wv": 2}} + sorted_extensions_1 = _sort_extensions_by_dependency(extensions_in_order) + assert list(sorted_extensions_1.keys()) == list(extensions_in_order.keys()) + + extensions_out_of_order = {"waveforms": {"wv": 2}, "random_spikes": {"rs": 1}} + sorted_extensions_2 = _sort_extensions_by_dependency(extensions_out_of_order) + assert list(sorted_extensions_2.keys()) == list(extensions_in_order.keys()) + + # doing two movements + extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}} + extensions_qm_correct = {"waveforms": {}, "templates": {}, "quality_metrics": {}} + sorted_extensions_3 = _sort_extensions_by_dependency(extensions_qm_left) + assert list(sorted_extensions_3.keys()) == list(extensions_qm_correct.keys()) + + # should move parent (waveforms) left of child (quality_metrics), and move grandparent (random_spikes) left of parent + extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}, "random_spikes": {}} + extensions_qm_correct = {"random_spikes": {}, "waveforms": {}, "templates": {}, "quality_metrics": {}} + sorted_extensions_4 = _sort_extensions_by_dependency(extensions_qm_left) + assert list(sorted_extensions_4.keys()) == list(extensions_qm_correct.keys()) + + if __name__ == "__main__": tmp_path = Path("test_SortingAnalyzer") test_SortingAnalyzer_memory(tmp_path)