From a6c4767f74cb608f2d1161ea93a1afeeb1ad09a1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 15 Mar 2024 13:59:18 +0100 Subject: [PATCH 1/8] Save extension class info. --- src/spikeinterface/core/base.py | 19 +++---------------- src/spikeinterface/core/core_tools.py | 21 +++++++++++++++++++++ src/spikeinterface/core/sortinganalyzer.py | 22 +++++++++++++++++++++- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 80341811b9..8f39491a8a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -23,6 +23,7 @@ make_paths_relative, make_paths_absolute, check_paths_relative, + get_class_info ) from .job_tools import _shared_job_kwargs_doc @@ -427,22 +428,8 @@ def to_dict( kwargs = new_kwargs - module_import_path = self.__class__.__module__ - class_name_no_path = self.__class__.__name__ - class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' - module = class_name.split(".")[0] - - imported_module = importlib.import_module(module) - module_version = getattr(imported_module, "__version__", "unknown") - - dump_dict = { - "class": class_name, - "module": module, - "kwargs": kwargs, - "version": module_version, - } - - dump_dict["version"] = module_version # Can be spikeinterface, spikeforest, etc. + dump_dict = get_class_info(self.__class__) + dump_dict["kwargs"] = kwargs if include_annotations: dump_dict["annotations"] = self._annotations diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 3b82436d5c..bb6c52c2c8 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -6,6 +6,7 @@ import datetime import json from copy import deepcopy +import importlib import numpy as np @@ -476,3 +477,23 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): """ return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2)) + +def get_class_info(a_class): + """ + Get class info as a dict. + """ + module_import_path = a_class.__module__ + class_name_no_path = a_class.__name__ + class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' + module = class_name.split(".")[0] + + imported_module = importlib.import_module(module) + module_version = getattr(imported_module, "__version__", "unknown") + + info = { + "class": class_name, + "module": module, + "version": module_version, + } + + return info diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ab5fc92c84..081e7a0850 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -20,7 +20,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes -from .core_tools import check_json +from .core_tools import check_json, get_class_info from .job_tools import split_job_kwargs from .numpyextractors import SharedMemorySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -1317,6 +1317,7 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # this also reset the folder or zarr group self._save_params() + self._save_info() self._run(**kwargs) @@ -1325,6 +1326,7 @@ def run(self, save=True, **kwargs): def save(self, **kwargs): self._save_params() + self._save_info() self._save_data(**kwargs) def _save_data(self, **kwargs): @@ -1443,6 +1445,7 @@ def set_params(self, save=True, **params): if save: self._save_params() + self._save_info() def _save_params(self): params_to_save = self.params.copy() @@ -1465,6 +1468,23 @@ def _save_params(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["params"] = check_json(params_to_save) + def _save_info(self): + # this save class info, this is not uselfull at the moment but this could be usefull in futur + # if some class change the data model and if we need to make backwards compatibility + # we have the same machanism in base.py for recording and sorting + + info = get_class_info(self.__class__) + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + extension_folder.mkdir(exist_ok=True, parents=True) + info_file = extension_folder / "info.json" + info_file.write_text(json.dumps(info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["info"] = info + + + def get_pipeline_nodes(self): assert ( self.use_nodepipeline From eb42b3edcb6efe3e7aa5c45ed206f2831006092f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Mar 2024 12:59:59 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/base.py | 2 +- src/spikeinterface/core/core_tools.py | 1 + src/spikeinterface/core/sortinganalyzer.py | 2 -- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8f39491a8a..628bb7750b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -23,7 +23,7 @@ make_paths_relative, make_paths_absolute, check_paths_relative, - get_class_info + get_class_info, ) from .job_tools import _shared_job_kwargs_doc diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index bb6c52c2c8..24ca4febfb 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -478,6 +478,7 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2)) + def get_class_info(a_class): """ Get class info as a dict. diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 081e7a0850..ae5a3e7f2f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1483,8 +1483,6 @@ def _save_info(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["info"] = info - - def get_pipeline_nodes(self): assert ( self.use_nodepipeline From d23c1c36e0743e40ff36a6161ba5aadc4f4cddb5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Mar 2024 09:17:28 +0100 Subject: [PATCH 3/8] Merci Ramon for feedback --- src/spikeinterface/core/base.py | 4 ++-- src/spikeinterface/core/core_tools.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 4 ++-- .../sortingcomponents/clustering/split.py | 11 +++++++++++ 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 628bb7750b..92b6018915 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -23,7 +23,7 @@ make_paths_relative, make_paths_absolute, check_paths_relative, - get_class_info, + retrieve_importing_provenance, ) from .job_tools import _shared_job_kwargs_doc @@ -428,7 +428,7 @@ def to_dict( kwargs = new_kwargs - dump_dict = get_class_info(self.__class__) + dump_dict = retrieve_importing_provenance(self.__class__) dump_dict["kwargs"] = kwargs if include_annotations: diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 24ca4febfb..34c9c3f946 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -479,7 +479,7 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((x - mu) ** 2) / (2 * sigma**2)) -def get_class_info(a_class): +def retrieve_importing_provenance(a_class): """ Get class info as a dict. """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ae5a3e7f2f..05e286c78b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -20,7 +20,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes -from .core_tools import check_json, get_class_info +from .core_tools import check_json, retrieve_importing_provenance from .job_tools import split_job_kwargs from .numpyextractors import SharedMemorySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -1473,7 +1473,7 @@ def _save_info(self): # if some class change the data model and if we need to make backwards compatibility # we have the same machanism in base.py for recording and sorting - info = get_class_info(self.__class__) + info = retrieve_importing_provenance(self.__class__) if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() extension_folder.mkdir(exist_ok=True, parents=True) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 3861e7fe83..8b45f27c4d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -231,6 +231,17 @@ def split( possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": + print(final_features[:, 0].shape) + unique_feat, counts_feat = np.unique(final_features[:, 0],return_counts=True) + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.hist(final_features[:, 0]) + fig, ax = plt.subplots() + ax.plot(unique_feat, counts_feat) + plt.show() + + + print(np.sum(counts_feat>1)) dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) if dipscore > 1.5: From d0d7df570bfd4e837110a684b1a0369ac99e82d4 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 20 Mar 2024 09:17:51 +0100 Subject: [PATCH 4/8] Update src/spikeinterface/core/core_tools.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/core/core_tools.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 34c9c3f946..aca953c67f 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -481,6 +481,20 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): def retrieve_importing_provenance(a_class): """ + Retrieve the import provenance of a class, including its an import name that consists of the class name and the module, the top-level module, and the module version. + + Parameters + ---------- + a_class : type + The class object for which to retrieve the import provenance. + + Returns + ------- + dict + A dictionary containing: + - 'class': The module path and the name of the class concatenated (e.g., 'package.subpackage.ClassName'). + - 'module': The top-level module name where the class is defined. + - 'version': The version of the module if available, otherwise 'unknown'. Get class info as a dict. """ module_import_path = a_class.__module__ From 2f7190bbcf40752547b46377d7cb2c77e88c7b15 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 08:17:55 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/split.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 8b45f27c4d..ba3bc1fcec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -232,16 +232,16 @@ def split( is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": print(final_features[:, 0].shape) - unique_feat, counts_feat = np.unique(final_features[:, 0],return_counts=True) + unique_feat, counts_feat = np.unique(final_features[:, 0], return_counts=True) import matplotlib.pyplot as plt + fig, ax = plt.subplots() ax.hist(final_features[:, 0]) fig, ax = plt.subplots() ax.plot(unique_feat, counts_feat) plt.show() - - print(np.sum(counts_feat>1)) + print(np.sum(counts_feat > 1)) dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) if dipscore > 1.5: From 0ceb68d2021ee10b62eb595400f74d9fab8aed86 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 20 Mar 2024 13:03:56 +0100 Subject: [PATCH 6/8] Apply suggestions from code review oups. Co-authored-by: Alessio Buccino --- src/spikeinterface/core/core_tools.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index aca953c67f..c680c5823b 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -481,7 +481,7 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): def retrieve_importing_provenance(a_class): """ - Retrieve the import provenance of a class, including its an import name that consists of the class name and the module, the top-level module, and the module version. + Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version. Parameters ---------- diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 05e286c78b..2a765eda7f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1469,8 +1469,8 @@ def _save_params(self): extension_group.attrs["params"] = check_json(params_to_save) def _save_info(self): - # this save class info, this is not uselfull at the moment but this could be usefull in futur - # if some class change the data model and if we need to make backwards compatibility + # this saves the class info, this is not uselful at the moment but could be useful in future + # if some class changes the data model and if we need to make backwards compatibility # we have the same machanism in base.py for recording and sorting info = retrieve_importing_provenance(self.__class__) From d1598b1f46113e66abfbd0346ade072d8e8eacb1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Mar 2024 13:06:11 +0100 Subject: [PATCH 7/8] oups --- .../sortingcomponents/clustering/split.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index ba3bc1fcec..3861e7fe83 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -231,17 +231,6 @@ def split( possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": - print(final_features[:, 0].shape) - unique_feat, counts_feat = np.unique(final_features[:, 0], return_counts=True) - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - ax.hist(final_features[:, 0]) - fig, ax = plt.subplots() - ax.plot(unique_feat, counts_feat) - plt.show() - - print(np.sum(counts_feat > 1)) dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) if dipscore > 1.5: From 82c4ebb42e8bf0557a7799b1119fd875cc831656 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Mar 2024 13:09:49 +0100 Subject: [PATCH 8/8] oups --- src/spikeinterface/core/sortinganalyzer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 2a765eda7f..d561b3267f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1317,7 +1317,7 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # this also reset the folder or zarr group self._save_params() - self._save_info() + self._save_importing_provenance() self._run(**kwargs) @@ -1326,7 +1326,7 @@ def run(self, save=True, **kwargs): def save(self, **kwargs): self._save_params() - self._save_info() + self._save_importing_provenance() self._save_data(**kwargs) def _save_data(self, **kwargs): @@ -1445,7 +1445,7 @@ def set_params(self, save=True, **params): if save: self._save_params() - self._save_info() + self._save_importing_provenance() def _save_params(self): params_to_save = self.params.copy() @@ -1468,7 +1468,7 @@ def _save_params(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["params"] = check_json(params_to_save) - def _save_info(self): + def _save_importing_provenance(self): # this saves the class info, this is not uselful at the moment but could be useful in future # if some class changes the data model and if we need to make backwards compatibility # we have the same machanism in base.py for recording and sorting