Skip to content

Commit

Permalink
Merge pull request #2585 from samuelgarcia/analyzer_module_info
Browse files Browse the repository at this point in the history
Save extension class info.
  • Loading branch information
alejoe91 authored Mar 29, 2024
2 parents f22b383 + f92545b commit 5a77dd7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 17 deletions.
19 changes: 3 additions & 16 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
make_paths_relative,
make_paths_absolute,
check_paths_relative,
retrieve_importing_provenance,
)
from .job_tools import _shared_job_kwargs_doc

Expand Down Expand Up @@ -427,22 +428,8 @@ def to_dict(

kwargs = new_kwargs

module_import_path = self.__class__.__module__
class_name_no_path = self.__class__.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

dump_dict = {
"class": class_name,
"module": module,
"kwargs": kwargs,
"version": module_version,
}

dump_dict["version"] = module_version # Can be spikeinterface, spikeforest, etc.
dump_dict = retrieve_importing_provenance(self.__class__)
dump_dict["kwargs"] = kwargs

if include_annotations:
dump_dict["annotations"] = self._annotations
Expand Down
36 changes: 36 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import json
from copy import deepcopy
import importlib
from math import prod

import numpy as np
Expand Down Expand Up @@ -477,3 +478,38 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0):
"""

return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2))


def retrieve_importing_provenance(a_class):
"""
Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version.
Parameters
----------
a_class : type
The class object for which to retrieve the import provenance.
Returns
-------
dict
A dictionary containing:
- 'class': The module path and the name of the class concatenated (e.g., 'package.subpackage.ClassName').
- 'module': The top-level module name where the class is defined.
- 'version': The version of the module if available, otherwise 'unknown'.
Get class info as a dict.
"""
module_import_path = a_class.__module__
class_name_no_path = a_class.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

info = {
"class": class_name,
"module": module,
"version": module_version,
}

return info
20 changes: 19 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes
from .core_tools import check_json
from .core_tools import check_json, retrieve_importing_provenance
from .job_tools import split_job_kwargs
from .numpyextractors import SharedMemorySorting
from .sparsity import ChannelSparsity, estimate_sparsity
Expand Down Expand Up @@ -1380,6 +1380,7 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
# this also reset the folder or zarr group
self._save_params()
self._save_importing_provenance()

self._run(**kwargs)

Expand All @@ -1388,6 +1389,7 @@ def run(self, save=True, **kwargs):

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)

def _save_data(self, **kwargs):
Expand Down Expand Up @@ -1506,6 +1508,7 @@ def set_params(self, save=True, **params):

if save:
self._save_params()
self._save_importing_provenance()

def _save_params(self):
params_to_save = self.params.copy()
Expand All @@ -1528,6 +1531,21 @@ def _save_params(self):
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["params"] = check_json(params_to_save)

def _save_importing_provenance(self):
# this saves the class info, this is not uselful at the moment but could be useful in future
# if some class changes the data model and if we need to make backwards compatibility
# we have the same machanism in base.py for recording and sorting

info = retrieve_importing_provenance(self.__class__)
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
extension_folder.mkdir(exist_ok=True, parents=True)
info_file = extension_folder / "info.json"
info_file.write_text(json.dumps(info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["info"] = info

def get_pipeline_nodes(self):
assert (
self.use_nodepipeline
Expand Down

0 comments on commit 5a77dd7

Please sign in to comment.