Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save extension class info. #2585

Merged
merged 9 commits into from
Mar 29, 2024
19 changes: 3 additions & 16 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
make_paths_relative,
make_paths_absolute,
check_paths_relative,
get_class_info,
)
from .job_tools import _shared_job_kwargs_doc

Expand Down Expand Up @@ -427,22 +428,8 @@ def to_dict(

kwargs = new_kwargs

module_import_path = self.__class__.__module__
class_name_no_path = self.__class__.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

dump_dict = {
"class": class_name,
"module": module,
"kwargs": kwargs,
"version": module_version,
}

dump_dict["version"] = module_version # Can be spikeinterface, spikeforest, etc.
dump_dict = get_class_info(self.__class__)
dump_dict["kwargs"] = kwargs

if include_annotations:
dump_dict["annotations"] = self._annotations
Expand Down
22 changes: 22 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import json
from copy import deepcopy
import importlib


import numpy as np
Expand Down Expand Up @@ -476,3 +477,24 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0):
"""

return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2))


def get_class_info(a_class):
Copy link
Collaborator

@h-mayorquin h-mayorquin Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly suggest against something as vague as get_class_info for the name of a function. If I remember well, the point of storing these attributes is to store information about how to instantiate the class. From that perspective I think that something like:

  • get_importing_provenance
  • `get_class_location_within_package"
  • retrieve_class_path

or a combination of them.

I like retrieve_importing_provenance because I think it describes well what it does and it also does not add more suggestions to the get methods that are more used for user relevant information (e.g. get_traces).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree.
THanks you for the feedback. lets go for retrieve_importing_provenance

"""
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
Get class info as a dict.
"""
module_import_path = a_class.__module__
class_name_no_path = a_class.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

info = {
"class": class_name,
"module": module,
"version": module_version,
}

return info
20 changes: 19 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes
from .core_tools import check_json
from .core_tools import check_json, get_class_info
from .job_tools import split_job_kwargs
from .numpyextractors import SharedMemorySorting
from .sparsity import ChannelSparsity, estimate_sparsity
Expand Down Expand Up @@ -1317,6 +1317,7 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
# this also reset the folder or zarr group
self._save_params()
self._save_info()

self._run(**kwargs)

Expand All @@ -1325,6 +1326,7 @@ def run(self, save=True, **kwargs):

def save(self, **kwargs):
self._save_params()
self._save_info()
self._save_data(**kwargs)

def _save_data(self, **kwargs):
Expand Down Expand Up @@ -1443,6 +1445,7 @@ def set_params(self, save=True, **params):

if save:
self._save_params()
self._save_info()

def _save_params(self):
params_to_save = self.params.copy()
Expand All @@ -1465,6 +1468,21 @@ def _save_params(self):
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["params"] = check_json(params_to_save)

def _save_info(self):
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
# this save class info, this is not uselfull at the moment but this could be usefull in futur
# if some class change the data model and if we need to make backwards compatibility
# we have the same machanism in base.py for recording and sorting
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

info = get_class_info(self.__class__)
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
extension_folder.mkdir(exist_ok=True, parents=True)
info_file = extension_folder / "info.json"
info_file.write_text(json.dumps(info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["info"] = info

def get_pipeline_nodes(self):
assert (
self.use_nodepipeline
Expand Down
Loading