Skip to content

Commit

Permalink
Overwrite . Mixin add cached_property decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 30, 2024
1 parent 013d085 commit 895e656
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 124 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Add `cautious_delete` to Mixin class
- Initial implementation. #711, #762
- More robust caching of join to downstream tables. #806
- Overwrite datajoint `delete` method to use `cautious_delete`. #806
- Add `deprecation_factory` to facilitate table migration. #717
- Add Spyglass logger. #730
- IntervalList: Add secondary key `pipeline` #742
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml
from pymysql.err import OperationalError

from spyglass.utils import logger
from spyglass.utils.logging import logger


class SpyglassConfig:
Expand Down
198 changes: 75 additions & 123 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from functools import cached_property
from time import time
from typing import Dict, List

import datajoint as dj
import networkx as nx
from datajoint.table import logger as dj_logger
from datajoint.user_tables import Table, TableMeta
from datajoint.logging import logger as dj_logger
from datajoint.table import Table
from datajoint.utils import get_master, user_choice

from spyglass.utils.database_settings import SHARED_MODULES
from spyglass.settings import test_mode
from spyglass.utils.dj_helper_fn import fetch_nwb
from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK
from spyglass.utils.dj_merge_tables import Merge
from spyglass.utils.logging import logger


Expand Down Expand Up @@ -47,77 +47,64 @@ class SpyglassMixin:
Alias for delete_downstream_merge.
"""

_nwb_table_dict = {} # Dict mapping NWBFile table to path attribute name.
# _nwb_table = None # NWBFile table class, defined at the table level
_nwb_table_resolved = None # NWBFiletable class, resolved here from above
_delete_dependencies = [] # Session, LabMember, LabTeam, delay import

# pks for delete permission check, assumed to be on field
_session_pk = None # Session primary key. Mixin is ambivalent to Session pk
_member_pk = None # LabMember primary key. Mixin ambivalent table structure
_merge_table_cache = {} # Cache of merge tables downstream of self
_merge_chains_cache = {} # Cache of table chains to merges
_session_connection_cache = None # Cache of path from Session to self
_test_mode_cache = None # Cache of test mode setting for delete
_usage_table_cache = None # Temporary inclusion for usage tracking

# ------------------------------- fetch_nwb -------------------------------

@property
@cached_property
def _table_dict(self):
"""Dict mapping NWBFile table to path attribute name.
Used to delay import of NWBFile tables until needed, avoiding circular
imports.
"""
if not self._nwb_table_dict:
from spyglass.common.common_nwbfile import ( # noqa F401
AnalysisNwbfile,
Nwbfile,
)
from spyglass.common.common_nwbfile import (
AnalysisNwbfile,
Nwbfile,
) # noqa F401

self._nwb_table_dict = {
AnalysisNwbfile: "analysis_file_abs_path",
Nwbfile: "nwb_file_abs_path",
}
return self._nwb_table_dict
return {
AnalysisNwbfile: "analysis_file_abs_path",
Nwbfile: "nwb_file_abs_path",
}

@property
@cached_property
def _nwb_table_tuple(self):
"""NWBFile table class.
Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb.
Multiple copies for different purposes.
- _nwb_table may be user-set. Don't overwrite.
- _nwb_table_resolved is set here from either _nwb_table or definition.
- _nwb_table_tuple is used to cache result of _nwb_table_resolved and
return the appropriate path_attr from _table_dict above.
"""
if not self._nwb_table_resolved:
from spyglass.common.common_nwbfile import ( # noqa F401
AnalysisNwbfile,
Nwbfile,
from spyglass.common.common_nwbfile import (
AnalysisNwbfile,
Nwbfile,
) # noqa F401

if hasattr(self, "_nwb_table"):
resolved = self._nwb_table

else:
resolved = (
AnalysisNwbfile
if "-> AnalysisNwbfile" in self.definition
else Nwbfile if "-> Nwbfile" in self.definition else None
)

if hasattr(self, "_nwb_table"):
self._nwb_table_resolved = self._nwb_table

if not hasattr(self, "_nwb_table"):
self._nwb_table_resolved = (
AnalysisNwbfile
if "-> AnalysisNwbfile" in self.definition
else Nwbfile if "-> Nwbfile" in self.definition else None
)

if getattr(self, "_nwb_table_resolved", None) is None:
raise NotImplementedError(
f"{self.__class__.__name__} does not have a "
"(Analysis)Nwbfile foreign key or _nwb_table attribute."
)
if not resolved:
raise NotImplementedError(
f"{self.__class__.__name__} does not have a "
"(Analysis)Nwbfile foreign key or _nwb_table attribute."
)

return (
self._nwb_table_resolved,
self._table_dict[self._nwb_table_resolved],
resolved,
self._table_dict[resolved],
)

def fetch_nwb(self, *attrs, **kwargs):
Expand All @@ -130,65 +117,50 @@ def fetch_nwb(self, *attrs, **kwargs):
'-> AnalysisNwbfile' in its definition can use a _nwb_table attribute to
specify which table to use.
"""
nwb_table, path_attr = self._nwb_table_tuple

return fetch_nwb(self, (nwb_table, path_attr), *attrs, **kwargs)
return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs)

# -------------------------------- delete ---------------------------------

@property
@cached_property
def _delete_deps(self) -> list:
"""List of tables required for delete permission check.
Used to delay import of tables until needed, avoiding circular imports.
Each of these tables inheits SpyglassMixin.
"""
if not self._delete_dependencies:
from spyglass.common import LabMember, LabTeam, Session # noqa F401

self._delete_dependencies = [LabMember, LabTeam, Session]
self._session_pk = Session.primary_key[0]
self._member_pk = LabMember.primary_key[0]
return self._delete_dependencies
from spyglass.common import LabMember, LabTeam, Session # noqa F401

@property
def _test_mode(self) -> bool:
"""Return True if test mode is enabled."""
if not self._test_mode_cache:
from spyglass.settings import test_mode

self._test_mode_cache = test_mode
return self._test_mode_cache
self._session_pk = Session.primary_key[0]
self._member_pk = LabMember.primary_key[0]
return [LabMember, LabTeam, Session]

@property
@cached_property
def _merge_tables(self) -> Dict[str, dj.FreeTable]:
"""Dict of merge tables downstream of self.
Cache of items in parents of self.descendants(as_objects=True) that
have a merge primary key.
"""
if self._merge_table_cache:
return self._merge_table_cache

def has_merge_pk(table):
return MERGE_PK in table.heading.names

self.connection.dependencies.load()
merge_tables = {}
for desc in self.descendants(as_objects=True):
if not has_merge_pk(desc):
continue
if not (master_name := get_master(desc.full_table_name)):
if MERGE_PK not in desc.heading.names or not (
master_name := get_master(desc.full_table_name)
):
continue
master = dj.FreeTable(self.connection, master_name)
if has_merge_pk(master):
self._merge_table_cache[master_name] = master
if MERGE_PK in master.heading.names:
merge_tables[master_name] = master

logger.info(
f"Building merge cache for {self.table_name}.\n\t"
+ f"Found {len(self._merge_table_cache)} downstream merge tables"
+ f"Found {len(merge_tables)} downstream merge tables"
)

return self._merge_table_cache
return merge_tables

@property
@cached_property
def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]:
"""Dict of merge links downstream of self.
Expand All @@ -197,14 +169,12 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]:
to recompute whenever delete_downstream_merge is called with a new
restriction.
"""
if self._merge_chains_cache:
return self._merge_chains_cache

merge_chains = {}
for name, merge_table in self._merge_tables.items():
chains = TableChains(self, merge_table, connection=self.connection)
if len(chains):
self._merge_chains_cache[name] = chains
return self._merge_chains_cache
merge_chains[name] = chains
return merge_chains

def _commit_merge_deletes(self, merge_join_dict, **kwargs):
"""Commit merge deletes.
Expand Down Expand Up @@ -247,8 +217,8 @@ def delete_downstream_merge(
Passed to datajoint.table.Table.delete.
"""
if reload_cache:
self._merge_table_cache = {}
self._merge_chains_cache = {}
del self._merge_tables
del self._merge_chains

restriction = restriction or self.restriction or True

Expand Down Expand Up @@ -315,18 +285,14 @@ def _get_exp_summary(self):

return exp_missing + exp_present

@property
@cached_property
def _session_connection(self) -> dj.expression.QueryExpression:
"""Path from Session table to self.
None is not yet cached, False if no connection found.
"""
if self._session_connection_cache is None:
connection = TableChain(parent=self._delete_deps[-1], child=self)
self._session_connection_cache = (
connection if connection.has_link else False
)
return self._session_connection_cache
connection = TableChain(parent=self._delete_deps[-1], child=self)
return connection if connection.has_link else False

def _check_delete_permission(self) -> None:
"""Check user name against lab team assoc. w/ self * Session.
Expand Down Expand Up @@ -382,14 +348,12 @@ def _check_delete_permission(self) -> None:
)
logger.info(f"Queueing delete for session(s):\n{sess_summary}")

@property
@cached_property
def _usage_table(self):
"""Temporary inclusion for usage tracking."""
if not self._usage_table_cache:
from spyglass.common.common_usage import CautiousDelete
from spyglass.common.common_usage import CautiousDelete

self._usage_table_cache = CautiousDelete
return self._usage_table_cache
return CautiousDelete

def _log_use(self, start, merge_deletes=None):
"""Log use of cautious_delete."""
Expand All @@ -403,9 +367,8 @@ def _log_use(self, start, merge_deletes=None):
)
)

# Rename to `delete` when we're ready to use it
# TODO: Intercept datajoint delete confirmation prompt for merge deletes
def cautious_delete(self, force_permission: bool = False, *args, **kwargs):
def delete(self, force_permission: bool = False, *args, **kwargs):
"""Delete table rows after checking user permission.
Permission is granted to users listed as admin in LabMember table or to
Expand Down Expand Up @@ -443,7 +406,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs):
count = sum([len(part) for part in content])
dj_logger.info(f"Merge: Deleting {count} rows from {table}")
if (
not self._test_mode
not test_mode
or not safemode
or user_choice("Commit deletes?", default="no") == "yes"
):
Expand Down Expand Up @@ -509,9 +472,6 @@ def __init__(self, parent: Table, child: Table, connection=None):
self._link_symbol = " -> "
self.parent = parent
self.child = child
self._repr = None
self._names = None # full table names of tables in chain
self._objects = None # free tables in chain
self._has_link = child.full_table_name in parent.descendants()

def __str__(self):
Expand All @@ -527,15 +487,12 @@ def __str__(self):

def __repr__(self):
"""Return full representation of chain: parent -> {links} -> child."""
if self._repr:
return self._repr
self._repr = (
return (
"Chain: "
+ self._link_symbol.join([t.table_name for t in self.objects])
if self.names
else "No link"
)
return self._repr

def __len__(self):
"""Return number of tables in chain."""
Expand All @@ -550,37 +507,32 @@ def has_link(self) -> bool:
"""
return self._has_link

@property
@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain.
Uses networkx.shortest_path.
"""
if not self._has_link:
return None
if self._names:
return self._names
try:
self._names = nx.shortest_path(
return nx.shortest_path(
self.parent.connection.dependencies,
self.parent.full_table_name,
self.child.full_table_name,
)
return self._names
except nx.NetworkXNoPath:
self._has_link = False
return None

@property
@cached_property
def objects(self) -> List[dj.FreeTable]:
"""Return list of FreeTable objects for each table in chain."""
if not self._objects:
self._objects = (
[dj.FreeTable(self._connection, name) for name in self.names]
if self.names
else None
)
return self._objects
return (
[dj.FreeTable(self._connection, name) for name in self.names]
if self.names
else None
)

def join(self, restricton: str = None) -> dj.expression.QueryExpression:
"""Return join of tables in chain with restriction applied to parent."""
Expand Down

0 comments on commit 895e656

Please sign in to comment.