Skip to content

Commit

Permalink
Merge pull request #2705 from chrishalcrow/sort_input_extensions
Browse files Browse the repository at this point in the history
Sort input extensions
  • Loading branch information
samuelgarcia authored May 21, 2024
2 parents 6bb7567 + b0cc5ad commit 26c145c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 4 deletions.
62 changes: 60 additions & 2 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal, Optional

from pathlib import Path
from itertools import chain
import os
import json
import pickle
Expand Down Expand Up @@ -978,14 +979,17 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
>>> sorting_analyzer.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}})
"""
for extension_name in extensions.keys():

sorted_extensions = _sort_extensions_by_dependency(extensions)

for extension_name in sorted_extensions.keys():
for child in _get_children_dependencies(extension_name):
self.delete_extension(child)

extensions_with_pipeline = {}
extensions_without_pipeline = {}
extensions_post_pipeline = {}
for extension_name, extension_params in extensions.items():
for extension_name, extension_params in sorted_extensions.items():
if extension_name == "quality_metrics":
# PATCH: the quality metric is computed after the pipeline, since some of the metrics optionally require
# the output of the pipeline extensions (e.g., spike_amplitudes, spike_locations).
Expand All @@ -1009,6 +1013,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
all_nodes = []
result_routage = []
extension_instances = {}

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
Expand All @@ -1024,6 +1029,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
all_nodes.extend(nodes)

job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys())

results = run_node_pipeline(
self.recording,
all_nodes,
Expand Down Expand Up @@ -1191,6 +1197,58 @@ def get_default_extension_params(self, extension_name: str):
return get_default_analyzer_extension_params(extension_name)


def _sort_extensions_by_dependency(extensions):
"""
Sorts a dictionary of extensions so that the parents of each extension are on the "left" of their children.
Assumes there is a valid ordering of the included extensions.
Parameters
----------
extensions: dict
A dict of extensions.
Returns
-------
sorted_extensions: dict
A dict of extensions, with the parents on the left of their children.
"""

extensions_list = list(extensions.keys())
extension_params = list(extensions.values())

i = 0
while i < len(extensions_list):

extension = extensions_list[i]
dependencies = get_extension_class(extension).depend_on

# Split cases with an "or" in them, and flatten into a list
dependencies = list(chain.from_iterable([dependency.split("|") for dependency in dependencies]))

# Should only iterate if nothing has happened.
# Otherwise, should check the dependency which has just been moved => at position i
did_nothing = True
for dependency in dependencies:

# if dependency is on the right, move it left of the current dependency
if dependency in extensions_list[i:]:

dependency_arg = extensions_list.index(dependency)

extension_params.pop(dependency_arg)
extension_params.insert(i, extensions[dependency])

extensions_list.pop(dependency_arg)
extensions_list.insert(i, dependency)

did_nothing = False

if did_nothing:
i += 1

return dict(zip(extensions_list, extension_params))


global _possible_extensions
_possible_extensions = []

Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface/core/tests/test_analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,26 @@ def test_delete_on_recompute():
assert sorting_analyzer.get_extension("waveforms") is None


def test_compute_several():
sorting_analyzer = get_sorting_analyzer(format="memory", sparse=False)

# should raise an error since waveforms depends on random_spikes, which isn't calculated
with pytest.raises(AssertionError):
sorting_analyzer.compute(["waveforms"])

# check that waveforms are calculated
sorting_analyzer.compute(["random_spikes", "waveforms"])
waveform_data = sorting_analyzer.get_extension("waveforms").get_data()
assert waveform_data is not None

sorting_analyzer.delete_extension("waveforms")
sorting_analyzer.delete_extension("random_spikes")

# check that waveforms are calculated as before, even when parent is after child
sorting_analyzer.compute(["waveforms", "random_spikes"])
assert np.all(waveform_data == sorting_analyzer.get_extension("waveforms").get_data())


if __name__ == "__main__":

test_ComputeWaveforms(format="memory", sparse=True)
Expand Down
32 changes: 30 additions & 2 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@

import shutil

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
load_sorting_analyzer,
get_available_analyzer_extensions,
get_default_analyzer_extension_params,
)
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
from spikeinterface.core.sortinganalyzer import (
register_result_extension,
AnalyzerExtension,
_sort_extensions_by_dependency,
)

import numpy as np

Expand Down Expand Up @@ -255,6 +259,30 @@ def test_extension():
register_result_extension(DummyAnalyzerExtension2)


def test_extensions_sorting():

# nothing happens if all parents are on the left of the children
extensions_in_order = {"random_spikes": {"rs": 1}, "waveforms": {"wv": 2}}
sorted_extensions_1 = _sort_extensions_by_dependency(extensions_in_order)
assert list(sorted_extensions_1.keys()) == list(extensions_in_order.keys())

extensions_out_of_order = {"waveforms": {"wv": 2}, "random_spikes": {"rs": 1}}
sorted_extensions_2 = _sort_extensions_by_dependency(extensions_out_of_order)
assert list(sorted_extensions_2.keys()) == list(extensions_in_order.keys())

# doing two movements
extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}}
extensions_qm_correct = {"waveforms": {}, "templates": {}, "quality_metrics": {}}
sorted_extensions_3 = _sort_extensions_by_dependency(extensions_qm_left)
assert list(sorted_extensions_3.keys()) == list(extensions_qm_correct.keys())

# should move parent (waveforms) left of child (quality_metrics), and move grandparent (random_spikes) left of parent
extensions_qm_left = {"quality_metrics": {}, "waveforms": {}, "templates": {}, "random_spikes": {}}
extensions_qm_correct = {"random_spikes": {}, "waveforms": {}, "templates": {}, "quality_metrics": {}}
sorted_extensions_4 = _sort_extensions_by_dependency(extensions_qm_left)
assert list(sorted_extensions_4.keys()) == list(extensions_qm_correct.keys())


if __name__ == "__main__":
tmp_path = Path("test_SortingAnalyzer")
test_SortingAnalyzer_memory(tmp_path)
Expand Down

0 comments on commit 26c145c

Please sign in to comment.