From ba68b5915d48d82a7d80434a4a51020364f1366f Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 19 Jan 2024 11:45:34 -0600 Subject: [PATCH] get_class logic -> dj_merge --- src/spyglass/decoding/decoding_merge.py | 47 ++++++--- src/spyglass/utils/dj_merge_tables.py | 122 +++++++++++++++++++++--- 2 files changed, 142 insertions(+), 27 deletions(-) diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index e04ce1563..1c494a444 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -9,9 +9,7 @@ from non_local_detector.visualization.figurl_2D import create_2D_decode_view from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1 # noqa: F401 -from spyglass.decoding.v1.sorted_spikes import ( - SortedSpikesDecodingV1, -) # noqa: F401 +from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1 # noqa: F401 from spyglass.settings import config from spyglass.utils import SpyglassMixin, _Merge, logger @@ -26,7 +24,7 @@ class DecodingOutput(_Merge, SpyglassMixin): source: varchar(32) """ - _source_class_dict = None + _source_class_dict = {} class ClusterlessDecodingV1(SpyglassMixin, dj.Part): definition = """ @@ -86,18 +84,36 @@ def cleanup(self, dry_run=False): except (PermissionError, FileNotFoundError): logger.warning(f"Unable to remove {path}, skipping") - @classmethod - def _get_source_class(cls, key): - if cls._source_class_dict is None: - cls._source_class_dict = {} - module = inspect.getmodule(cls) - for part_name in cls.parts(): + @property + def source_class_dict(self) -> dict: + """Dictionary of source class names to source classes + + { + 'ClusterlessDecodingV1': spy...ClusterlessDecodingV1, + 'SortedSpikesDecodingV1': spy...SortedSpikesDecodingV1 + } + + Returns + ------- + dict + Dictionary of source class names to source classes + """ + if not self._source_class_dict: + self._ensure_dependencies_loaded() + module = inspect.getmodule(self) + for part_name in self.parts(): part_name = to_camel_case(part_name.split("__")[-1].strip("`")) part = getattr(module, part_name) - cls._source_class_dict[part_name] = part + self._source_class_dict[part_name] = part + return self._source_class_dict - source = (cls & key).fetch1("source") - return cls._source_class_dict[source] + @classmethod + def _get_source_class(cls, key): + # CB: By making this a property, we can generate the source_class_dict + # without a key. Previously failed on empty table + # This demonstrates pipeline-specific implementation. See also + # merge_restrict_class edits that centralize this logic. + return cls.source_class_dict[(cls & key).fetch1("source")] @classmethod def load_results(cls, key): @@ -105,6 +121,11 @@ def load_results(cls, key): source_class = cls._get_source_class(key) return (source_class & decoding_selection_key).load_results() + def load_results_new(cls, key): + # CB: Please test with populated database. If this works, all Merge + # tables can inherit this get_parent_class method for similar + return cls.merge_restrict_class(key).load_results() + @classmethod def load_model(cls, key): decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY") diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 82c205d48..0ab71f49e 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,5 +1,6 @@ import re from contextlib import nullcontext +from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint @@ -52,14 +53,7 @@ def __init__(self): + f"\n\tActual : {part.primary_key}" ) self._analysis_nwbfile = None - - @property # CB: This is a property to avoid circular import - def analysis_nwbfile(self): - if self._analysis_nwbfile is None: - from spyglass.common import AnalysisNwbfile # noqa F401 - - self._analysis_nwbfile = AnalysisNwbfile - return self._analysis_nwbfile + self._source_class_dict = {} def _remove_comments(self, definition): """Use regular expressions to remove comments and blank lines""" @@ -67,6 +61,37 @@ def _remove_comments(self, definition): r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition) ) + def _part_name(self, part=None): + """Return the CamelCase name of a part table""" + if not isinstance(part, str): + part = part.full_table_name + return to_camel_case(part.split("__")[-1].strip("`")) + + def get_source_from_key(self, key: dict) -> str: + """Return the source of a given key""" + return self._part_name(self & key) + + def parts(self, camel_case=False, *args, **kwargs) -> list: + """Return a list of part tables, add option for CamelCase names. + + See DataJoint `parts` for additional arguments. If camel_case is True, + forces return of strings rather than objects. + """ + self._ensure_dependencies_loaded() + + if camel_case and kwargs.get("as_objects"): + logger.warning( + "Overriding as_objects=True to return CamelCase part names." + ) + kwargs["as_objects"] = False + + parts = super().parts(*args, **kwargs) + + if camel_case: + parts = [self._part_name(part) for part in parts] + + return parts + @classmethod def _merge_restrict_parts( cls, @@ -294,7 +319,7 @@ def _merge_insert( keys = [] # empty to-be-inserted key for part in parts: # check each part part_parent = part.parents(as_objects=True)[-1] - part_name = to_camel_case(part.table_name.split("__")[-1]) + part_name = cls._part_name(part) if part_parent & row: # if row is in part parent if keys and mutual_exclusvity: # if key from other part raise ValueError( @@ -475,6 +500,15 @@ def merge_delete_parent( for part_parent in part_parents: super().delete(part_parent, **kwargs) + @property + def analysis_nwbfile(self): + """Return the AnalysisNwbfile table. Avoid circular import.""" + if self._analysis_nwbfile is None: + from spyglass.common import AnalysisNwbfile # noqa F401 + + self._analysis_nwbfile = AnalysisNwbfile + return self._analysis_nwbfile + @classmethod def fetch_nwb( cls, @@ -527,6 +561,7 @@ def merge_get_part( join_master: bool = False, restrict_part=True, multi_source=False, + return_empties=False, ) -> dj.Table: """Retrieve part table from a restricted Merge table. @@ -545,6 +580,8 @@ def merge_get_part( native part table. multi_source: bool Return multiple parts. Default False. + return_empties: bool + Default False. Return empty part tables. Returns ------ @@ -563,11 +600,11 @@ def merge_get_part( restricting """ sources = [ - to_camel_case(n.split("__")[-1].strip("`")) # friendly part name - for n in cls._merge_restrict_parts( + cls._part_name(part) # friendly part name + for part in cls._merge_restrict_parts( restriction=restriction, as_objects=False, - return_empties=False, + return_empties=return_empties, add_invalid_restrict=False, ) ] @@ -595,7 +632,8 @@ def merge_get_parent( cls, restriction: str = True, join_master: bool = False, - multi_source=False, + multi_source: bool = False, + return_empties: bool = False, ) -> dj.FreeTable: """Returns a list of part parents with restrictions applied. @@ -610,6 +648,10 @@ def merge_get_parent( Default True. join_master: bool Default False. Join part with Merge master to show uuid and source + multi_source: bool + Return multiple parents. Default False. + return_empties: bool + Default False. Return empty parent tables. Returns ------ @@ -620,7 +662,7 @@ def merge_get_parent( part_parents = cls._merge_restrict_parents( restriction=restriction, as_objects=True, - return_empties=False, + return_empties=return_empties, add_invalid_restrict=False, ) @@ -637,6 +679,58 @@ def merge_get_parent( return part_parents if multi_source else part_parents[0] + @property + def source_class_dict(self) -> dict: + if not self._source_class_dict: + module = getmodule(self) + self._source_class_dict = { + part_name: getattr(module, part_name) + for part_name in self.parts(camel_case=True) + } + return self._source_class_dict + + def merge_get_parent_class(self, source: str) -> dj.Table: + """Return the class of the parent table for a given CamelCase source. + + Parameters + ---------- + source: Union[str, dict, dj.Table] + Accepts a CamelCase name of the source, or key as a dict, or a part + table. + + Returns + ------- + dj.Table + Class instance of the parent table, including class methods. + """ + + if isinstance(source, dj.Table): + source = self._part_name(source) + if isinstance(source, dict): + source = self.get_source_from_key(source) + + ret = self.source_class_dict.get(source) + + if not ret: + logger.error( + f"No source class found for {source}: \n\t" + + f"{self.parts(camel_case=True)}" + ) + return ret + + def merge_restrict_class(self, key: dict) -> dj.Table: + """Returns native parent class, restricted with key.""" + parent_key = self.merge_get_parent(key).fetch("KEY", as_dict=True) + + if len(parent_key) > 1: + raise ValueError( + f"Ambiguous entry. Data has mult rows in parent:\n\tData:{key}" + + f"\n\t{parent_key}" + ) + + parent_class = self.merge_get_parent_class(key) + return parent_class & parent_key + @classmethod def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: """Perform a fetch across all parts. If >1 result, return as a list.