Skip to content

Commit

Permalink
Merge pull request #2571 from samuelgarcia/extension_auto_load
Browse files Browse the repository at this point in the history
Proposal for auto import extensions module.
  • Loading branch information
alejoe91 authored Mar 19, 2024
2 parents 8e2bbb8 + efa1923 commit 52617ee
Showing 1 changed file with 59 additions and 27 deletions.
86 changes: 59 additions & 27 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import weakref
import shutil
import warnings
import importlib

import numpy as np

Expand Down Expand Up @@ -911,35 +912,29 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):

def get_saved_extension_names(self):
"""
Get extension saved in folder or zarr that can be loaded.
Get extension names saved in folder or zarr that can be loaded.
This do not load data, this only explores the directory.
"""
assert self.format != "memory"
global _possible_extensions
saved_extension_names = []
if self.format == "binary_folder":
ext_folder = self.folder / "extensions"
if ext_folder.is_dir():
for extension_folder in ext_folder.iterdir():
is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file()
if not is_saved:
continue
saved_extension_names.append(extension_folder.stem)

if self.format == "zarr":
elif self.format == "zarr":
zarr_root = self._get_zarr_root(mode="r")
if "extensions" in zarr_root.keys():
extension_group = zarr_root["extensions"]
else:
extension_group = None
for extension_name in extension_group.keys():
if "params" in extension_group[extension_name].attrs.keys():
saved_extension_names.append(extension_name)

saved_extension_names = []
for extension_class in _possible_extensions:
extension_name = extension_class.extension_name

if self.format == "binary_folder":
extension_folder = self.folder / "extensions" / extension_name
is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file()
elif self.format == "zarr":
if extension_group is not None:
is_saved = (
extension_name in extension_group.keys()
and "params" in extension_group[extension_name].attrs.keys()
)
else:
is_saved = False
if is_saved:
saved_extension_names.append(extension_class.extension_name)
else:
raise ValueError("SortingAnalyzer.get_saved_extension_names() works only with binary_folder and zarr")

return saved_extension_names

Expand Down Expand Up @@ -1060,14 +1055,16 @@ def register_result_extension(extension_class):
_possible_extensions.append(extension_class)


def get_extension_class(extension_name: str):
def get_extension_class(extension_name: str, auto_import=True):
"""
Get extension class from name and check if registered.
Parameters
----------
extension_name: str
The extension name.
auto_import: bool, default True
Auto import the module if the extension class is not registered yet.
Returns
-------
Expand All @@ -1076,9 +1073,20 @@ def get_extension_class(extension_name: str):
"""
global _possible_extensions
extensions_dict = {ext.extension_name: ext for ext in _possible_extensions}
assert (
extension_name in extensions_dict
), f"Extension '{extension_name}' is not registered, please import related module before use"

if extension_name not in extensions_dict:
if extension_name in _builtin_extensions:
module = _builtin_extensions[extension_name]
if auto_import:
imported_module = importlib.import_module(module)
extensions_dict = {ext.extension_name: ext for ext in _possible_extensions}
else:
raise ValueError(
f"Extension '{extension_name}' is not registered, please import related module before use: 'import {module}'"
)
else:
raise ValueError(f"Extension '{extension_name}' is unknown maybe this is an external extension or a typo.")

ext_class = extensions_dict[extension_name]
return ext_class

Expand Down Expand Up @@ -1474,3 +1482,27 @@ def get_pipeline_nodes(self):
def get_data(self, *args, **kwargs):
assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data"
return self._get_data(*args, **kwargs)


# this is a hardcoded list to to improve error message and auto_import mechanism
# this is important because extension are registered when the submodule is imported
_builtin_extensions = {
# from core
"random_spikes": "spikeinterface.core",
"waveforms": "spikeinterface.core",
"templates": "spikeinterface.core",
"fast_templates": "spikeinterface.core",
"noise_levels": "spikeinterface.core",
# from postprocessing
"amplitude_scalings": "spikeinterface.postprocessing",
"correlograms": "spikeinterface.postprocessing",
"isi_histograms": "spikeinterface.postprocessing",
"principal_components": "spikeinterface.postprocessing",
"spike_amplitudes": "spikeinterface.postprocessing",
"spike_locations": "spikeinterface.postprocessing",
"template_metrics": "spikeinterface.postprocessing",
"template_similarity": "spikeinterface.postprocessing",
"unit_locations": "spikeinterface.postprocessing",
# from quality metrics
"quality_metrics": "spikeinterface.qualitymetrics",
}

0 comments on commit 52617ee

Please sign in to comment.