Skip to content

Commit

Permalink
get_class logic -> dj_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 19, 2024
1 parent 082fa42 commit ba68b59
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 27 deletions.
47 changes: 34 additions & 13 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from non_local_detector.visualization.figurl_2D import create_2D_decode_view

from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1 # noqa: F401
from spyglass.decoding.v1.sorted_spikes import (
SortedSpikesDecodingV1,
) # noqa: F401
from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1 # noqa: F401
from spyglass.settings import config
from spyglass.utils import SpyglassMixin, _Merge, logger

Expand All @@ -26,7 +24,7 @@ class DecodingOutput(_Merge, SpyglassMixin):
source: varchar(32)
"""

_source_class_dict = None
_source_class_dict = {}

class ClusterlessDecodingV1(SpyglassMixin, dj.Part):
definition = """
Expand Down Expand Up @@ -86,25 +84,48 @@ def cleanup(self, dry_run=False):
except (PermissionError, FileNotFoundError):
logger.warning(f"Unable to remove {path}, skipping")

@classmethod
def _get_source_class(cls, key):
if cls._source_class_dict is None:
cls._source_class_dict = {}
module = inspect.getmodule(cls)
for part_name in cls.parts():
@property
def source_class_dict(self) -> dict:
"""Dictionary of source class names to source classes
{
'ClusterlessDecodingV1': spy...ClusterlessDecodingV1,
'SortedSpikesDecodingV1': spy...SortedSpikesDecodingV1
}
Returns
-------
dict
Dictionary of source class names to source classes
"""
if not self._source_class_dict:
self._ensure_dependencies_loaded()
module = inspect.getmodule(self)
for part_name in self.parts():
part_name = to_camel_case(part_name.split("__")[-1].strip("`"))
part = getattr(module, part_name)
cls._source_class_dict[part_name] = part
self._source_class_dict[part_name] = part
return self._source_class_dict

source = (cls & key).fetch1("source")
return cls._source_class_dict[source]
@classmethod
def _get_source_class(cls, key):
# CB: By making this a property, we can generate the source_class_dict
# without a key. Previously failed on empty table
# This demonstrates pipeline-specific implementation. See also
# merge_restrict_class edits that centralize this logic.
return cls.source_class_dict[(cls & key).fetch1("source")]

@classmethod
def load_results(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return (source_class & decoding_selection_key).load_results()

def load_results_new(cls, key):
# CB: Please test with populated database. If this works, all Merge
# tables can inherit this get_parent_class method for similar
return cls.merge_restrict_class(key).load_results()

@classmethod
def load_model(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
Expand Down
122 changes: 108 additions & 14 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from contextlib import nullcontext
from inspect import getmodule
from itertools import chain as iter_chain
from pprint import pprint

Expand Down Expand Up @@ -52,21 +53,45 @@ def __init__(self):
+ f"\n\tActual : {part.primary_key}"
)
self._analysis_nwbfile = None

@property # CB: This is a property to avoid circular import
def analysis_nwbfile(self):
if self._analysis_nwbfile is None:
from spyglass.common import AnalysisNwbfile # noqa F401

self._analysis_nwbfile = AnalysisNwbfile
return self._analysis_nwbfile
self._source_class_dict = {}

def _remove_comments(self, definition):
"""Use regular expressions to remove comments and blank lines"""
return re.sub( # First remove comments, then blank lines
r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition)
)

def _part_name(self, part=None):
"""Return the CamelCase name of a part table"""
if not isinstance(part, str):
part = part.full_table_name
return to_camel_case(part.split("__")[-1].strip("`"))

def get_source_from_key(self, key: dict) -> str:
"""Return the source of a given key"""
return self._part_name(self & key)

def parts(self, camel_case=False, *args, **kwargs) -> list:
"""Return a list of part tables, add option for CamelCase names.
See DataJoint `parts` for additional arguments. If camel_case is True,
forces return of strings rather than objects.
"""
self._ensure_dependencies_loaded()

if camel_case and kwargs.get("as_objects"):
logger.warning(
"Overriding as_objects=True to return CamelCase part names."
)
kwargs["as_objects"] = False

parts = super().parts(*args, **kwargs)

if camel_case:
parts = [self._part_name(part) for part in parts]

return parts

@classmethod
def _merge_restrict_parts(
cls,
Expand Down Expand Up @@ -294,7 +319,7 @@ def _merge_insert(
keys = [] # empty to-be-inserted key
for part in parts: # check each part
part_parent = part.parents(as_objects=True)[-1]
part_name = to_camel_case(part.table_name.split("__")[-1])
part_name = cls._part_name(part)
if part_parent & row: # if row is in part parent
if keys and mutual_exclusvity: # if key from other part
raise ValueError(
Expand Down Expand Up @@ -475,6 +500,15 @@ def merge_delete_parent(
for part_parent in part_parents:
super().delete(part_parent, **kwargs)

@property
def analysis_nwbfile(self):
"""Return the AnalysisNwbfile table. Avoid circular import."""
if self._analysis_nwbfile is None:
from spyglass.common import AnalysisNwbfile # noqa F401

self._analysis_nwbfile = AnalysisNwbfile
return self._analysis_nwbfile

@classmethod
def fetch_nwb(
cls,
Expand Down Expand Up @@ -527,6 +561,7 @@ def merge_get_part(
join_master: bool = False,
restrict_part=True,
multi_source=False,
return_empties=False,
) -> dj.Table:
"""Retrieve part table from a restricted Merge table.
Expand All @@ -545,6 +580,8 @@ def merge_get_part(
native part table.
multi_source: bool
Return multiple parts. Default False.
return_empties: bool
Default False. Return empty part tables.
Returns
------
Expand All @@ -563,11 +600,11 @@ def merge_get_part(
restricting
"""
sources = [
to_camel_case(n.split("__")[-1].strip("`")) # friendly part name
for n in cls._merge_restrict_parts(
cls._part_name(part) # friendly part name
for part in cls._merge_restrict_parts(
restriction=restriction,
as_objects=False,
return_empties=False,
return_empties=return_empties,
add_invalid_restrict=False,
)
]
Expand Down Expand Up @@ -595,7 +632,8 @@ def merge_get_parent(
cls,
restriction: str = True,
join_master: bool = False,
multi_source=False,
multi_source: bool = False,
return_empties: bool = False,
) -> dj.FreeTable:
"""Returns a list of part parents with restrictions applied.
Expand All @@ -610,6 +648,10 @@ def merge_get_parent(
Default True.
join_master: bool
Default False. Join part with Merge master to show uuid and source
multi_source: bool
Return multiple parents. Default False.
return_empties: bool
Default False. Return empty parent tables.
Returns
------
Expand All @@ -620,7 +662,7 @@ def merge_get_parent(
part_parents = cls._merge_restrict_parents(
restriction=restriction,
as_objects=True,
return_empties=False,
return_empties=return_empties,
add_invalid_restrict=False,
)

Expand All @@ -637,6 +679,58 @@ def merge_get_parent(

return part_parents if multi_source else part_parents[0]

@property
def source_class_dict(self) -> dict:
if not self._source_class_dict:
module = getmodule(self)
self._source_class_dict = {
part_name: getattr(module, part_name)
for part_name in self.parts(camel_case=True)
}
return self._source_class_dict

def merge_get_parent_class(self, source: str) -> dj.Table:
"""Return the class of the parent table for a given CamelCase source.
Parameters
----------
source: Union[str, dict, dj.Table]
Accepts a CamelCase name of the source, or key as a dict, or a part
table.
Returns
-------
dj.Table
Class instance of the parent table, including class methods.
"""

if isinstance(source, dj.Table):
source = self._part_name(source)
if isinstance(source, dict):
source = self.get_source_from_key(source)

ret = self.source_class_dict.get(source)

if not ret:
logger.error(
f"No source class found for {source}: \n\t"
+ f"{self.parts(camel_case=True)}"
)
return ret

def merge_restrict_class(self, key: dict) -> dj.Table:
"""Returns native parent class, restricted with key."""
parent_key = self.merge_get_parent(key).fetch("KEY", as_dict=True)

if len(parent_key) > 1:
raise ValueError(
f"Ambiguous entry. Data has mult rows in parent:\n\tData:{key}"
+ f"\n\t{parent_key}"
)

parent_class = self.merge_get_parent_class(key)
return parent_class & parent_key

@classmethod
def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list:
"""Perform a fetch across all parts. If >1 result, return as a list.
Expand Down

0 comments on commit ba68b59

Please sign in to comment.