Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save extension class info. #2585

Merged
merged 9 commits into from
Mar 29, 2024
19 changes: 3 additions & 16 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
make_paths_relative,
make_paths_absolute,
check_paths_relative,
retrieve_importing_provenance,
)
from .job_tools import _shared_job_kwargs_doc

Expand Down Expand Up @@ -427,22 +428,8 @@ def to_dict(

kwargs = new_kwargs

module_import_path = self.__class__.__module__
class_name_no_path = self.__class__.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

dump_dict = {
"class": class_name,
"module": module,
"kwargs": kwargs,
"version": module_version,
}

dump_dict["version"] = module_version # Can be spikeinterface, spikeforest, etc.
dump_dict = retrieve_importing_provenance(self.__class__)
dump_dict["kwargs"] = kwargs

if include_annotations:
dump_dict["annotations"] = self._annotations
Expand Down
36 changes: 36 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import json
from copy import deepcopy
import importlib


import numpy as np
Expand Down Expand Up @@ -476,3 +477,38 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0):
"""

return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2))


def retrieve_importing_provenance(a_class):
"""
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
Retrieve the import provenance of a class, including its an import name that consists of the class name and the module, the top-level module, and the module version.
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
a_class : type
The class object for which to retrieve the import provenance.

Returns
-------
dict
A dictionary containing:
- 'class': The module path and the name of the class concatenated (e.g., 'package.subpackage.ClassName').
- 'module': The top-level module name where the class is defined.
- 'version': The version of the module if available, otherwise 'unknown'.
Get class info as a dict.
"""
module_import_path = a_class.__module__
class_name_no_path = a_class.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]

imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

info = {
"class": class_name,
"module": module,
"version": module_version,
}

return info
20 changes: 19 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes
from .core_tools import check_json
from .core_tools import check_json, retrieve_importing_provenance
from .job_tools import split_job_kwargs
from .numpyextractors import SharedMemorySorting
from .sparsity import ChannelSparsity, estimate_sparsity
Expand Down Expand Up @@ -1317,6 +1317,7 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
# this also reset the folder or zarr group
self._save_params()
self._save_info()

self._run(**kwargs)

Expand All @@ -1325,6 +1326,7 @@ def run(self, save=True, **kwargs):

def save(self, **kwargs):
self._save_params()
self._save_info()
self._save_data(**kwargs)

def _save_data(self, **kwargs):
Expand Down Expand Up @@ -1443,6 +1445,7 @@ def set_params(self, save=True, **params):

if save:
self._save_params()
self._save_info()

def _save_params(self):
params_to_save = self.params.copy()
Expand All @@ -1465,6 +1468,21 @@ def _save_params(self):
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["params"] = check_json(params_to_save)

def _save_info(self):
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
# this save class info, this is not uselfull at the moment but this could be usefull in futur
# if some class change the data model and if we need to make backwards compatibility
# we have the same machanism in base.py for recording and sorting
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

info = retrieve_importing_provenance(self.__class__)
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
extension_folder.mkdir(exist_ok=True, parents=True)
info_file = extension_folder / "info.json"
info_file.write_text(json.dumps(info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["info"] = info

def get_pipeline_nodes(self):
assert (
self.use_nodepipeline
Expand Down
11 changes: 11 additions & 0 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ def split(
possible_labels = clust.labels_
is_split = np.setdiff1d(possible_labels, [-1]).size > 1
elif clusterer == "isocut5":
print(final_features[:, 0].shape)
unique_feat, counts_feat = np.unique(final_features[:, 0], return_counts=True)
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.hist(final_features[:, 0])
fig, ax = plt.subplots()
ax.plot(unique_feat, counts_feat)
plt.show()

print(np.sum(counts_feat > 1))
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
dipscore, cutpoint = isocut5(final_features[:, 0])
possible_labels = np.zeros(final_features.shape[0])
if dipscore > 1.5:
Expand Down
Loading